bruAristimunha commited on
Commit
a5ea1ce
·
verified ·
1 Parent(s): 90a6a4f

Replace with clean markdown card

Browse files
Files changed (1) hide show
  1. README.md +36 -295
README.md CHANGED
@@ -14,13 +14,12 @@ tags:
14
 
15
  # ATCNet
16
 
17
- ATCNet from Altaheri et al (2022) .
18
 
19
- > **Architecture-only repository.** This repo documents the
20
  > `braindecode.models.ATCNet` class. **No pretrained weights are
21
- > distributed here** instantiate the model and train it on your own
22
- > data, or fine-tune from a published foundation-model checkpoint
23
- > separately.
24
 
25
  ## Quick start
26
 
@@ -39,313 +38,55 @@ model = ATCNet(
39
  )
40
  ```
41
 
42
- The signal-shape arguments above are example defaults — adjust them
43
- to match your recording.
44
 
45
  ## Documentation
46
-
47
- - Full API reference (parameters, references, architecture figure):
48
- <https://braindecode.org/stable/generated/braindecode.models.ATCNet.html>
49
- - Interactive browser with live instantiation:
50
  <https://huggingface.co/spaces/braindecode/model-explorer>
51
  - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/atcnet.py#L15>
52
 
53
- ## Architecture description
54
-
55
- The block below is the rendered class docstring (parameters,
56
- references, architecture figure where available).
57
-
58
- <div class='bd-doc'><main>
59
- <p>ATCNet from Altaheri et al (2022) [1]_.</p>
60
- <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#6c757d;color:white;font-size:11px;font-weight:600;margin-right:4px;">Recurrent</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
61
-
62
-
63
-
64
- .. figure:: https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png
65
- :align: center
66
- :alt: ATCNet Architecture
67
- :width: 650px
68
-
69
- .. rubric:: Architectural Overview
70
-
71
- ATCNet is a *convolution-first* architecture augmented with a *lightweight attention–TCN*
72
- sequence module. The end-to-end flow is:
73
-
74
- - (i) :class:`_ConvBlock` learns temporal filter-banks and spatial projections (EEGNet-style),
75
- downsampling time to a compact feature map;
76
-
77
- - (ii) Sliding Windows carve overlapping temporal windows from this map;
78
-
79
- - (iii) for each window, :class:`_AttentionBlock` applies small multi-head self-attention
80
- over time, followed by a :class:`_TCNResidualBlock` stack (causal, dilated);
81
-
82
- - (iv) window-level features are aggregated (mean of window logits or concatenation)
83
- and mapped via a max-norm–constrained linear layer.
84
-
85
- Relative to ViT, ATCNet replaces linear patch projection with learned *temporal–spatial*
86
- convolutions; it processes *parallel* window encoders (attention→TCN) instead of a deep
87
- stack; and swaps the MLP head for a TCN suited to 1-D EEG sequences.
88
-
89
- .. rubric:: Macro Components
90
-
91
- - :class:`_ConvBlock` **(Shallow conv stem → feature map)**
92
-
93
- - *Operations.*
94
- - **Temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_t, 1)`` builds a
95
- FIR-like filter bank (``F1`` maps).
96
- - **Depthwise spatial conv** (:class:`torch.nn.Conv2d`, ``groups=F1``) with kernel
97
- ``(1, n_chans)`` learns per-filter spatial projections (akin to EEGNet's CSP-like step).
98
- - **BN → ELU → AvgPool → Dropout** to stabilize and condense activations.
99
- - **Refining temporal conv** (:class:`torch.nn.Conv2d`) with kernel ``(L_r, 1)`` +
100
- **BN → ELU → AvgPool → Dropout**.
101
-
102
- The output shape is ``(B, F2, T_c, 1)`` with ``F2 = F1·D`` and ``T_c = T/(P1·P2)``.
103
- Temporal kernels behave as FIR filters; the depthwise-spatial conv yields frequency-specific
104
- topographies. Pooling acts as a local integrator, reducing variance and imposing a
105
- useful inductive bias on short EEG windows.
106
-
107
- - **Sliding-Window Sequencer**
108
-
109
- From the condensed time axis (length ``T_c``), ATCNet forms ``n`` overlapping windows
110
- of width ``T_w = T_c - n + 1`` (one start per index). Each window produces a sequence
111
- ``(B, F2, T_w)`` forwarded to its own attention-TCN branch. This creates *parallel*
112
- encoders over shifted contexts and is key to robustness on nonstationary EEG.
113
-
114
- - :class:`_AttentionBlock` **(small MHA on temporal positions)**
115
-
116
- Attention here is *local to a window* and purely temporal.
117
-
118
- - *Operations.*
119
- - Rearrange to ``(B, T_w, F2)``,
120
- - Normalization :class:`torch.nn.LayerNorm`
121
- - Custom MultiHeadAttention :class:`_MHA` (``num_heads=H``, per-head dim ``d_h``) + residual add,
122
- - Dropout :class:`torch.nn.Dropout`
123
- - Rearrange back to ``(B, F2, T_w)``.
124
-
125
- *Role.* Re-weights evidence across the window, letting the model emphasize informative
126
- segments (onsets, bursts) before causal convolutions aggregate history.
127
-
128
- - :class:`_TCNResidualBlock` **(causal dilated temporal CNN)**
129
-
130
- *Operations:*
131
-
132
- - Two :class:`braindecode.modules.CausalConv1d` layers per block with dilation ``1, 2, 4, …``
133
- - Across blocks of `torch.nn.ELU` + `torch.nn.BatchNorm1d` + `torch.nn.Dropout`) +
134
- a residual (identity or 1x1 mapping).
135
- - The final feature used per window is the *last* causal step ``[..., -1]`` (forecast-style).
136
-
137
- *Role.* Efficient long-range temporal integration with stable gradients; the dilated
138
- receptive field complements attention's soft selection.
139
-
140
- - **Aggregation & Classifier**
141
-
142
- *Operations:*
143
-
144
- - Either (a) map each window feature ``(B, F2)`` to logits via :class:`braindecode.modules.MaxNormLinear`
145
- and **average** across windows (default, matching official code), or
146
- - (b) **concatenate** all window features ``(B, n·F2)`` and apply a single :class:`MaxNormLinear`.
147
-
148
- The max-norm constraint regularizes the readout.
149
-
150
- .. rubric:: Convolutional Details
151
-
152
- - **Temporal.** Temporal structure is learned in three places:
153
- - (1) the stem's wide ``(L_t, 1)`` conv (learned filter bank),
154
- - (2) the refining ``(L_r, 1)`` conv after pooling (short-term dynamics), and
155
- - (3) the TCN's causal 1-D convolutions with exponentially increasing dilation
156
- (long-range dependencies). The minimum sequence length required by the TCN stack is
157
- ``(K_t - 1)·2^{L-1} + 1``; the implementation *auto-scales* kernels/pools/windows
158
- when inputs are shorter to preserve feasibility.
159
-
160
- - **Spatial.** A depthwise spatial conv spans the **full montage** (kernel ``(1, n_chans)``),
161
- producing *per-temporal-filter* spatial projections (no cross-filter mixing at this step).
162
- This mirrors EEGNet's interpretability: each temporal filter has its own spatial pattern.
163
-
164
- .. rubric:: Attention / Sequential Modules
165
-
166
- - **Type.** Multi-head self-attention with ``H`` heads and per-head dim ``d_h`` implemented
167
- in :class:`_MHA`, allowing ``embed_dim = H·d_h`` independent of input and output dims.
168
- - **Shapes.** ``(B, F2, T_w) → (B, T_w, F2) → (B, F2, T_w)``. Attention operates along
169
- the **temporal** axis within a window; channels/features stay in the embedding dim ``F2``.
170
- - **Role.** Highlights salient temporal positions prior to causal convolution; small attention
171
- keeps compute modest while improving context modeling over pooled features.
172
-
173
- .. rubric:: Additional Mechanisms
174
-
175
- - **Parallel encoders over shifted windows.** Improves montage/phase robustness by
176
- ensembling nearby contexts rather than committing to a single segmentation.
177
- - **Max-norm classifier.** Enforces weight norm constraints at the readout, a common
178
- stabilization trick in EEG decoding.
179
- - **ViT vs. ATCNet (design choices).** Convolutional *nonlinear* projection rather than
180
- linear patchification; attention followed by **TCN** (not MLP); *parallel* window
181
- encoders rather than stacked encoders.
182
-
183
- .. rubric:: Usage and Configuration
184
-
185
- - ``conv_block_n_filters (F1)``, ``conv_block_depth_mult (D)`` → capacity of the stem
186
- (with ``F2 = F1·D`` feeding attention/TCN), dimensions aligned to ``F2``, like :class:`EEGNet`.
187
- - Pool sizes ``P1,P2`` trade temporal resolution for stability/compute; they set
188
- ``T_c = T/(P1·P2)`` and thus window width ``T_w``.
189
- - ``n_windows`` controls the ensemble over shifts (compute ∝ windows).
190
- - ``num_heads``, ``head_dim`` set attention capacity; keep ``H·d_h ≈ F2``.
191
- - ``tcn_depth``, ``tcn_kernel_size`` govern receptive field; larger values demand
192
- longer inputs (see minimum length above). The implementation warns and *rescales*
193
- kernels/pools/windows if inputs are too short.
194
- - **Aggregation choice.** ``concat=False`` (default, average of per-window logits) matches
195
- the official code; ``concat=True`` mirrors the paper's concatenation variant.
196
-
197
- Parameters
198
- ----------
199
- input_window_seconds : float, optional
200
- Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a
201
- dataset.
202
- sfreq : int, optional
203
- Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in
204
- BCI-IV 2a dataset.
205
- conv_block_n_filters : int
206
- Number temporal filters in the first convolutional layer of the
207
- convolutional block, denoted F1 in figure 2 of the paper [1]_. Defaults
208
- to 16 as in [1]_.
209
- conv_block_kernel_length_1 : int
210
- Length of temporal filters in the first convolutional layer of the
211
- convolutional block, denoted Kc in table 1 of the paper [1]_. Defaults
212
- to 64 as in [1]_.
213
- conv_block_kernel_length_2 : int
214
- Length of temporal filters in the last convolutional layer of the
215
- convolutional block. Defaults to 16 as in [1]_.
216
- conv_block_pool_size_1 : int
217
- Length of first average pooling kernel in the convolutional block.
218
- Defaults to 8 as in [1]_.
219
- conv_block_pool_size_2 : int
220
- Length of first average pooling kernel in the convolutional block,
221
- denoted P2 in table 1 of the paper [1]_. Defaults to 7 as in [1]_.
222
- conv_block_depth_mult : int
223
- Depth multiplier of depthwise convolution in the convolutional block,
224
- denoted D in table 1 of the paper [1]_. Defaults to 2 as in [1]_.
225
- conv_block_dropout : float
226
- Dropout probability used in the convolution block, denoted pc in
227
- table 1 of the paper [1]_. Defaults to 0.3 as in [1]_.
228
- n_windows : int
229
- Number of sliding windows, denoted n in [1]_. Defaults to 5 as in [1]_.
230
- head_dim : int
231
- Embedding dimension used in each self-attention head, denoted dh in
232
- table 1 of the paper [1]_. Defaults to 8 as in [1]_.
233
- num_heads : int
234
- Number of attention heads, denoted H in table 1 of the paper [1]_.
235
- Defaults to 2 as in [1]_.
236
- att_dropout : float
237
- Dropout probability used in the attention block, denoted pa in table 1
238
- of the paper [1]_. Defaults to 0.5 as in [1]_.
239
- tcn_depth : int
240
- Depth of Temporal Convolutional Network block (i.e. number of TCN
241
- Residual blocks), denoted L in table 1 of the paper [1]_. Defaults to 2
242
- as in [1]_.
243
- tcn_kernel_size : int
244
- Temporal kernel size used in TCN block, denoted Kt in table 1 of the
245
- paper [1]_. Defaults to 4 as in [1]_.
246
- tcn_dropout : float
247
- Dropout probability used in the TCN block, denoted pt in table 1
248
- of the paper [1]_. Defaults to 0.3 as in [1]_.
249
- tcn_activation : torch.nn.Module
250
- Nonlinear activation to use. Defaults to nn.ELU().
251
- concat : bool
252
- When ``True``, concatenates each slidding window embedding before
253
- feeding it to a fully-connected layer, as done in [1]_. When ``False``,
254
- maps each slidding window to `n_outputs` logits and average them.
255
- Defaults to ``False`` contrary to what is reported in [1]_, but
256
- matching what the official code does [2]_.
257
- max_norm_const : float
258
- Maximum L2-norm constraint imposed on weights of the last
259
- fully-connected layer. Defaults to 0.25.
260
-
261
- Notes
262
- -----
263
- - Inputs substantially shorter than the implied minimum length trigger **automatic
264
- downscaling** of kernels, pools, windows, and TCN kernel size to maintain validity.
265
- - The attention–TCN sequence operates **per window**; the last causal step is used as the
266
- window feature, aligning the temporal semantics across windows.
267
-
268
- .. versionadded:: 1.1
269
-
270
- - More detailed documentation of the model.
271
-
272
- References
273
- ----------
274
- .. [1] H. Altaheri, G. Muhammad, M. Alsulaiman (2022).
275
- *Physics-informed attention temporal convolutional network for EEG-based motor imagery classification.*
276
- IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
277
- .. [2] Official EEG-ATCNet implementation (TensorFlow):
278
- https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
279
-
280
- .. rubric:: Hugging Face Hub integration
281
-
282
- When the optional ``huggingface_hub`` package is installed, all models
283
- automatically gain the ability to be pushed to and loaded from the
284
- Hugging Face Hub. Install with::
285
-
286
- pip install braindecode[hub]
287
-
288
- **Pushing a model to the Hub:**
289
-
290
- .. code::
291
- from braindecode.models import ATCNet
292
-
293
- # Train your model
294
- model = ATCNet(n_chans=22, n_outputs=4, n_times=1000)
295
- # ... training code ...
296
-
297
- # Push to the Hub
298
- model.push_to_hub(
299
- repo_id="username/my-atcnet-model",
300
- commit_message="Initial model upload",
301
- )
302
-
303
- **Loading a model from the Hub:**
304
-
305
- .. code::
306
- from braindecode.models import ATCNet
307
-
308
- # Load pretrained model
309
- model = ATCNet.from_pretrained("username/my-atcnet-model")
310
-
311
- # Load with a different number of outputs (head is rebuilt automatically)
312
- model = ATCNet.from_pretrained("username/my-atcnet-model", n_outputs=4)
313
-
314
- **Extracting features and replacing the head:**
315
 
316
- .. code::
317
- import torch
318
 
319
- x = torch.randn(1, model.n_chans, model.n_times)
320
- # Extract encoder features (consistent dict across all models)
321
- out = model(x, return_features=True)
322
- features = out["features"]
323
 
324
- # Replace the classification head
325
- model.reset_head(n_outputs=10)
326
 
327
- **Saving and restoring full configuration:**
328
 
329
- .. code::
330
- import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
- config = model.get_config() # all __init__ params
333
- with open("config.json", "w") as f:
334
- json.dump(config, f)
335
 
336
- model2 = ATCNet.from_config(config) # reconstruct (no weights)
337
 
338
- All model parameters (both EEG-specific and model-specific such as
339
- dropout rates, activation functions, number of filters) are automatically
340
- saved to the Hub and restored when loading.
341
 
342
- See :ref:`load-pretrained-models` for a complete tutorial.</main>
343
- </div>
344
 
345
  ## Citation
346
 
347
- Please cite both the original paper for this architecture (see the
348
- *References* section above) and braindecode:
349
 
350
  ```bibtex
351
  @article{aristimunha2025braindecode,
 
14
 
15
  # ATCNet
16
 
17
+ ATCNet from Altaheri et al (2022) [1].
18
 
19
+ > **Architecture-only repository.** Documents the
20
  > `braindecode.models.ATCNet` class. **No pretrained weights are
21
+ > distributed here.** Instantiate the model and train it on your own
22
+ > data.
 
23
 
24
  ## Quick start
25
 
 
38
  )
39
  ```
40
 
41
+ The signal-shape arguments above are illustrative defaults — adjust to
42
+ match your recording.
43
 
44
  ## Documentation
45
+ - Full API reference: <https://braindecode.org/stable/generated/braindecode.models.ATCNet.html>
46
+ - Interactive browser (live instantiation, parameter counts):
 
 
47
  <https://huggingface.co/spaces/braindecode/model-explorer>
48
  - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/atcnet.py#L15>
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ ## Architecture
 
52
 
53
+ ![ATCNet architecture](https://user-images.githubusercontent.com/25565236/185449791-e8539453-d4fa-41e1-865a-2cf7e91f60ef.png)
 
 
 
54
 
 
 
55
 
56
+ ## Parameters
57
 
58
+ | Parameter | Type | Description |
59
+ |---|---|---|
60
+ | `input_window_seconds` | float, optional | Time length of inputs, in seconds. Defaults to 4.5 s, as in BCI-IV 2a dataset. |
61
+ | `sfreq` | int, optional | Sampling frequency of the inputs, in Hz. Default to 250 Hz, as in BCI-IV 2a dataset. |
62
+ | `conv_block_n_filters` | int | Number temporal filters in the first convolutional layer of the convolutional block, denoted F1 in figure 2 of the paper [1]. Defaults to 16 as in [1]. |
63
+ | `conv_block_kernel_length_1` | int | Length of temporal filters in the first convolutional layer of the convolutional block, denoted Kc in table 1 of the paper [1]. Defaults to 64 as in [1]. |
64
+ | `conv_block_kernel_length_2` | int | Length of temporal filters in the last convolutional layer of the convolutional block. Defaults to 16 as in [1]. |
65
+ | `conv_block_pool_size_1` | int | Length of first average pooling kernel in the convolutional block. Defaults to 8 as in [1]. |
66
+ | `conv_block_pool_size_2` | int | Length of first average pooling kernel in the convolutional block, denoted P2 in table 1 of the paper [1]. Defaults to 7 as in [1]. |
67
+ | `conv_block_depth_mult` | int | Depth multiplier of depthwise convolution in the convolutional block, denoted D in table 1 of the paper [1]. Defaults to 2 as in [1]. |
68
+ | `conv_block_dropout` | float | Dropout probability used in the convolution block, denoted pc in table 1 of the paper [1]. Defaults to 0.3 as in [1]. |
69
+ | `n_windows` | int | Number of sliding windows, denoted n in [1]. Defaults to 5 as in [1]. |
70
+ | `head_dim` | int | Embedding dimension used in each self-attention head, denoted dh in table 1 of the paper [1]. Defaults to 8 as in [1]. |
71
+ | `num_heads` | int | Number of attention heads, denoted H in table 1 of the paper [1]. Defaults to 2 as in [1]. |
72
+ | `att_dropout` | float | Dropout probability used in the attention block, denoted pa in table 1 of the paper [1]. Defaults to 0.5 as in [1]. |
73
+ | `tcn_depth` | int | Depth of Temporal Convolutional Network block (i.e. number of TCN Residual blocks), denoted L in table 1 of the paper [1]. Defaults to 2 as in [1]. |
74
+ | `tcn_kernel_size` | int | Temporal kernel size used in TCN block, denoted Kt in table 1 of the paper [1]. Defaults to 4 as in [1]. |
75
+ | `tcn_dropout` | float | Dropout probability used in the TCN block, denoted pt in table 1 of the paper [1]. Defaults to 0.3 as in [1]. |
76
+ | `tcn_activation` | torch.nn.Module | Nonlinear activation to use. Defaults to nn.ELU(). |
77
+ | `concat` | bool | When `True`, concatenates each slidding window embedding before feeding it to a fully-connected layer, as done in [1]. When `False`, maps each slidding window to `n_outputs` logits and average them. Defaults to `False` contrary to what is reported in [1], but matching what the official code does [2]. |
78
+ | `max_norm_const` | float | Maximum L2-norm constraint imposed on weights of the last fully-connected layer. Defaults to 0.25. |
79
 
 
 
 
80
 
81
+ ## References
82
 
83
+ 1. H. Altaheri, G. Muhammad, M. Alsulaiman (2022). *Physics-informed attention temporal convolutional network for EEG-based motor imagery classification.* IEEE Transactions on Industrial Informatics. doi:10.1109/TII.2022.3197419.
84
+ 2. Official EEG-ATCNet implementation (TensorFlow): https://github.com/Altaheri/EEG-ATCNet/blob/main/models.py
 
85
 
 
 
86
 
87
  ## Citation
88
 
89
+ Cite the original architecture paper (see *References* above) and braindecode:
 
90
 
91
  ```bibtex
92
  @article{aristimunha2025braindecode,