bruAristimunha commited on
Commit
6318a15
·
verified ·
1 Parent(s): a6b6c7d

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +365 -0
README.md ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ - transformer
13
+ ---
14
+
15
+ # AttentionBaseNet
16
+
17
+ AttentionBaseNet from Wimpff M et al (2023) .
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.AttentionBaseNet` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import AttentionBaseNet
33
+
34
+ model = AttentionBaseNet(
35
+ n_chans=22,
36
+ sfreq=250,
37
+ input_window_seconds=4.0,
38
+ n_outputs=4,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.AttentionBaseNet.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/attentionbasenet.py#L29>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>AttentionBaseNet from Wimpff M et al (2023) [Martin2023]_.</p>
60
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
61
+
62
+
63
+
64
+ .. figure:: https://content.cld.iop.org/journals/1741-2552/21/3/036020/revision2/jnead48b9f2_hr.jpg
65
+ :align: center
66
+ :alt: AttentionBaseNet Architecture
67
+ :width: 640px
68
+
69
+ .. rubric:: Architectural Overview
70
+
71
+ AttentionBaseNet is a *convolution-first* network with a *channel-attention* stage.
72
+ The end-to-end flow is:
73
+
74
+ - (i) :class:`_FeatureExtractor` learns a temporal filter bank and per-filter spatial
75
+ projections (depthwise across electrodes), then condenses time by pooling;
76
+ - (ii) **Channel Expansion** uses a ``1x1`` convolution to set the feature width;
77
+ - (iii) :class:`_ChannelAttentionBlock` refines features via depthwise–pointwise temporal
78
+ convs and an optional channel-attention module (SE/CBAM/ECA/…);
79
+ - (iv) **Classifier** flattens the sequence and applies a linear readout.
80
+
81
+ This design mirrors shallow CNN pipelines (EEGNet-style stem) but inserts a pluggable
82
+ attention unit that *re-weights channels* (and optionally temporal positions) before
83
+ classification.
84
+
85
+ .. rubric:: Macro Components
86
+
87
+ - :class:`_FeatureExtractor` **(Shallow conv stem → condensed feature map)**
88
+
89
+ - *Operations.*
90
+ - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(1, L_t)`` creates a learned
91
+ FIR-like filter bank with ``n_temporal_filters`` maps.
92
+ - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=n_temporal_filters``)
93
+ with kernel ``(n_chans, 1)`` learns per-filter spatial projections over the full montage.
94
+ - **BatchNorm → ELU → AvgPool → Dropout** stabilize and downsample time.
95
+ - Output shape: ``(B, F2, 1, T₁)`` with ``F2 = n_temporal_filters x spatial_expansion``.
96
+
97
+ *Interpretability/robustness.* Temporal kernels behave as analyzable FIR filters; the
98
+ depthwise spatial step yields rhythm-specific topographies. Pooling acts as a local
99
+ integrator that reduces variance on short EEG windows.
100
+
101
+ - **Channel Expansion**
102
+
103
+ - *Operations.*
104
+ - A ``1x1`` conv → BN → activation maps ``F2 → ch_dim`` without changing
105
+ the temporal length ``T₁`` (shape: ``(B, ch_dim, 1, T₁)``).
106
+ This sets the embedding width for the attention block.
107
+
108
+ - :class:`_ChannelAttentionBlock` **(temporal refinement + channel attention)**
109
+
110
+ - *Operations.*
111
+ - **Depthwise temporal conv** ``(1, L_a)`` (groups=``ch_dim``) + **pointwise ``1x1``**,
112
+ BN and activation → preserves shape ``(B, ch_dim, 1, T₁)`` while refining timing.
113
+ - **Optional attention module** (see *Additional Mechanisms*) applies channel reweighting
114
+ (some variants also apply temporal gating).
115
+ - **AvgPool (1, P₂)** with stride ``(1, S₂)`` and **Dropout** → outputs
116
+ ``(B, ch_dim, 1, T₂)``.
117
+
118
+ *Role.* Emphasizes informative channels (and, in certain modes, salient time steps)
119
+ before the classifier; complements the convolutional priors with adaptive re-weighting.
120
+
121
+ - **Classifier (aggregation + readout)**
122
+
123
+ *Operations.* :class:`torch.nn.Flatten` → :class:`torch.nn.Linear` from
124
+ ``(B, ch_dim·T₂)`` to classes.
125
+
126
+ .. rubric:: Convolutional Details
127
+
128
+ - **Temporal (where time-domain patterns are learned).**
129
+ Wide kernels in the stem (``(1, L_t)``) act as a learned filter bank for oscillatory
130
+ bands/transients; the attention block's depthwise temporal conv (``(1, L_a)``) sharpens
131
+ short-term dynamics after downsampling. Pool sizes/strides (``P₁,S₁`` then ``P₂,S₂``)
132
+ set the token rate and effective temporal resolution.
133
+
134
+ - **Spatial (how electrodes are processed).**
135
+ A depthwise spatial conv with kernel ``(n_chans, 1)`` spans the full montage to
136
+ learn *per-temporal-filter* spatial projections (no cross-filter mixing at this step),
137
+ mirroring the interpretable spatial stage in shallow CNNs.
138
+
139
+ - **Spectral (how frequency content is captured).**
140
+ No explicit Fourier/wavelet transform is used in the stem—spectral selectivity
141
+ emerges from learned temporal kernels. When ``attention_mode="fca"``, a frequency
142
+ channel attention (DCT-based) summarizes frequencies to drive channel weights.
143
+
144
+ .. rubric:: Attention / Sequential Modules
145
+
146
+ - **Type.** Channel attention chosen by ``attention_mode`` (SE, ECA, CBAM, CAT, GSoP,
147
+ EncNet, GE, GCT, SRM, CATLite). Most operate purely on channels; CBAM/CAT additionally
148
+ include temporal attention.
149
+
150
+ - **Shapes.** Input/Output around attention: ``(B, ch_dim, 1, T₁)``. Re-arrangements
151
+ (if any) are internal to the module; the block returns the same shape before pooling.
152
+
153
+ - **Role.** Re-weights channels (and optionally time) to highlight informative sources
154
+ and suppress distractors, improving SNR ahead of the linear head.
155
+
156
+ .. rubric:: Additional Mechanisms
157
+
158
+ **Attention variants at a glance:**
159
+
160
+ - ``"se"``: Squeeze-and-Excitation (global pooling → bottleneck → gates).
161
+ - ``"gsop"``: Global second-order pooling (covariance-aware channel weights).
162
+ - ``"fca"``: Frequency Channel Attention (DCT summary; uses ``seq_len`` and ``freq_idx``).
163
+ - ``"encnet"``: EncNet with learned codewords (uses ``n_codewords``).
164
+ - ``"eca"``: Efficient Channel Attention (local 1-D conv over channel descriptor; uses ``kernel_size``).
165
+ - ``"ge"``: Gather–Excite (context pooling with optional MLP; can use ``extra_params``).
166
+ - ``"gct"``: Gated Channel Transformation (global context normalization + gating).
167
+ - ``"srm"``: Style-based recalibration (mean–std descriptors; optional MLP).
168
+ - ``"cbam"``: Channel then temporal attention (uses ``kernel_size``).
169
+ - ``"cat"`` / ``"catlite"``: Collaborative (channel ± temporal) attention; *lite* omits temporal.
170
+
171
+ **Auto-compatibility on short inputs:**
172
+
173
+ If the input duration is too short for the configured kernels/pools, the implementation
174
+ **automatically rescales** temporal lengths/strides downward (with a warning) to keep
175
+ shapes valid and preserve the pipeline semantics.
176
+
177
+ .. rubric:: Usage and Configuration
178
+
179
+ - ``n_temporal_filters``, ``temporal_filter_length`` and ``spatial_expansion``:
180
+ control the capacity and the number of spatial projections in the stem.
181
+ - ``pool_length_inp``, ``pool_stride_inp`` then ``pool_length``, ``pool_stride``:
182
+ trade temporal resolution for compute; they determine the final sequence length ``T₂``.
183
+ - ``ch_dim``: width after the ``1x1`` expansion and the effective embedding size for attention.
184
+ - ``attention_mode`` + its specific hyperparameters (``reduction_rate``,
185
+ ``kernel_size``, ``seq_len``, ``freq_idx``, ``n_codewords``, ``use_mlp``):
186
+ select and tune the reweighting mechanism.
187
+ - ``drop_prob_inp`` and ``drop_prob_attn``: regularize stem and attention stages.
188
+ - **Training tips.**
189
+
190
+ Start with moderate pooling (e.g., ``P₁=75,S₁=15``) and ELU activations; enable attention
191
+ only after the stem learns stable filters. For small datasets, prefer simpler modes
192
+ (``"se"``, ``"eca"``) before heavier ones (``"gsop"``, ``"encnet"``).
193
+
194
+ Parameters
195
+ ----------
196
+ n_temporal_filters : int, optional
197
+ Number of temporal convolutional filters in the first layer. This defines
198
+ the number of output channels after the temporal convolution.
199
+ Default is 40.
200
+ temp_filter_length : int, default=15
201
+ The length of the temporal filters in the convolutional layers.
202
+ spatial_expansion : int, optional
203
+ Multiplicative factor to expand the spatial dimensions. Used to increase
204
+ the capacity of the model by expanding spatial features. Default is 1.
205
+ pool_length_inp : int, optional
206
+ Length of the pooling window in the input layer. Determines how much
207
+ temporal information is aggregated during pooling. Default is 75.
208
+ pool_stride_inp : int, optional
209
+ Stride of the pooling operation in the input layer. Controls the
210
+ downsampling factor in the temporal dimension. Default is 15.
211
+ drop_prob_inp : float, optional
212
+ Dropout rate applied after the input layer. This is the probability of
213
+ zeroing out elements during training to prevent overfitting.
214
+ Default is 0.5.
215
+ ch_dim : int, optional
216
+ Number of channels in the subsequent convolutional layers. This controls
217
+ the depth of the network after the initial layer. Default is 16.
218
+ attention_mode : str, optional
219
+ The type of attention mechanism to apply. If `None`, no attention is applied.
220
+
221
+ - "se" for Squeeze-and-excitation network
222
+ - "gsop" for Global Second-Order Pooling
223
+ - "fca" for Frequency Channel Attention Network
224
+ - "encnet" for context encoding module
225
+ - "eca" for Efficient channel attention for deep convolutional neural networks
226
+ - "ge" for Gather-Excite
227
+ - "gct" for Gated Channel Transformation
228
+ - "srm" for Style-based Recalibration Module
229
+ - "cbam" for Convolutional Block Attention Module
230
+ - "cat" for Learning to collaborate channel and temporal attention
231
+ from multi-information fusion
232
+ - "catlite" for Learning to collaborate channel attention
233
+ from multi-information fusion (lite version, cat w/o temporal attention)
234
+
235
+ pool_length : int, default=8
236
+ The length of the window for the average pooling operation.
237
+ pool_stride : int, default=8
238
+ The stride of the average pooling operation.
239
+ drop_prob_attn : float, default=0.5
240
+ The dropout rate for regularization for the attention layer. Values should be between 0 and 1.
241
+ reduction_rate : int, default=4
242
+ The reduction rate used in the attention mechanism to reduce dimensionality
243
+ and computational complexity.
244
+ use_mlp : bool, default=False
245
+ Flag to indicate whether an MLP (Multi-Layer Perceptron) should be used within
246
+ the attention mechanism for further processing.
247
+ freq_idx : int, default=0
248
+ DCT index used in fca attention mechanism.
249
+ n_codewords : int, default=4
250
+ The number of codewords (clusters) used in attention mechanisms that employ
251
+ quantization or clustering strategies.
252
+ kernel_size : int, default=9
253
+ The kernel size used in certain types of attention mechanisms for convolution
254
+ operations.
255
+ activation : type[nn.Module] = nn.ELU,
256
+ Activation function class to apply. Should be a PyTorch activation
257
+ module class like ``nn.ReLU`` or ``nn.ELU``. Default is ``nn.ELU``.
258
+ extra_params : bool, default=False
259
+ Flag to indicate whether additional, custom parameters should be passed to
260
+ the attention mechanism.
261
+
262
+ Notes
263
+ -----
264
+ - Sequence length after each stage is computed internally; the final classifier expects
265
+ a flattened ``ch_dim x T₂`` vector.
266
+ - Attention operates on *channel* dimension by design; temporal gating exists only in
267
+ specific variants (CBAM/CAT).
268
+ - The paper and original code with more details about the methodological
269
+ choices are available at the [Martin2023]_ and [MartinCode]_.
270
+
271
+ .. versionadded:: 0.9
272
+
273
+ References
274
+ ----------
275
+ .. [Martin2023] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B., 2023.
276
+ EEG motor imagery decoding: A framework for comparative analysis with
277
+ channel attention mechanisms. arXiv preprint arXiv:2310.11198.
278
+ .. [MartinCode] Wimpff, M., Gizzi, L., Zerfowski, J. and Yang, B.
279
+ GitHub https://github.com/martinwimpff/channel-attention (accessed 2024-03-28)
280
+
281
+ .. rubric:: Hugging Face Hub integration
282
+
283
+ When the optional ``huggingface_hub`` package is installed, all models
284
+ automatically gain the ability to be pushed to and loaded from the
285
+ Hugging Face Hub. Install with::
286
+
287
+ pip install braindecode[hub]
288
+
289
+ **Pushing a model to the Hub:**
290
+
291
+ .. code::
292
+ from braindecode.models import AttentionBaseNet
293
+
294
+ # Train your model
295
+ model = AttentionBaseNet(n_chans=22, n_outputs=4, n_times=1000)
296
+ # ... training code ...
297
+
298
+ # Push to the Hub
299
+ model.push_to_hub(
300
+ repo_id="username/my-attentionbasenet-model",
301
+ commit_message="Initial model upload",
302
+ )
303
+
304
+ **Loading a model from the Hub:**
305
+
306
+ .. code::
307
+ from braindecode.models import AttentionBaseNet
308
+
309
+ # Load pretrained model
310
+ model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model")
311
+
312
+ # Load with a different number of outputs (head is rebuilt automatically)
313
+ model = AttentionBaseNet.from_pretrained("username/my-attentionbasenet-model", n_outputs=4)
314
+
315
+ **Extracting features and replacing the head:**
316
+
317
+ .. code::
318
+ import torch
319
+
320
+ x = torch.randn(1, model.n_chans, model.n_times)
321
+ # Extract encoder features (consistent dict across all models)
322
+ out = model(x, return_features=True)
323
+ features = out["features"]
324
+
325
+ # Replace the classification head
326
+ model.reset_head(n_outputs=10)
327
+
328
+ **Saving and restoring full configuration:**
329
+
330
+ .. code::
331
+ import json
332
+
333
+ config = model.get_config() # all __init__ params
334
+ with open("config.json", "w") as f:
335
+ json.dump(config, f)
336
+
337
+ model2 = AttentionBaseNet.from_config(config) # reconstruct (no weights)
338
+
339
+ All model parameters (both EEG-specific and model-specific such as
340
+ dropout rates, activation functions, number of filters) are automatically
341
+ saved to the Hub and restored when loading.
342
+
343
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
344
+ </div>
345
+
346
+ ## Citation
347
+
348
+ Please cite both the original paper for this architecture (see the
349
+ *References* section above) and braindecode:
350
+
351
+ ```bibtex
352
+ @article{aristimunha2025braindecode,
353
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
354
+ author = {Aristimunha, Bruno and others},
355
+ journal = {Zenodo},
356
+ year = {2025},
357
+ doi = {10.5281/zenodo.17699192},
358
+ }
359
+ ```
360
+
361
+ ## License
362
+
363
+ BSD-3-Clause for the model code (matching braindecode).
364
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
365
+ inherit the licence of that checkpoint and its training corpus.