EEGNetv4 / README.md
guido151's picture
Update README.md
f14a02d verified
|
raw
history blame
2.4 kB

EEGNet V4 is implemented using Braindecode version 0.8.1 and Skorch version 0.15.

Model details

  • Architecture: EEGNet by Lawhern et al.
  • Accuracy: 86%
  • NonTarget recall: 0.86
  • NonTarget precision: 0.97
  • Target recall: 0.84
  • Target precision: 0.54

Training details

  • Trained on the Lee 2019 ERP dataset (http://moabb.neurotechx.com/docs/generated/moabb.datasets.Lee2019_ERP.html#moabb.datasets.Lee2019_ERP)
  • Dropout rate of 25%
  • Class rebalanced weighting of the labels after data preprocessing
  • 8 spatial filters with 2 temporal filters per spatial filter
  • Batch size of 128
  • Dataset is shuffled and a random 20% is used as a validation set
  • trained for 1000 epochs, model with the lowest validation loss is saved

Get started with the Model

from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from skorch import NeuralNet
import torch.nn as nn
import torch as th

path_params = hf_hub_download(
    repo_id='guido151/EEGNetv4',
    filename='EEGNetv4_Lee2019_ERP/params.pt',
)
path_optimizer = hf_hub_download(
    repo_id='guido151/EEGNetv4',
    filename='EEGNetv4_Lee2019_ERP/optimizer.pt',
)
path_history = hf_hub_download(
    repo_id='guido151/EEGNetv4',
    filename='EEGNetv4_Lee2019_ERP/history.json',
)
path_criterion = hf_hub_download(
    repo_id='guido151/EEGNetv4',
    filename='EEGNetv4_Lee2019_ERP/criterion.pt',
)

model = EEGNetv4(
        n_chans=19,
        n_outputs=2,
        n_times=128,
)
  
net = NeuralNet(
    model,
    criterion=nn.CrossEntropyLoss(weight=th.tensor([1, 1])),
)
net.initialize()
net.load_params(
    path_params,
    path_optimizer,
    path_criterion,
    path_history,
)

Get the FID model

def get_fid_model(model: EEGNetv4) -> nn.Module:
    fid_model = deepcopy(model)
    for i in range(len(fid_model)):
        if i >= 14:
            fid_model[i] = Identity()
    fid_model.eval()
    for param in fid_model.parameters():
        param.requires_grad = False
    return fid_model

Get the IS model

def get_is_model(model: EEGNetv4) -> nn.Module:
    is_model = deepcopy(model)
    is_model.eval()
    for param in is_model.parameters():
        param.requires_grad = False
    return is_model