EEG Sleep Stage Prediction with Neural Networks in Python

By NeuralNine

Machine LearningDeep LearningNeural NetworksSignal Processing
Share:

Key Concepts

  • EEG (Electroencephalography): A neurophysiological monitoring method to record the electrical activity of the brain.
  • Sleep EDF Database: A publicly available dataset containing EEG and other physiological signals recorded during sleep.
  • PSG (Polysomnography): A sleep study that records brain waves, oxygen levels, heart rate, breathing, and eye movements during sleep.
  • Hypnogram: A graph that shows the different sleep stages a person goes through during the night.
  • Deep Convolutional Neural Network (CNN): A type of artificial neural network commonly used for image recognition but also applicable to time-series data like EEG.
  • PyTorch: An open-source machine learning framework used for building and training neural networks.
  • MNE-Python: An open-source Python package for analyzing and visualizing electrophysiological data, including EEG.
  • Time Series Data: Data points indexed in time order.
  • Sampling Frequency (Fs): The number of samples of a signal taken per second.
  • Microvolts (µV): A unit of electrical potential equal to one millionth of a volt.
  • Bipolar Montage: A configuration of EEG electrodes where the difference in electrical potential between two electrodes is measured.
  • 10-20 System: A standardized system for the placement of EEG electrodes on the scalp.
  • Sleep Stages: Different phases of sleep characterized by distinct brain wave patterns and physiological activity (e.g., Wake, Stage 1, Stage 2, Stage 3, Stage 4, REM).
  • Data Preprocessing: Steps taken to clean, transform, and prepare data for machine learning models.
  • Normalization: Scaling data to a standard range, often to improve model performance.
  • Train-Test Split: Dividing a dataset into two subsets: one for training the model and one for evaluating its performance on unseen data.
  • Stratification: Ensuring that the proportion of different classes in the target variable is maintained in both the training and testing sets.
  • Tensors: Multi-dimensional arrays used in PyTorch for numerical computation.
  • GPU (Graphics Processing Unit): Specialized electronic circuit designed to rapidly manipulate and alter memory to accelerate the creation of images. Used for speeding up deep learning computations.
  • CNN Layers:
    • Conv1D: A 1D convolutional layer for processing sequential data.
    • ReLU (Rectified Linear Unit): An activation function that introduces non-linearity.
    • MaxPool1D: A 1D max pooling layer for down-sampling.
    • Flatten: A layer that reshapes the input into a 1D array.
    • Linear (Fully Connected): A layer where every neuron is connected to every neuron in the previous layer.
    • Dropout: A regularization technique to prevent overfitting by randomly setting a fraction of input units to zero during training.
  • Loss Function: A function that quantifies the error between the predicted output and the actual target.
  • Optimizer: An algorithm used to adjust the parameters of a neural network to minimize the loss function.
  • Epoch: One complete pass through the entire training dataset.
  • Batch Size: The number of training examples used in one iteration of the training process.
  • Backpropagation: The process of calculating gradients and updating model weights.
  • Accuracy Score: A metric that measures the proportion of correctly classified instances.
  • Classification Report: A report that provides precision, recall, F1-score, and support for each class.
  • Confusion Matrix: A table that summarizes the performance of a classification model, showing the counts of true positives, true negatives, false positives, and false negatives.

Project Overview: Sleep Stage Classification using EEG Data

This project focuses on implementing a machine learning model to predict sleep stages from EEG time-series data. The primary goal is to build and train a deep convolutional neural network (CNN) using PyTorch to classify sleep stages, specifically distinguishing between wakefulness and different sleep stages. The project utilizes the Sleep EDF database and demonstrates a step-by-step process from data loading and exploration to model training and evaluation.

Data Acquisition and Exploration

Sleep EDF Database

  • The Sleep EDF database is the source of the medical data.
  • It contains data from multiple subjects, with two types of files per subject:
    • PSG files: Contain seven channels of time-series data, including two EEG channels (FPZ-CZ and PZ-OZ), which are the focus of this project. Other channels like body temperature are ignored.
    • Hypnogram files: Provide the ground truth, indicating the actual sleep stage for specific time intervals.
  • Data can be downloaded via zip files, wget, or the AWS CLI.

Data Loading and Initial Inspection

  • The mne.io.read_raw_edf() function is used to load raw EDF files.
  • The raw_edf object contains meta-information about the file, including channels and sampling frequency.
  • The sampling frequency (Fs) for this dataset is 100 Hz, meaning 100 measurements per second.
  • The raw_edf.get_data() method extracts the time-series data for all channels.
  • The EEG channels of interest are identified as 'EEG FPZ-CZ' and 'EEG PZ-OZ'.
  • The data is visualized to understand its characteristics, with the y-axis representing voltage in microvolts (µV) and the x-axis representing time.
  • The 10-20 system for electrode placement is mentioned, explaining the bipolar montage used for the EEG channels.

Ground Truth (Hypnogram) Loading

  • Annotations (sleep stages) are loaded using mne.Annotations() from the hypnogram files.
  • Key attributes of annotations include:
    • onset: The starting time of a sleep stage.
    • duration: The duration of a sleep stage.
    • description: The label for the sleep stage (e.g., 'W' for waking, '1', '2', '3', '4' for sleep stages, 'R' for REM, '?' for unknown).
  • A mapping is created to convert these string descriptions into numerical labels for machine learning:
    • W: 0
    • 1: 1
    • 2: 2
    • 3: 3
    • 4: 4
    • R: 5

Data Preparation

Single Channel Training

  • The project first focuses on training a model using only the 'EEG FPZ-CZ' channel.
  • A window_length of 30 seconds (3000 samples at 100 Hz) is defined to segment the data for prediction.
  • The process involves iterating through a subset of subjects (first 25) to avoid memory issues.
  • For each subject, PSG and hypnogram files are loaded.
  • The selected EEG channel data and annotations are extracted.
  • The annotations are converted to numerical labels using the stage_map.
  • The continuous EEG data is segmented into 30-second windows.
  • Each window is added to X_data, and its corresponding sleep stage label is added to Y_data.
  • The onset and duration from annotations are converted to sample indices by multiplying by the sampling frequency (Fs).
  • The resulting X_data has a shape of (number of windows, window length) and Y_data has a shape of (number of windows,).
  • Normalization: The EEG data is normalized by subtracting the mean and dividing by the standard deviation to center it around zero.
  • Train-Test Split: The data is split into training (80%) and testing (20%) sets using train_test_split with stratification to maintain class distribution.
  • Tensor Conversion: NumPy arrays are converted into PyTorch tensors and moved to the appropriate device (GPU if available, otherwise CPU). An extra dimension is added for the single channel input (unsqueeze(1)).

Two Channel Training

  • The process is repeated for training with both 'EEG FPZ-CZ' and 'EEG PZ-OZ' channels.
  • The key difference is that data from both channels is retrieved.
  • Instead of appending individual windows, the windows from both channels are stacked using np.stack(axis=0) to create a combined input feature.
  • The model architecture is adjusted to accept two input channels.
  • The unsqueeze(1) step is removed as the stacking already creates the necessary channel dimension.

Model Training

Convolutional Neural Network (CNN) Architecture

  • A sequential CNN model is defined using torch.nn.Sequential.
  • Layers:
    1. Conv1D: Input channels: 1 (for single channel), Output channels: 32, Kernel size: 50, Stride: 6.
    2. ReLU: Activation function.
    3. MaxPool1D: Kernel size: 8.
    4. Conv1D: Input channels: 32, Output channels: 64, Kernel size: 8, Stride: 1 (default).
    5. ReLU: Activation function.
    6. MaxPool1D: Kernel size: 8.
    7. Flatten: Reshapes the output to a 1D vector.
    8. Linear: Input features: 64 * 6 (calculated based on previous layer output and kernel sizes), Output features: 128.
    9. ReLU: Activation function.
    10. Dropout: Probability: 0.5 (50%) for regularization.
    11. Linear: Input features: 128, Output features: 6 (number of sleep stages: W, 1, 2, 3, 4, R).
  • The model is moved to the selected device (GPU/CPU).

Training Parameters

  • Loss Function: torch.nn.CrossEntropyLoss is used, suitable for multi-class classification.
  • Optimizer: torch.optim.Adam is chosen with a learning rate of 0.001.
  • Epochs: Training is performed for 10 epochs.
  • Batch Size: A batch size of 64 is used.

Training Loop

  • The model is set to training mode (model.train()).
  • The total loss for each epoch is tracked.
  • The training data is iterated over in batches.
  • For each batch:
    • Gradients are zeroed out (optimizer.zero_grad()).
    • Predictions are made (outputs = model(batch_x)).
    • The loss is calculated (loss = criterion(outputs, batch_y)).
    • The loss is backpropagated (loss.backward()).
    • Optimizer parameters are updated (optimizer.step()).
    • The batch loss is added to the total epoch loss.
  • The average loss per epoch is printed.

Model Evaluation

Single Channel Performance

  • After training, the model is set to evaluation mode (model.eval()).
  • Gradient calculations are disabled (torch.no_grad()).
  • Predictions are made on the X_test data.
  • The predicted class with the highest probability is determined using torch.argmax.
  • Predictions and true labels are converted back to NumPy arrays on the CPU.
  • Accuracy Score: An accuracy of approximately 92.6% is achieved.
  • Confusion Matrix: A confusion matrix is generated to visualize the classification performance across different sleep stages. It reveals that while overall accuracy is high, some sleep stages are misclassified more frequently (e.g., Stage 1 often predicted as REM, Stage 3 as Stage 2).

Two Channel Performance

  • The training and evaluation process is repeated with the two-channel input.
  • The model architecture is adapted to accept two input channels.
  • Accuracy Score: The accuracy increases to approximately 93.5%.
  • Confusion Matrix: The confusion matrix shows improvements in some classifications, with slightly worse performance for specific stages (Stage 1 and 2).

Conclusion and Takeaways

This project successfully demonstrates how to build and train a deep convolutional neural network using PyTorch for sleep stage classification from EEG data. Key takeaways include:

  • Data Handling: Effective loading, exploration, and preprocessing of time-series EEG data using libraries like MNE-Python.
  • Feature Engineering: Segmenting continuous EEG data into fixed-length windows for supervised learning.
  • Model Architecture: Designing and implementing a CNN suitable for time-series analysis.
  • Training Process: Understanding the PyTorch training loop, including loss functions, optimizers, batching, and backpropagation.
  • Evaluation Metrics: Utilizing accuracy and confusion matrices to assess model performance and identify areas for improvement.
  • Impact of Input Features: Demonstrating that using multiple EEG channels can lead to improved classification accuracy compared to a single channel.

The project provides a practical guide for implementing deep learning on physiological time-series data, with actionable insights for further experimentation and model refinement.

Chat with this Video

AI-Powered

Hi! I can answer questions about this video "EEG Sleep Stage Prediction with Neural Networks in Python". What would you like to know?

Chat is based on the transcript of this video and may not be 100% accurate.

Related Videos

Ready to summarize another video?

Summarize YouTube Video