PierreGtch commited on
Commit
5a28bad
·
verified ·
1 Parent(s): ac15189

Re-upload checkpoint in flat state_dict layout.

Browse files

The previous layout nested the encoder and contextualizer under
`encoder_state_dict` / `contextualizer_state_dict` top-level keys, which braindecode's loader did not unwrap: `load_state_dict(..., strict=False)` silently matched 0 of 99 weights and every call to `BENDR.from_pretrained` returned a random model. The flat layout matches braindecode's `BENDR` key namespace (`encoder.encoder.*`, `contextualizer.*`, plus the random classification head), so `from_pretrained(..., strict=True, n_outputs=2)` now loads all 99 pretrained weights. See braindecode PR #992.

Files changed (4) hide show
  1. README.md +10 -102
  2. config.json +16 -9
  3. model.safetensors +3 -0
  4. pytorch_model.bin +2 -2
README.md CHANGED
@@ -1,106 +1,14 @@
1
  ---
2
- license: apache-2.0
3
- datasets:
4
- - Sleep-EDF
5
- - TUAB
6
- - MOABB
7
- language:
8
- - en
9
  tags:
10
- - eeg
11
- - brain
12
- - timeseries
13
- - self-supervised
14
- - transformer
15
- - biomedical
16
- - neuroscience
17
  ---
18
 
19
- # BENDR: BErt-inspired Neural Data Representations
20
-
21
- Pretrained BENDR model for EEG classification tasks. This is the official Braindecode implementation
22
- of BENDR from Kostas et al. (2021).
23
-
24
- ## Model Details
25
-
26
- - **Model Type**: Transformer-based EEG encoder
27
- - **Pretraining**: Self-supervised learning on masked sequence reconstruction
28
- - **Architecture**:
29
- - Convolutional Encoder: 6 blocks with 512 hidden units
30
- - Transformer Contextualizer: 8 layers, 8 attention heads
31
- - Total Parameters: ~157M
32
- - **Input**: Raw EEG signals (20 channels, variable length)
33
- - **Output**: Contextualized representations or class predictions
34
-
35
- ## Usage
36
-
37
- ```python
38
- from braindecode.models import BENDR
39
- import torch
40
-
41
- # Load pretrained model
42
- model = BENDR(n_chans=20, n_outputs=2)
43
-
44
- # Load pretrained weights from Hugging Face
45
- from huggingface_hub import hf_hub_download
46
- checkpoint_path = hf_hub_download(repo_id="braindecode/bendr-pretrained-v1", filename="pytorch_model.bin")
47
- checkpoint = torch.load(checkpoint_path)
48
- model.load_state_dict(checkpoint["model_state_dict"], strict=False)
49
-
50
- # Use for inference
51
- model.eval()
52
- with torch.no_grad():
53
- eeg_data = torch.randn(1, 20, 600) # (batch, channels, time)
54
- predictions = model(eeg_data)
55
- ```
56
-
57
- ## Fine-tuning
58
-
59
- ```python
60
- import torch
61
- from torch.optim import Adam
62
-
63
- # Freeze encoder for transfer learning
64
- for param in model.encoder.parameters():
65
- param.requires_grad = False
66
-
67
- # Fine-tune on downstream task
68
- optimizer = Adam(model.parameters(), lr=0.0001)
69
- ```
70
-
71
- ## Paper
72
-
73
- [BENDR: Using transformers and a contrastive self-supervised learning task to learn from massive amounts of EEG data](https://doi.org/10.3389/fnhum.2021.653659)
74
-
75
- Kostas, D., Aroca-Ouellette, S., & Rudzicz, F. (2021).
76
- Frontiers in Human Neuroscience, 15, 653659.
77
-
78
- ## Citation
79
-
80
- ```bibtex
81
- @article{kostas2021bendr,
82
- title={BENDR: Using transformers and a contrastive self-supervised learning task to learn from massive amounts of EEG data},
83
- author={Kostas, Demetres and Aroca-Ouellette, St{\'e}phane and Rudzicz, Frank},
84
- journal={Frontiers in Human Neuroscience},
85
- volume={15},
86
- pages={653659},
87
- year={2021},
88
- publisher={Frontiers}
89
- }
90
- ```
91
-
92
- ## Implementation Notes
93
-
94
- - Start token is correctly extracted at index 0 (BERT [CLS] convention)
95
- - Uses T-Fixup weight initialization for stability
96
- - Includes LayerDrop for regularization
97
- - All architectural improvements from original paper maintained
98
-
99
- ## License
100
-
101
- Apache 2.0
102
-
103
- ## Authors
104
-
105
- - Braindecode Team
106
- - Original paper: Kostas et al. (2021)
 
1
  ---
2
+ library_name: braindecode
3
+ license: bsd-3-clause
 
 
 
 
 
4
  tags:
5
+ - BENDR
6
+ - braindecode
7
+ - model_hub_mixin
8
+ - pytorch_model_hub_mixin
 
 
 
9
  ---
10
 
11
+ This model has been pushed to the Hub using the [PytorchModelHubMixin](https://huggingface.co/docs/huggingface_hub/package_reference/mixins#huggingface_hub.PyTorchModelHubMixin) integration:
12
+ - Code: https://braindecode.org
13
+ - Paper: [More Information Needed]
14
+ - Docs: https://braindecode.org/stable/generated/braindecode.models.BENDR.html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
config.json CHANGED
@@ -1,16 +1,19 @@
1
  {
2
- "model_type": "bendr",
3
  "n_chans": 20,
 
 
 
 
 
4
  "encoder_h": 512,
5
  "contextualizer_hidden": 3076,
6
- "transformer_heads": 8,
7
- "transformer_layers": 8,
8
- "position_encoder_length": 25,
9
  "drop_prob": 0.1,
10
  "layer_drop": 0.0,
11
- "start_token": -5,
12
- "final_layer": true,
13
- "projection_head": false,
 
14
  "enc_width": [
15
  3,
16
  2,
@@ -27,6 +30,10 @@
27
  2,
28
  2
29
  ],
30
- "notes": "Pretrained BENDR model for EEG classification",
31
- "paper": "https://doi.org/10.3389/fnhum.2021.653659"
 
 
 
 
32
  }
 
1
  {
 
2
  "n_chans": 20,
3
+ "n_outputs": 2,
4
+ "n_times": 1000,
5
+ "chs_info": null,
6
+ "input_window_seconds": null,
7
+ "sfreq": 250,
8
  "encoder_h": 512,
9
  "contextualizer_hidden": 3076,
10
+ "projection_head": false,
 
 
11
  "drop_prob": 0.1,
12
  "layer_drop": 0.0,
13
+ "activation": "torch.nn.modules.activation.GELU",
14
+ "transformer_layers": 8,
15
+ "transformer_heads": 8,
16
+ "position_encoder_length": 25,
17
  "enc_width": [
18
  3,
19
  2,
 
30
  2,
31
  2
32
  ],
33
+ "start_token": -5,
34
+ "final_layer": true,
35
+ "encoder_only": false,
36
+ "n_chans_pretrained": null,
37
+ "chan_proj_max_norm": 1.0,
38
+ "braindecode_version": "1.5.0dev0"
39
  }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7e4b432a3206cce274d060aa35286592c47e51181b55b71e791f4515ca6752a5
3
+ size 628580476
pytorch_model.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:58696d59ae4fb3d041837746c6c6225fa851841f68753e8cc28d4ecd4383d828
3
- size 628594288
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1157d9c148850a91443093cc483ee98339a5f25b7948d4386c46de77131fb5c0
3
+ size 628611067