bruAristimunha commited on
Commit
fb8937f
·
verified ·
1 Parent(s): 32f9a50

Replace with clean markdown card

Browse files
Files changed (1) hide show
  1. README.md +28 -274
README.md CHANGED
@@ -14,13 +14,12 @@ tags:
14
 
15
  # SSTDPN
16
 
17
- SSTDPN from Can Han et al (2025) .
18
 
19
- > **Architecture-only repository.** This repo documents the
20
  > `braindecode.models.SSTDPN` class. **No pretrained weights are
21
- > distributed here** instantiate the model and train it on your own
22
- > data, or fine-tune from a published foundation-model checkpoint
23
- > separately.
24
 
25
  ## Quick start
26
 
@@ -39,292 +38,47 @@ model = SSTDPN(
39
  )
40
  ```
41
 
42
- The signal-shape arguments above are example defaults — adjust them
43
- to match your recording.
44
 
45
  ## Documentation
46
-
47
- - Full API reference (parameters, references, architecture figure):
48
- <https://braindecode.org/stable/generated/braindecode.models.SSTDPN.html>
49
- - Interactive browser with live instantiation:
50
  <https://huggingface.co/spaces/braindecode/model-explorer>
51
  - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/sstdpn.py#L17>
52
 
53
- ## Architecture description
54
-
55
- The block below is the rendered class docstring (parameters,
56
- references, architecture figure where available).
57
-
58
- <div class='bd-doc'><main>
59
- <p>SSTDPN from Can Han et al (2025) [Han2025]_.</p>
60
- <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span>
61
-
62
-
63
-
64
- .. figure:: https://raw.githubusercontent.com/hancan16/SST-DPN/refs/heads/main/figs/framework.png
65
- :align: center
66
- :alt: SSTDPN Architecture
67
- :width: 1000px
68
-
69
- The **Spatial-Spectral** and **Temporal - Dual Prototype Network** (SST-DPN)
70
- is an end-to-end 1D convolutional architecture designed for motor imagery (MI) EEG decoding,
71
- aiming to address challenges related to discriminative feature extraction and
72
- small-sample sizes [Han2025]_.
73
-
74
- The framework systematically addresses three key challenges: multi-channel spatial–spectral
75
- features and long-term temporal features [Han2025]_.
76
-
77
- .. rubric:: Architectural Overview
78
-
79
- SST-DPN consists of a feature extractor (_SSTEncoder, comprising Adaptive Spatial-Spectral
80
- Fusion and Multi-scale Variance Pooling) followed by Dual Prototype Learning classification [Han2025]_.
81
-
82
- 1. **Adaptive Spatial-Spectral Fusion (ASSF)**: Uses :class:`_DepthwiseTemporalConv1d` to generate a
83
- multi-channel spatial-spectral representation, followed by :class:`_SpatSpectralAttn`
84
- (Spatial-Spectral Attention) to model relationships and highlight key spatial-spectral
85
- channels [Han2025]_.
86
-
87
- 2. **Multi-scale Variance Pooling (MVP)**: Applies :class:`_MultiScaleVarPooler` with variance pooling
88
- at multiple temporal scales to capture long-range temporal dependencies, serving as an
89
- efficient alternative to transformers [Han2025]_.
90
-
91
- 3. **Dual Prototype Learning (DPL)**: A training strategy that employs two sets of
92
- prototypes—Inter-class Separation Prototypes (proto_sep) and Intra-class Compact
93
- Prototypes (proto_cpt)—to optimize the feature space, enhancing generalization ability and
94
- preventing overfitting on small datasets [Han2025]_. During inference (forward pass),
95
- classification decisions are based on the distance (dot product) between the
96
- feature vector and proto_sep for each class [Han2025]_.
97
-
98
- .. rubric:: Macro Components
99
-
100
- - `SSTDPN.encoder` **(Feature Extractor)**
101
-
102
- - *Operations.* Combines Adaptive Spatial-Spectral Fusion and Multi-scale Variance Pooling
103
- via an internal :class:`_SSTEncoder`.
104
- - *Role.* Maps the raw MI-EEG trial :math:`X_i \in \mathbb{R}^{C \times T}` to the
105
- feature space :math:`z_i \in \mathbb{R}^d`.
106
-
107
- - `_SSTEncoder.temporal_conv` **(Depthwise Temporal Convolution for Spectral Extraction)**
108
-
109
- - *Operations.* Internal :class:`_DepthwiseTemporalConv1d` applying separate temporal
110
- convolution filters to each channel with kernel size `temporal_conv_kernel_size` and
111
- depth multiplier `n_spectral_filters_temporal` (equivalent to :math:`F_1` in the paper).
112
- - *Role.* Extracts multiple distinct spectral bands from each EEG channel independently.
113
-
114
- - `_SSTEncoder.spt_attn` **(Spatial-Spectral Attention for Channel Gating)**
115
-
116
- - *Operations.* Internal :class:`_SpatSpectralAttn` module using Global Context Embedding
117
- via variance-based pooling, followed by adaptive channel normalization and gating.
118
- - *Role.* Reweights channels in the spatial-spectral dimension to extract efficient and
119
- discriminative features by emphasizing task-relevant regions and frequency bands.
120
-
121
- - `_SSTEncoder.chan_conv` **(Pointwise Fusion across Channels)**
122
-
123
- - *Operations.* A 1D pointwise convolution with `n_fused_filters` output channels
124
- (equivalent to :math:`F_2` in the paper), followed by BatchNorm and the specified
125
- `activation` function (default: ELU).
126
- - *Role.* Fuses the weighted spatial-spectral features across all electrodes to produce
127
- a fused representation :math:`X_{fused} \in \mathbb{R}^{F_2 \times T}`.
128
-
129
- - `_SSTEncoder.mvp` **(Multi-scale Variance Pooling for Temporal Extraction)**
130
-
131
- - *Operations.* Internal :class:`_MultiScaleVarPooler` using :class:`_VariancePool1D`
132
- layers at multiple scales (`mvp_kernel_sizes`), followed by concatenation.
133
- - *Role.* Captures long-range temporal features at multiple time scales. The variance
134
- operation leverages the prior that variance represents EEG spectral power.
135
-
136
- - `SSTDPN.proto_sep` / `SSTDPN.proto_cpt` **(Dual Prototypes)**
137
-
138
- - *Operations.* Learnable vectors optimized during training using prototype learning losses.
139
- The `proto_sep` (Inter-class Separation Prototype) is constrained via L2 weight-normalization
140
- (:math:`\lVert s_i \rVert_2 \leq` `proto_sep_maxnorm`) during inference.
141
- - *Role.* `proto_sep` achieves inter-class separation; `proto_cpt` enhances intra-class compactness.
142
-
143
- .. rubric:: How the information is encoded temporally, spatially, and spectrally
144
-
145
- * **Temporal.**
146
- The initial :class:`_DepthwiseTemporalConv1d` uses a large kernel (e.g., 75). The MVP module employs pooling
147
- kernels that are much larger (e.g., 50, 100, 200 samples) to capture long-term temporal
148
- features effectively. Large kernel pooling layers are shown to be superior to transformer
149
- modules for this task in EEG decoding according to [Han2025]_.
150
-
151
- * **Spatial.**
152
- The initial convolution at the classes :class:`_DepthwiseTemporalConv1d` groups parameter :math:`h=1`,
153
- meaning :math:`F_1` temporal filters are shared across channels. The Spatial-Spectral Attention
154
- mechanism explicitly models the relationships among these channels in the spatial-spectral
155
- dimension, allowing for finer-grained spatial feature modeling compared to conventional
156
- GCNs according to the authors [Han2025]_.
157
- In other words, all electrode channels share :math:`F_1` temporal filters
158
- independently to produce the spatial-spectral representation.
159
-
160
- * **Spectral.**
161
- Spectral information is implicitly extracted via the :math:`F_1` filters in :class:`_DepthwiseTemporalConv1d`.
162
- Furthermore, the use of Variance Pooling (in MVP) explicitly leverages the neurophysiological
163
- prior that the **variance of EEG signals represents their spectral power**, which is an
164
- important feature for distinguishing different MI classes [Han2025]_.
165
-
166
- .. rubric:: Additional Mechanisms
167
-
168
- - **Attention.** A lightweight Spatial-Spectral Attention mechanism models spatial-spectral relationships
169
- at the channel level, distinct from applying attention to deep feature dimensions,
170
- which is common in comparison methods like :class:`ATCNet`.
171
- - **Regularization.** Dual Prototype Learning acts as a regularization technique
172
- by optimizing the feature space to be compact within classes and separated between
173
- classes. This enhances model generalization and classification performance, particularly
174
- useful for limited data typical of MI-EEG tasks, without requiring external transfer
175
- learning data, according to [Han2025]_.
176
-
177
- Notes
178
- -----
179
- * The implementation of the DPL loss functions (:math:`\mathcal{L}_S`, :math:`\mathcal{L}_C`, :math:`\mathcal{L}_{EF}`)
180
- and the optimization of ICPs are typically handled outside the primary ``forward`` method, within the training strategy
181
- (see Ref. 52 in [Han2025]_).
182
- * The default parameters are configured based on the BCI Competition IV 2a dataset.
183
- * The use of Prototype Learning (PL) methods is novel in the field of EEG-MI decoding.
184
- * **Lowest FLOPs:** Achieves the lowest Floating Point Operations (FLOPs) (9.65 M) among competitive
185
- SOTA methods, including braindecode models like :class:`ATCNet` (29.81 M) and
186
- :class:`EEGConformer` (63.86 M), demonstrating computational efficiency [Han2025]_.
187
- * **Transformer Alternative:** Multi-scale Variance Pooling (MVP) provides a accuracy
188
- improvement over temporal attention transformer modules in ablation studies, offering a more
189
- efficient alternative to transformer-based approaches like :class:`EEGConformer` [Han2025]_.
190
-
191
- .. warning::
192
-
193
- **Important:** To utilize the full potential of SSTDPN with Dual Prototype Learning (DPL),
194
- users must implement the DPL optimization strategy outside the model's forward method.
195
- For implementation details and training strategies, please consult the official code at
196
- [Han2025Code]_:
197
- https://github.com/hancan16/SST-DPN/blob/main/train.py
198
-
199
- Parameters
200
- ----------
201
- n_spectral_filters_temporal : int, optional
202
- Number of spectral filters extracted per channel via temporal convolution.
203
- These represent the temporal spectral bands (equivalent to :math:`F_1` in the paper).
204
- Default is 9.
205
-
206
- n_fused_filters : int, optional
207
- Number of output filters after pointwise fusion convolution.
208
- These fuse the spectral filters across all channels (equivalent to :math:`F_2` in the paper).
209
- Default is 48.
210
-
211
- temporal_conv_kernel_size : int, optional
212
- Kernel size for the temporal convolution layer. Controls the receptive field for extracting
213
- spectral information. Default is 75 samples.
214
-
215
- mvp_kernel_sizes : list[int], optional
216
- Kernel sizes for Multi-scale Variance Pooling (MVP) module.
217
- Larger kernels capture long-term temporal dependencies .
218
-
219
- return_features : bool, optional
220
- If True, the forward pass returns (features, logits). If False, returns only logits.
221
- Default is False.
222
-
223
- proto_sep_maxnorm : float, optional
224
- Maximum L2 norm constraint for Inter-class Separation Prototypes during forward pass.
225
- This constraint acts as an implicit force to push features away from the origin. Default is 1.0.
226
-
227
- proto_cpt_std : float, optional
228
- Standard deviation for Intra-class Compactness Prototype initialization. Default is 0.01.
229
-
230
- spt_attn_global_context_kernel : int, optional
231
- Kernel size for global context embedding in Spatial-Spectral Attention module.
232
- Default is 250 samples.
233
-
234
- spt_attn_epsilon : float, optional
235
- Small epsilon value for numerical stability in Spatial-Spectral Attention. Default is 1e-5.
236
-
237
- spt_attn_mode : str, optional
238
- Embedding computation mode for Spatial-Spectral Attention ('var', 'l2', or 'l1').
239
- Default is 'var' (variance-based mean-var operation).
240
-
241
- activation : nn.Module, optional
242
- Activation function to apply after the pointwise fusion convolution in :class:`_SSTEncoder`.
243
- Should be a PyTorch activation module class. Default is nn.ELU.
244
-
245
-
246
- References
247
- ----------
248
- .. [Han2025] Han, C., Liu, C., Wang, J., Wang, Y., Cai, C.,
249
- & Qian, D. (2025). A spatial–spectral and temporal dual
250
- prototype network for motor imagery brain–computer
251
- interface. Knowledge-Based Systems, 315, 113315.
252
- .. [Han2025Code] Han, C., Liu, C., Wang, J., Wang, Y.,
253
- Cai, C., & Qian, D. (2025). A spatial–spectral and
254
- temporal dual prototype network for motor imagery
255
- brain–computer interface. Knowledge-Based Systems,
256
- 315, 113315. GitHub repository.
257
- https://github.com/hancan16/SST-DPN.
258
-
259
- .. rubric:: Hugging Face Hub integration
260
-
261
- When the optional ``huggingface_hub`` package is installed, all models
262
- automatically gain the ability to be pushed to and loaded from the
263
- Hugging Face Hub. Install with::
264
-
265
- pip install braindecode[hub]
266
-
267
- **Pushing a model to the Hub:**
268
-
269
- .. code::
270
- from braindecode.models import SSTDPN
271
-
272
- # Train your model
273
- model = SSTDPN(n_chans=22, n_outputs=4, n_times=1000)
274
- # ... training code ...
275
-
276
- # Push to the Hub
277
- model.push_to_hub(
278
- repo_id="username/my-sstdpn-model",
279
- commit_message="Initial model upload",
280
- )
281
-
282
- **Loading a model from the Hub:**
283
-
284
- .. code::
285
- from braindecode.models import SSTDPN
286
-
287
- # Load pretrained model
288
- model = SSTDPN.from_pretrained("username/my-sstdpn-model")
289
-
290
- # Load with a different number of outputs (head is rebuilt automatically)
291
- model = SSTDPN.from_pretrained("username/my-sstdpn-model", n_outputs=4)
292
-
293
- **Extracting features and replacing the head:**
294
 
295
- .. code::
296
- import torch
297
 
298
- x = torch.randn(1, model.n_chans, model.n_times)
299
- # Extract encoder features (consistent dict across all models)
300
- out = model(x, return_features=True)
301
- features = out["features"]
302
 
303
- # Replace the classification head
304
- model.reset_head(n_outputs=10)
305
 
306
- **Saving and restoring full configuration:**
307
 
308
- .. code::
309
- import json
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- config = model.get_config() # all __init__ params
312
- with open("config.json", "w") as f:
313
- json.dump(config, f)
314
 
315
- model2 = SSTDPN.from_config(config) # reconstruct (no weights)
316
 
317
- All model parameters (both EEG-specific and model-specific such as
318
- dropout rates, activation functions, number of filters) are automatically
319
- saved to the Hub and restored when loading.
320
 
321
- See :ref:`load-pretrained-models` for a complete tutorial.</main>
322
- </div>
323
 
324
  ## Citation
325
 
326
- Please cite both the original paper for this architecture (see the
327
- *References* section above) and braindecode:
328
 
329
  ```bibtex
330
  @article{aristimunha2025braindecode,
 
14
 
15
  # SSTDPN
16
 
17
+ SSTDPN from Can Han et al (2025) [Han2025].
18
 
19
+ > **Architecture-only repository.** Documents the
20
  > `braindecode.models.SSTDPN` class. **No pretrained weights are
21
+ > distributed here.** Instantiate the model and train it on your own
22
+ > data.
 
23
 
24
  ## Quick start
25
 
 
38
  )
39
  ```
40
 
41
+ The signal-shape arguments above are illustrative defaults — adjust to
42
+ match your recording.
43
 
44
  ## Documentation
45
+ - Full API reference: <https://braindecode.org/stable/generated/braindecode.models.SSTDPN.html>
46
+ - Interactive browser (live instantiation, parameter counts):
 
 
47
  <https://huggingface.co/spaces/braindecode/model-explorer>
48
  - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/sstdpn.py#L17>
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ ## Architecture
 
52
 
53
+ ![SSTDPN architecture](https://raw.githubusercontent.com/hancan16/SST-DPN/refs/heads/main/figs/framework.png)
 
 
 
54
 
 
 
55
 
56
+ ## Parameters
57
 
58
+ | Parameter | Type | Description |
59
+ |---|---|---|
60
+ | `n_spectral_filters_temporal` | int, optional | Number of spectral filters extracted per channel via temporal convolution. These represent the temporal spectral bands (equivalent to :math:`F_1` in the paper). Default is 9. |
61
+ | `n_fused_filters` | int, optional | Number of output filters after pointwise fusion convolution. These fuse the spectral filters across all channels (equivalent to :math:`F_2` in the paper). Default is 48. |
62
+ | `temporal_conv_kernel_size` | int, optional | Kernel size for the temporal convolution layer. Controls the receptive field for extracting spectral information. Default is 75 samples. |
63
+ | `mvp_kernel_sizes` | list[int], optional | Kernel sizes for Multi-scale Variance Pooling (MVP) module. Larger kernels capture long-term temporal dependencies . |
64
+ | `return_features` | bool, optional | If True, the forward pass returns (features, logits). If False, returns only logits. Default is False. |
65
+ | `proto_sep_maxnorm` | float, optional | Maximum L2 norm constraint for Inter-class Separation Prototypes during forward pass. This constraint acts as an implicit force to push features away from the origin. Default is 1.0. |
66
+ | `proto_cpt_std` | float, optional | Standard deviation for Intra-class Compactness Prototype initialization. Default is 0.01. |
67
+ | `spt_attn_global_context_kernel` | int, optional | Kernel size for global context embedding in Spatial-Spectral Attention module. Default is 250 samples. |
68
+ | `spt_attn_epsilon` | float, optional | Small epsilon value for numerical stability in Spatial-Spectral Attention. Default is 1e-5. |
69
+ | `spt_attn_mode` | str, optional | Embedding computation mode for Spatial-Spectral Attention ('var', 'l2', or 'l1'). Default is 'var' (variance-based mean-var operation). |
70
+ | `activation` | nn.Module, optional | Activation function to apply after the pointwise fusion convolution in :class:`_SSTEncoder`. Should be a PyTorch activation module class. Default is nn.ELU. |
71
 
 
 
 
72
 
73
+ ## References
74
 
75
+ 1. Han, C., Liu, C., Wang, J., Wang, Y., Cai, C., & Qian, D. (2025). A spatial–spectral and temporal dual prototype network for motor imagery brain–computer interface. Knowledge-Based Systems, 315, 113315.
76
+ 2. Han, C., Liu, C., Wang, J., Wang, Y., Cai, C., & Qian, D. (2025). A spatial–spectral and temporal dual prototype network for motor imagery brain–computer interface. Knowledge-Based Systems, 315, 113315. GitHub repository. https://github.com/hancan16/SST-DPN.
 
77
 
 
 
78
 
79
  ## Citation
80
 
81
+ Cite the original architecture paper (see *References* above) and braindecode:
 
82
 
83
  ```bibtex
84
  @article{aristimunha2025braindecode,