YAML Metadata
Warning:
empty or missing yaml metadata in repo card
(https://huggingface.co/docs/hub/model-cards#model-card-metadata)
EEGNet V4 is implemented using Braindecode version 0.8.1 and Skorch version 0.15.
Model details
- Architecture: EEGNet by Lawhern et al.
- Accuracy: 86%
- NonTarget recall: 0.86
- NonTarget precision: 0.97
- Target recall: 0.84
- Target precision: 0.54
Training details
- Trained on the Lee 2019 ERP dataset (http://moabb.neurotechx.com/docs/generated/moabb.datasets.Lee2019_ERP.html#moabb.datasets.Lee2019_ERP)
- Dropout rate of 25%
- Class rebalanced weighting of the labels after data preprocessing
- 8 spatial filters with 2 temporal filters per spatial filter
- Batch size of 128
- Dataset is shuffled and a random 20% is used as a validation set
- trained for 1000 epochs, model with the lowest validation loss is saved
Get started with the Model
from braindecode.models import EEGNetv4
from huggingface_hub import hf_hub_download
from skorch import NeuralNet
import torch.nn as nn
import torch as th
path_params = hf_hub_download(
repo_id='guido151/EEGNetv4',
filename='EEGNetv4_Lee2019_ERP/params.pt',
)
path_optimizer = hf_hub_download(
repo_id='guido151/EEGNetv4',
filename='EEGNetv4_Lee2019_ERP/optimizer.pt',
)
path_history = hf_hub_download(
repo_id='guido151/EEGNetv4',
filename='EEGNetv4_Lee2019_ERP/history.json',
)
path_criterion = hf_hub_download(
repo_id='guido151/EEGNetv4',
filename='EEGNetv4_Lee2019_ERP/criterion.pt',
)
model = EEGNetv4(
n_chans=19,
n_outputs=2,
n_times=128,
)
net = NeuralNet(
model,
criterion=nn.CrossEntropyLoss(weight=th.tensor([1, 1])),
)
net.initialize()
net.load_params(
path_params,
path_optimizer,
path_criterion,
path_history,
)
Get the FID model
def get_fid_model(model: EEGNetv4) -> nn.Module:
fid_model = deepcopy(model)
for i in range(len(fid_model)):
if i >= 14:
fid_model[i] = Identity()
fid_model.eval()
for param in fid_model.parameters():
param.requires_grad = False
return fid_model
Get the IS model
def get_is_model(model: EEGNetv4) -> nn.Module:
is_model = deepcopy(model)
is_model.eval()
for param in is_model.parameters():
param.requires_grad = False
return is_model