bruAristimunha commited on
Commit
8a53009
·
verified ·
1 Parent(s): 7b6b39a

Add architecture-only model card

Browse files
Files changed (1) hide show
  1. README.md +217 -0
README.md ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ library_name: braindecode
4
+ pipeline_tag: feature-extraction
5
+ tags:
6
+ - eeg
7
+ - biosignal
8
+ - pytorch
9
+ - neuroscience
10
+ - braindecode
11
+ - convolutional
12
+ - transformer
13
+ ---
14
+
15
+ # CTNet
16
+
17
+ CTNet from Zhao, W et al (2024) .
18
+
19
+ > **Architecture-only repository.** This repo documents the
20
+ > `braindecode.models.CTNet` class. **No pretrained weights are
21
+ > distributed here** — instantiate the model and train it on your own
22
+ > data, or fine-tune from a published foundation-model checkpoint
23
+ > separately.
24
+
25
+ ## Quick start
26
+
27
+ ```bash
28
+ pip install braindecode
29
+ ```
30
+
31
+ ```python
32
+ from braindecode.models import CTNet
33
+
34
+ model = CTNet(
35
+ n_chans=22,
36
+ sfreq=250,
37
+ input_window_seconds=4.0,
38
+ n_outputs=4,
39
+ )
40
+ ```
41
+
42
+ The signal-shape arguments above are example defaults — adjust them
43
+ to match your recording.
44
+
45
+ ## Documentation
46
+
47
+ - Full API reference (parameters, references, architecture figure):
48
+ <https://braindecode.org/stable/generated/braindecode.models.CTNet.html>
49
+ - Interactive browser with live instantiation:
50
+ <https://huggingface.co/spaces/braindecode/model-explorer>
51
+ - Source on GitHub: <https://github.com/braindecode/braindecode/blob/master/braindecode/models/ctnet.py#L27>
52
+
53
+ ## Architecture description
54
+
55
+ The block below is the rendered class docstring (parameters,
56
+ references, architecture figure where available).
57
+
58
+ <div class='bd-doc'><main>
59
+ <p>CTNet from Zhao, W et al (2024) [ctnet]_.</p>
60
+ <span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#5cb85c;color:white;font-size:11px;font-weight:600;margin-right:4px;">Convolution</span><span style="display:inline-block;padding:2px 8px;border-radius:4px;background:#56B4E9;color:white;font-size:11px;font-weight:600;margin-right:4px;">Attention/Transformer</span>
61
+
62
+
63
+
64
+ A Convolutional Transformer Network for EEG-Based Motor Imagery Classification
65
+
66
+ .. figure:: https://raw.githubusercontent.com/snailpt/CTNet/main/architecture.png
67
+ :align: center
68
+ :alt: CTNet Architecture
69
+
70
+ CTNet is an end-to-end neural network architecture designed for classifying motor imagery (MI) tasks from EEG signals.
71
+ The model combines convolutional neural networks (CNNs) with a Transformer encoder to capture both local and global temporal dependencies in the EEG data.
72
+
73
+ The architecture consists of three main components:
74
+
75
+ 1. **Convolutional Module**:
76
+
77
+ - Apply :class:`EEGNet` to perform some feature extraction, denoted here as
78
+ _PatchEmbeddingEEGNet module.
79
+
80
+ 2. **Transformer Encoder Module**:
81
+
82
+ - Utilizes multi-head self-attention mechanisms as EEGConformer but
83
+ with residual blocks.
84
+
85
+ 3. **Classifier Module**:
86
+
87
+ - Combines features from both the convolutional module
88
+ and the Transformer encoder.
89
+ - Flattens the combined features and applies dropout for regularization.
90
+ - Uses a fully connected layer to produce the final classification output.
91
+
92
+ Parameters
93
+ ----------
94
+ activation : nn.Module, default=nn.GELU
95
+ Activation function to use in the network.
96
+ num_heads : int, default=4
97
+ Number of attention heads in the Transformer encoder.
98
+ embed_dim : int or None, default=None
99
+ Embedding size (dimensionality) for the Transformer encoder.
100
+ num_layers : int, default=6
101
+ Number of encoder layers in the Transformer.
102
+ n_filters_time : int, default=20
103
+ Number of temporal filters in the first convolutional layer.
104
+ kernel_size : int, default=64
105
+ Kernel size for the temporal convolutional layer.
106
+ depth_multiplier : int, default=2
107
+ Multiplier for the number of depth-wise convolutional filters.
108
+ pool_size_1 : int, default=8
109
+ Pooling size for the first average pooling layer.
110
+ pool_size_2 : int, default=8
111
+ Pooling size for the second average pooling layer.
112
+ cnn_drop_prob: float, default=0.3
113
+ Dropout probability after convolutional layers.
114
+ att_positional_drop_prob : float, default=0.1
115
+ Dropout probability for the positional encoding in the Transformer.
116
+ final_drop_prob : float, default=0.5
117
+ Dropout probability before the final classification layer.
118
+
119
+ Notes
120
+ -----
121
+ This implementation is adapted from the original CTNet source code
122
+ [ctnetcode]_ to comply with Braindecode's model standards.
123
+
124
+ References
125
+ ----------
126
+ .. [ctnet] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
127
+ CTNet: a convolutional transformer network for EEG-based motor imagery
128
+ classification. Scientific Reports, 14(1), 20237.
129
+ .. [ctnetcode] Zhao, W., Jiang, X., Zhang, B., Xiao, S., & Weng, S. (2024).
130
+ CTNet source code:
131
+ https://github.com/snailpt/CTNet
132
+
133
+ .. rubric:: Hugging Face Hub integration
134
+
135
+ When the optional ``huggingface_hub`` package is installed, all models
136
+ automatically gain the ability to be pushed to and loaded from the
137
+ Hugging Face Hub. Install with::
138
+
139
+ pip install braindecode[hub]
140
+
141
+ **Pushing a model to the Hub:**
142
+
143
+ .. code::
144
+ from braindecode.models import CTNet
145
+
146
+ # Train your model
147
+ model = CTNet(n_chans=22, n_outputs=4, n_times=1000)
148
+ # ... training code ...
149
+
150
+ # Push to the Hub
151
+ model.push_to_hub(
152
+ repo_id="username/my-ctnet-model",
153
+ commit_message="Initial model upload",
154
+ )
155
+
156
+ **Loading a model from the Hub:**
157
+
158
+ .. code::
159
+ from braindecode.models import CTNet
160
+
161
+ # Load pretrained model
162
+ model = CTNet.from_pretrained("username/my-ctnet-model")
163
+
164
+ # Load with a different number of outputs (head is rebuilt automatically)
165
+ model = CTNet.from_pretrained("username/my-ctnet-model", n_outputs=4)
166
+
167
+ **Extracting features and replacing the head:**
168
+
169
+ .. code::
170
+ import torch
171
+
172
+ x = torch.randn(1, model.n_chans, model.n_times)
173
+ # Extract encoder features (consistent dict across all models)
174
+ out = model(x, return_features=True)
175
+ features = out["features"]
176
+
177
+ # Replace the classification head
178
+ model.reset_head(n_outputs=10)
179
+
180
+ **Saving and restoring full configuration:**
181
+
182
+ .. code::
183
+ import json
184
+
185
+ config = model.get_config() # all __init__ params
186
+ with open("config.json", "w") as f:
187
+ json.dump(config, f)
188
+
189
+ model2 = CTNet.from_config(config) # reconstruct (no weights)
190
+
191
+ All model parameters (both EEG-specific and model-specific such as
192
+ dropout rates, activation functions, number of filters) are automatically
193
+ saved to the Hub and restored when loading.
194
+
195
+ See :ref:`load-pretrained-models` for a complete tutorial.</main>
196
+ </div>
197
+
198
+ ## Citation
199
+
200
+ Please cite both the original paper for this architecture (see the
201
+ *References* section above) and braindecode:
202
+
203
+ ```bibtex
204
+ @article{aristimunha2025braindecode,
205
+ title = {Braindecode: a deep learning library for raw electrophysiological data},
206
+ author = {Aristimunha, Bruno and others},
207
+ journal = {Zenodo},
208
+ year = {2025},
209
+ doi = {10.5281/zenodo.17699192},
210
+ }
211
+ ```
212
+
213
+ ## License
214
+
215
+ BSD-3-Clause for the model code (matching braindecode).
216
+ Pretraining-derived weights, if you fine-tune from a checkpoint,
217
+ inherit the licence of that checkpoint and its training corpus.