bruAristimunha commited on
Commit
3fe80ce
·
verified ·
1 Parent(s): 70bda5d

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +309 -0
README.md ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - foundation-model
12
+ - convolutional
13
+ - transformer
14
+ ---
15
+
16
+ # PBT
17
+
18
+ Patched Brain Transformer (PBT) model from Klein et al (2025) .
19
+
20
+ > **Architecture-only repository.** This repo documents the
21
+ > `braindecode.models.PBT` class. **No pretrained weights are
22
+ > distributed here** — instantiate the model and train it on your own
23
+ > data, or fine-tune from a published foundation-model checkpoint
24
+ > separately.
25
+
26
+ ## Quick start
27
+
28
+ ```bash
29
+ pip install braindecode
30
+ ```
31
+
32
+ ```python
33
+ from braindecode.models import PBT
34
+
35
+ model = PBT(
36
+ n_chans=22,
37
+ sfreq=250,
38
+ input_window_seconds=4.0,
39
+ n_outputs=4,
40
+ )
41
+ ```
42
+
43
+ The signal-shape arguments above are example defaults — adjust them
44
+ to match your recording.
45
+
46
+ ## Documentation
47
+
48
+ - Full API reference (parameters, references, architecture figure):
49
+ <https://braindecode.org/stable/generated/braindecode.models.PBT.html>
50
+ - Interactive browser with live instantiation:
51
+ <https://huggingface.co/spaces/braindecode/model-explorer>
52
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/patchedtransformer.py#L17>
53
+
54
+ ## Architecture description
55
+
56
+ The block below is the rendered class docstring (parameters,
57
+ references, architecture figure where available).
58
+
59
+ <div class='bd-doc'><main>
60
+ <p>Patched Brain Transformer (PBT) model from Klein et al (2025) [pbt]_.</p>
61
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#d9534f;color:white;font-size:11px;font-weight:600;margin-right:4px;">Foundation Model</span>
62
+
63
+
64
+
65
+ This implementation was based in https://github.com/timonkl/PatchedBrainTransformer/
66
+
67
+ .. figure:: https://raw.githubusercontent.com/timonkl/PatchedBrainTransformer/refs/heads/main/PBT_sketch.png
68
+ :align: center
69
+ :alt: Patched Brain Transformer Architecture
70
+ :width: 680px
71
+
72
+ PBT tokenizes EEG trials into per-channel patches, linearly projects each
73
+ patch to a model embedding dimension, prepends a classification token and
74
+ adds channel-aware positional embeddings. The token sequence is processed
75
+ by a Transformer encoder stack and classification is performed from the
76
+ classification token.
77
+
78
+ .. rubric:: Macro Components
79
+
80
+ - ``PBT.tokenization`` **(patch extraction)**
81
+
82
+ *Operations.* The pre-processed EEG signal :math:`X \in \mathbb{R}^{C \times T}`
83
+ (with :math:`C = \text{n_chans}` and :math:`T = \text{n_times}`) is divided into
84
+ non-overlapping patches of size :math:`d_{\text{input}}` along the time axis.
85
+ This process yields :math:`N` total patches, calculated as
86
+ :math:`N = C \left\lfloor \frac{T}{D} \right\rfloor` (where :math:`D = d_{\text{input}}`).
87
+ When time shifts are applied, :math:`N` decreases to
88
+ :math:`N = C \left\lfloor \frac{T - T_{\text{aug}}}{D} \right\rfloor`.
89
+
90
+ *Role.* Tokenizes EEG trials into fixed-size, per-channel patches so the model
91
+ remains adaptive to different numbers of channels and recording lengths.
92
+ Process is inspired by Vision Transformers [visualtransformer]_ and
93
+ adapted for GPT context from [efficient-batchpacking]_.
94
+
95
+ - ``PBT.patch_projection`` **(patch embedding)**
96
+
97
+ *Operations.* The linear layer ``PBT.patch_projection`` maps the tokens from dimension
98
+ :math:`d_{\text{input}}` to the Transformer embedding dimension :math:`d_{\text{model}}`.
99
+ Patches :math:`X_P` are projected as :math:`X_E = X_P W_E^\top`, where
100
+ :math:`W_E \in \mathbb{R}^{d_{\text{model}} \times D}`. In this configuration
101
+ :math:`d_{\text{model}} = 2D` with :math:`D = d_{\text{input}}`.
102
+
103
+ *Interpretability.* Learns periodic structures similar to frequency filters in
104
+ the first convolutional layers of CNNs (for example :class:`~braindecode.models.EEGNet`).
105
+ The learned filters frequently focus on the high-frequency range (20-40 Hz),
106
+ which correlates with beta and gamma waves linked to higher concentration levels.
107
+
108
+ - ``PBT.cls_token`` **(classification token)**
109
+
110
+ *Operations.* A classification token :math:`[c_{\text{ls}}] \in \mathbb{R}^{1 \times d_{\text{model}}}`
111
+ is prepended to the projected patch sequence :math:`X_E`. The CLS token can optionally
112
+ be learnable (see ``learnable_cls``).
113
+
114
+ *Role.* Acts as a dedicated readout token that aggregates information through the
115
+ Transformer encoder stack.
116
+
117
+ - ``PBT.pos_embedding`` **(positional embedding)**
118
+
119
+ *Operations.* Positional indices are generated by ``PBT.linear_projection``, an instance
120
+ of :class:`~braindecode.models.patchedtransformer._ChannelEncoding`, and mapped to vectors
121
+ through :class:`~torch.nn.Embedding`. The embedding table
122
+ :math:`W_{\text{pos}} \in \mathbb{R}^{(N+1) \times d_{\text{model}}}` is added to the token
123
+ sequence, yielding :math:`X_{\text{pos}} = [c_{\text{ls}}, X_E] + W_{\text{pos}}`.
124
+
125
+ *Role/Interpretability.* Introduces spatial and temporal dependence to counter the
126
+ position invariance of the Transformer encoder. The learned positional embedding
127
+ exposes spatial relationships, often revealing a symmetric pattern in central regions
128
+ (C1-C6) associated with the motor cortex.
129
+
130
+ - ``PBT.transformer_encoder`` **(sequence processing and attention)**
131
+
132
+ *Operations.* The token sequence passes through :math:`n_{\text{blocks}}` Transformer
133
+ encoder layers. Each block combines a Multi-Head Self-Attention (MHSA) module with
134
+ ``num_heads`` attention heads and a Feed-Forward Network (FFN). Both MHSA
135
+ and FFN use parallel residual connections with Layer Normalization inside the blocks
136
+ and apply dropout (``drop_prob``) within the Transformer components.
137
+
138
+ *Role/Robustness.* Self-attention enables every token to consider all others, capturing
139
+ global temporal and spatial dependencies immediately and adaptively. This architecture
140
+ accommodates arbitrary numbers of patches and channels, supporting pre-training across
141
+ diverse datasets.
142
+
143
+ - ``PBT.final_layer`` **(readout)**
144
+
145
+ *Operations.* A linear layer operates on the processed CLS token only, and the model
146
+ predicts class probabilities as :math:`y = \operatorname{softmax}([c_{\text{ls}}] W_{\text{class}}^\top + b_{\text{class}})`.
147
+
148
+ *Role.* Performs the final classification from the information aggregated into the CLS
149
+ token after the Transformer encoder stack.
150
+
151
+ .. rubric:: Convolutional Details
152
+
153
+ PBT omits convolutional layers; equivalent feature extraction is carried out by the patch
154
+ pipeline and attention stack.
155
+
156
+ * **Temporal.** Tokenization slices the EEG into fixed windows of size :math:`D = d_{\text{input}}`
157
+ (for the default configuration, :math:`D=64` samples :math:`\approx 0.256\,\text{s}` at
158
+ :math:`250\,\text{Hz}`), while ``PBT.patch_projection`` learns periodic patterns within each
159
+ patch. The Transformer encoder then models long- and short-range temporal dependencies through
160
+ self-attention.
161
+
162
+ * **Spatial.** Patches are channel-specific, keeping the architecture adaptive to any electrode
163
+ montage. Channel-aware positional encodings :math:`W_{\text{pos}}` capture relationships between
164
+ nearby sensors; learned embeddings often form symmetric motifs across motor cortex electrodes
165
+ (C1–C6), and self-attention propagates information across all channels jointly.
166
+
167
+ * **Spectral.** ``PBT.patch_projection`` acts similarly to the first convolutional layer in
168
+ :class:`~braindecode.models.EEGNet`, learning frequency-selective filters without an explicit
169
+ Fourier transform. The highest-energy filters typically reside between :math:`20` and
170
+ :math:`40\,\text{Hz}`, aligning with beta/gamma rhythms tied to focused motor imagery.
171
+
172
+ .. rubric:: Attention / Sequential Modules
173
+
174
+ * **Attention Details.** ``PBT.transformer_encoder`` stacks :math:`n_{\text{blocks}}` Transformer
175
+ encoder layers with Multi-Head Self-Attention. Every token attends to all others, enabling
176
+ immediate global integration across time and channels and supporting heterogeneous datasets.
177
+ Attention rollout visualisations highlight strong activations over motor cortex electrodes
178
+ (C3, C4, Cz) during motor imagery decoding.
179
+
180
+
181
+ .. warning::
182
+
183
+ **Important:** As the other Foundation Models in Braindecode, :class:`PBT` is
184
+ designed for large-scale pre-training and fine-tuning. Training from
185
+ scratch on small datasets may lead to suboptimal results. Cross-Dataset
186
+ pre-training and subsequent fine-tuning is recommended to leverage the
187
+ full potential of this architecture.
188
+
189
+ Parameters
190
+ ----------
191
+ d_input : int, optional
192
+ Size (in samples) of each patch (token) extracted along the time axis.
193
+ embed_dim : int, optional
194
+ Transformer embedding dimensionality.
195
+ num_layers : int, optional
196
+ Number of Transformer encoder layers.
197
+ num_heads : int, optional
198
+ Number of attention heads.
199
+ drop_prob : float, optional
200
+ Dropout probability used in Transformer components.
201
+ learnable_cls : bool, optional
202
+ Whether the classification token is learnable.
203
+ bias_transformer : bool, optional
204
+ Whether to use bias in Transformer linear layers.
205
+ activation : nn.Module, optional
206
+ Activation function class to use in Transformer feed-forward layers.
207
+
208
+ References
209
+ ----------
210
+ .. [pbt] Klein, T., Minakowski, P., & Sager, S. (2025).
211
+ Flexible Patched Brain Transformer model for EEG decoding.
212
+ Scientific Reports, 15(1), 1-12.
213
+ https://www.nature.com/articles/s41598-025-86294-3
214
+ .. [visualtransformer] Dosovitskiy, A., Beyer, L., Kolesnikov, A.,
215
+ Weissenborn, D., Zhai, X., Unterthiner, T., Dehghani, M.,
216
+ Minderer, M., Heigold, G., Gelly, S., Uszkoreit, J. & Houlsby,
217
+ N. (2021). An Image is Worth 16x16 Words: Transformers for Image
218
+ Recognition at Scale. International Conference on Learning
219
+ Representations (ICLR).
220
+ .. [efficient-batchpacking] Krell, M. M., Kosec, M., Perez, S. P., &
221
+ Fitzgibbon, A. (2021). Efficient sequence packing without
222
+ cross-contamination: Accelerating large language models without
223
+ impacting performance. arXiv preprint arXiv:2107.02027.
224
+
225
+ .. rubric:: Hugging Face Hub integration
226
+
227
+ When the optional ``huggingface_hub`` package is installed, all models
228
+ automatically gain the ability to be pushed to and loaded from the
229
+ Hugging Face Hub. Install with::
230
+
231
+ pip install braindecode[hub]
232
+
233
+ **Pushing a model to the Hub:**
234
+
235
+ .. code::
236
+ from braindecode.models import PBT
237
+
238
+ # Train your model
239
+ model = PBT(n_chans=22, n_outputs=4, n_times=1000)
240
+ # ... training code ...
241
+
242
+ # Push to the Hub
243
+ model.push_to_hub(
244
+ repo_id="username/my-pbt-model",
245
+ commit_message="Initial model upload",
246
+ )
247
+
248
+ **Loading a model from the Hub:**
249
+
250
+ .. code::
251
+ from braindecode.models import PBT
252
+
253
+ # Load pretrained model
254
+ model = PBT.from_pretrained("username/my-pbt-model")
255
+
256
+ # Load with a different number of outputs (head is rebuilt automatically)
257
+ model = PBT.from_pretrained("username/my-pbt-model", n_outputs=4)
258
+
259
+ **Extracting features and replacing the head:**
260
+
261
+ .. code::
262
+ import torch
263
+
264
+ x = torch.randn(1, model.n_chans, model.n_times)
265
+ # Extract encoder features (consistent dict across all models)
266
+ out = model(x, return_features=True)
267
+ features = out["features"]
268
+
269
+ # Replace the classification head
270
+ model.reset_head(n_outputs=10)
271
+
272
+ **Saving and restoring full configuration:**
273
+
274
+ .. code::
275
+ import json
276
+
277
+ config = model.get_config() # all __init__ params
278
+ with open("config.json", "w") as f:
279
+ json.dump(config, f)
280
+
281
+ model2 = PBT.from_config(config) # reconstruct (no weights)
282
+
283
+ All model parameters (both EEG-specific and model-specific such as
284
+ dropout rates, activation functions, number of filters) are automatically
285
+ saved to the Hub and restored when loading.
286
+
287
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
288
+ </div>
289
+
290
+ ## Citation
291
+
292
+ Please cite both the original paper for this architecture (see the
293
+ *References* section above) and braindecode:
294
+
295
+ ```bibtex
296
+ @article{aristimunha2025braindecode,
297
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
298
+ author = {Aristimunha, Bruno and others},
299
+ journal = {Zenodo},
300
+ year = {2025},
301
+ doi = {10.5281/zenodo.17699192},
302
+ }
303
+ ```
304
+
305
+ ## License
306
+
307
+ BSD-3-Clause for the model code (matching braindecode).
308
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
309
+ inherit the licence of that checkpoint and its training corpus.