bruAristimunha commited on
Commit
7fe8119
·
verified ·
1 Parent(s): 06bd7e7

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +319 -0
README.md ADDED
@@ -0,0 +1,319 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ - transformer
13
+ ---
14
+
15
+ # EEGConformer
16
+
17
+ EEG Conformer from Song et al (2022) .
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.EEGConformer` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import EEGConformer
33
+
34
+ model = EEGConformer(
35
+ n_chans=22,
36
+ sfreq=250,
37
+ input_window_seconds=4.0,
38
+ n_outputs=4,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.EEGConformer.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/eegconformer.py#L14>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>EEG Conformer from Song et al (2022) [song2022]_.</p>
60
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
61
+
62
+
63
+
64
+ .. figure:: https://raw.githubusercontent.com/eeyhsong/EEG-Conformer/refs/heads/main/visualization/Fig1.png
65
+ :align: center
66
+ :alt: EEGConformer Architecture
67
+ :width: 600px
68
+
69
+
70
+ .. rubric:: Architectural Overview
71
+
72
+ EEG-Conformer is a *convolution-first* model augmented with a *lightweight transformer
73
+ encoder*. The end-to-end flow is:
74
+
75
+ - (i) :class:`_PatchEmbedding` converts the continuous EEG into a compact sequence of tokens via a
76
+ :class:`ShallowFBCSPNet` temporal–spatial conv stem and temporal pooling;
77
+ - (ii) :class:`_TransformerEncoder` applies small multi-head self-attention to integrate
78
+ longer-range temporal context across tokens;
79
+ - (iii) :class:`_ClassificationHead` aggregates the sequence and performs a linear readout.
80
+ This preserves the strong inductive biases of shallow CNN filter banks while adding
81
+ just enough attention to capture dependencies beyond the pooling horizon [song2022]_.
82
+
83
+ .. rubric:: Macro Components
84
+
85
+ - :class:`_PatchEmbedding` **(Shallow conv stem → tokens)**
86
+
87
+ - *Operations.*
88
+ - A temporal convolution (`:class:torch.nn.Conv2d`) ``(1 x L_t)`` forms a data-driven "filter bank";
89
+ - A spatial convolution (`:class:torch.nn.Conv2d`) (n_chans x 1)`` projects across electrodes,
90
+ collapsing the channel axis into a virtual channel.
91
+ - **Normalization function** :class:`torch.nn.BatchNorm`
92
+ - **Activation function** :class:`torch.nn.ELU`
93
+ - **Average Pooling** :class:`torch.nn.AvgPool` along time (kernel ``(1, P)`` with stride ``(1, S)``)
94
+ - final ``1x1`` :class:`torch.nn.Linear` projection.
95
+
96
+ The result is rearranged to a token sequence ``(B, S_tokens, D)``, where ``D = n_filters_time``.
97
+
98
+ *Interpretability/robustness.* Temporal kernels can be inspected as FIR filters;
99
+ the spatial conv yields channel projections analogous to :class:`ShallowFBCSPNet`'s learned
100
+ spatial filters. Temporal pooling stabilizes statistics and reduces sequence length.
101
+
102
+ - :class:`_TransformerEncoder` **(context over temporal tokens)**
103
+
104
+ - *Operations.*
105
+ - A stack of ``num_layers`` encoder blocks. :class:`_TransformerEncoderBlock`
106
+ - Each block applies LayerNorm :class:`torch.nn.LayerNorm`
107
+ - Multi-Head Self-Attention (``num_heads``) with dropout + residual :class:`MultiHeadAttention` (:class:`torch.nn.Dropout`)
108
+ - LayerNorm :class:`torch.nn.LayerNorm`
109
+ - 2-layer feed-forward (≈4x expansion, :class:`torch.nn.GELU`) with dropout + residual.
110
+
111
+ Shapes remain ``(B, S_tokens, D)`` throughout.
112
+
113
+ *Role.* Small attention focuses on interactions among *temporal patches* (not channels),
114
+ extending effective receptive fields at modest cost.
115
+
116
+ - :class:`ClassificationHead` **(aggregation + readout)**
117
+
118
+ - *Operations*.
119
+ - Flatten, :class:`torch.nn.Flatten` the sequence ``(B, S_tokens·D)`` -
120
+ - MLP (:class:`torch.nn.Linear` → activation (default: :class:`torch.nn.ELU`) → :class:`torch.nn.Dropout` → :class:`torch.nn.Linear`)
121
+ - final Linear to classes.
122
+
123
+ With ``return_features=True``, features before the last Linear can be exported for
124
+ linear probing or downstream tasks.
125
+
126
+ .. rubric:: Convolutional Details
127
+
128
+ - **Temporal (where time-domain patterns are learned).**
129
+ The initial ``(1 x L_t)`` conv per channel acts as a *learned filter bank* for oscillatory
130
+ bands and transients. Subsequent **AvgPool** along time performs local integration,
131
+ converting activations into “patches” (tokens). Pool length/stride control the
132
+ token rate and set the lower bound on temporal context within each token.
133
+
134
+ - **Spatial (how electrodes are processed).**
135
+ A single conv with kernel ``(n_chans x 1)`` spans the full montage to learn spatial
136
+ projections for each temporal feature map, collapsing the channel axis into a
137
+ virtual channel before tokenization. This mirrors the shallow spatial step in
138
+ :class:`ShallowFBCSPNet` (temporal filters → spatial projection → temporal condensation).
139
+
140
+ - **Spectral (how frequency content is captured).**
141
+ No explicit Fourier/wavelet stage is used. Spectral selectivity emerges implicitly
142
+ from the learned temporal kernels; pooling further smooths high-frequency noise.
143
+ The effective spectral resolution is thus governed by ``L_t`` and the pooling
144
+ configuration.
145
+
146
+ .. rubric:: Attention / Sequential Modules
147
+
148
+ - **Type.** Standard multi-head self-attention (MHA) with ``num_heads`` heads over the token sequence.
149
+ - **Shapes.** Input/Output: ``(B, S_tokens, D)``; attention operates along the ``S_tokens`` axis.
150
+ - **Role.** Re-weights and integrates evidence across pooled windows, capturing dependencies
151
+ longer than any single token while leaving channel relationships to the convolutional stem.
152
+ The design is intentionally *small*—attention refines rather than replaces convolutional feature extraction.
153
+
154
+ .. rubric:: Additional Mechanisms
155
+
156
+ - **Parallel with ShallowFBCSPNet.** Both begin with a learned temporal filter bank,
157
+ spatial projection across electrodes, and early temporal condensation.
158
+ :class:`ShallowFBCSPNet` then computes band-power (via squaring/log-variance), whereas
159
+ EEG-Conformer applies BN/ELU and **continues with attention** over tokens to
160
+ refine temporal context before classification.
161
+
162
+ - **Tokenization knob.** ``pool_time_length`` and especially ``pool_time_stride`` set
163
+ the number of tokens ``S_tokens``. Smaller strides → more tokens and higher attention
164
+ capacity (but higher compute); larger strides → fewer tokens and stronger inductive bias.
165
+
166
+ - **Embedding dimension = filters.** ``n_filters_time`` serves double duty as both the
167
+ number of temporal filters in the stem and the transformer's embedding size ``D``,
168
+ simplifying dimensional alignment.
169
+
170
+ .. rubric:: Usage and Configuration
171
+
172
+ - **Instantiation.** Choose ``n_filters_time`` (embedding size ``D``) and
173
+ ``filter_time_length`` to match the rhythms of interest. Tune
174
+ ``pool_time_length/stride`` to trade temporal resolution for sequence length.
175
+ Keep ``num_layers`` modest (e.g., 4–6) and set ``num_heads`` to divide ``D``.
176
+ ``final_fc_length="auto"`` infers the flattened size from PatchEmbedding.
177
+
178
+ Notes
179
+ -----
180
+ The authors recommend using data augmentation before using Conformer,
181
+ e.g. segmentation and recombination,
182
+ Please refer to the original paper and code for more details [ConformerCode]_.
183
+
184
+ The model was initially tuned on 4 seconds of 250 Hz data.
185
+ Please adjust the scale of the temporal convolutional layer,
186
+ and the pooling layer for better performance.
187
+
188
+ .. versionadded:: 0.8
189
+
190
+ We aggregate the parameters based on the parts of the models, or
191
+ when the parameters were used first, e.g. ``n_filters_time``.
192
+
193
+ .. versionadded:: 1.1
194
+
195
+
196
+ Parameters
197
+ ----------
198
+ n_filters_time: int
199
+ Number of temporal filters, defines also embedding size.
200
+ filter_time_length: int
201
+ Length of the temporal filter.
202
+ pool_time_length: int
203
+ Length of temporal pooling filter.
204
+ pool_time_stride: int
205
+ Length of stride between temporal pooling filters.
206
+ drop_prob: float
207
+ Dropout rate of the convolutional layer.
208
+ num_layers: int
209
+ Number of self-attention layers.
210
+ num_heads: int
211
+ Number of attention heads.
212
+ att_drop_prob: float
213
+ Dropout rate of the self-attention layer.
214
+ final_fc_length: int | str
215
+ The dimension of the fully connected layer.
216
+ return_features: bool
217
+ If True, the forward method returns the features before the
218
+ last classification layer. Defaults to False.
219
+ activation: nn.Module
220
+ Activation function as parameter. Default is nn.ELU
221
+ activation_transfor: nn.Module
222
+ Activation function as parameter, applied at the FeedForwardBlock module
223
+ inside the transformer. Default is nn.GeLU
224
+
225
+ References
226
+ ----------
227
+ .. [song2022] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
228
+ conformer: Convolutional transformer for EEG decoding and visualization.
229
+ IEEE Transactions on Neural Systems and Rehabilitation Engineering,
230
+ 31, pp.710-719. https://ieeexplore.ieee.org/document/9991178
231
+ .. [ConformerCode] Song, Y., Zheng, Q., Liu, B. and Gao, X., 2022. EEG
232
+ conformer: Convolutional transformer for EEG decoding and visualization.
233
+ https://github.com/eeyhsong/EEG-Conformer.
234
+
235
+ .. rubric:: Hugging Face Hub integration
236
+
237
+ When the optional ``huggingface_hub`` package is installed, all models
238
+ automatically gain the ability to be pushed to and loaded from the
239
+ Hugging Face Hub. Install with::
240
+
241
+ pip install braindecode[hub]
242
+
243
+ **Pushing a model to the Hub:**
244
+
245
+ .. code::
246
+ from braindecode.models import EEGConformer
247
+
248
+ # Train your model
249
+ model = EEGConformer(n_chans=22, n_outputs=4, n_times=1000)
250
+ # ... training code ...
251
+
252
+ # Push to the Hub
253
+ model.push_to_hub(
254
+ repo_id="username/my-eegconformer-model",
255
+ commit_message="Initial model upload",
256
+ )
257
+
258
+ **Loading a model from the Hub:**
259
+
260
+ .. code::
261
+ from braindecode.models import EEGConformer
262
+
263
+ # Load pretrained model
264
+ model = EEGConformer.from_pretrained("username/my-eegconformer-model")
265
+
266
+ # Load with a different number of outputs (head is rebuilt automatically)
267
+ model = EEGConformer.from_pretrained("username/my-eegconformer-model", n_outputs=4)
268
+
269
+ **Extracting features and replacing the head:**
270
+
271
+ .. code::
272
+ import torch
273
+
274
+ x = torch.randn(1, model.n_chans, model.n_times)
275
+ # Extract encoder features (consistent dict across all models)
276
+ out = model(x, return_features=True)
277
+ features = out["features"]
278
+
279
+ # Replace the classification head
280
+ model.reset_head(n_outputs=10)
281
+
282
+ **Saving and restoring full configuration:**
283
+
284
+ .. code::
285
+ import json
286
+
287
+ config = model.get_config() # all __init__ params
288
+ with open("config.json", "w") as f:
289
+ json.dump(config, f)
290
+
291
+ model2 = EEGConformer.from_config(config) # reconstruct (no weights)
292
+
293
+ All model parameters (both EEG-specific and model-specific such as
294
+ dropout rates, activation functions, number of filters) are automatically
295
+ saved to the Hub and restored when loading.
296
+
297
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
298
+ </div>
299
+
300
+ ## Citation
301
+
302
+ Please cite both the original paper for this architecture (see the
303
+ *References* section above) and braindecode:
304
+
305
+ ```bibtex
306
+ @article{aristimunha2025braindecode,
307
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
308
+ author = {Aristimunha, Bruno and others},
309
+ journal = {Zenodo},
310
+ year = {2025},
311
+ doi = {10.5281/zenodo.17699192},
312
+ }
313
+ ```
314
+
315
+ ## License
316
+
317
+ BSD-3-Clause for the model code (matching braindecode).
318
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
319
+ inherit the licence of that checkpoint and its training corpus.