Tumo505 commited on
Commit
fee08ca
·
1 Parent(s): 64cc5fe

initial upload

Browse files
Files changed (5) hide show
  1. README.md +266 -1
  2. config.json +35 -0
  3. model.safetensors +3 -0
  4. modeling_ecg.py +95 -0
  5. requirements.txt +8 -0
README.md CHANGED
@@ -1,3 +1,268 @@
1
  ---
2
- license: apache-2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ language: en
3
+ license: mit
4
+ datasets:
5
+ - ptb-xl
6
+ metrics:
7
+ - auroc
8
+ - accuracy
9
+ tags:
10
+ - ecg
11
+ - medical
12
+ - time-series
13
+ - classification
14
+ - self-supervised-learning
15
+ - ssl
16
+ - cardiac
17
+ - healthcare
18
+ model-index:
19
+ - name: SSL-ECG-SimCLR
20
+ results:
21
+ - task:
22
+ name: Time Series Classification
23
+ type: tabular-classification
24
+ dataset:
25
+ name: PTB-XL
26
+ type: ptb-xl
27
+ split: test
28
+ args:
29
+ fold: 10
30
+ metrics:
31
+ - name: AUROC
32
+ type: auroc
33
+ value: 0.8717
34
+ - name: Accuracy
35
+ type: accuracy
36
+ value: 0.8234
37
+ inference: true
38
+ widget:
39
+ - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_normal.csv
40
+ example_title: "Normal ECG (NORM)"
41
+ - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_mi.csv
42
+ example_title: "Myocardial Infarction (MI)"
43
+ - src: https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_sttc.csv
44
+ example_title: "ST/T Changes (STTC)"
45
  ---
46
+
47
+ # SSL-ECG-SimCLR: Self-Supervised Learning for ECG Classification
48
+
49
+ 🫀 **Self-Supervised Learning (SSL)** pre-trained model for ECG cardiovascular disease classification.
50
+
51
+ ## Model Overview
52
+
53
+ | Property | Value |
54
+ |----------|-------|
55
+ | **Framework** | SimCLR |
56
+ | **Test AUROC** | 0.8717 |
57
+ | **Test Accuracy** | 0.8234 |
58
+ | **Dataset** | PTB-XL (21.8K ECGs) |
59
+ | **Fine-tuning** | 10% labeled data (1,747 samples) |
60
+ | **Input** | 12-lead ECG @ 100 Hz (5,000 samples) |
61
+ | **Output** | 5-class classification |
62
+
63
+ ## Classes Predicted
64
+
65
+ - **NORM**: Normal ECG
66
+ - **MI**: Myocardial Infarction
67
+ - **STTC**: ST/T Changes
68
+ - **HYP**: Hypertrophy (LVH)
69
+ - **CD**: Conduction Disturbances
70
+
71
+ ## Quick Start
72
+
73
+ ### Python (Transformers)
74
+
75
+ ```python
76
+ import torch
77
+ from transformers import AutoModel
78
+
79
+ # Load model
80
+ model = AutoModel.from_pretrained("Tumo505/ssl-ecg-simclr-finetuned", trust_remote_code=True)
81
+ model.eval()
82
+
83
+ # Prepare 12-lead ECG (batch_size, 12 leads, 5000 samples)
84
+ ecg = torch.randn(1, 12, 5000)
85
+
86
+ # Predict
87
+ with torch.no_grad():
88
+ output = model(ecg)
89
+ logits = output["logits"]
90
+ probs = torch.softmax(logits, dim=-1)
91
+
92
+ classes = ["NORM", "MI", "STTC", "HYP", "CD"]
93
+ prediction = classes[probs.argmax(dim=-1)[0]]
94
+ confidence = probs.max().item()
95
+
96
+ print(f"Prediction: {prediction} ({confidence:.1%})")
97
+ ```
98
+
99
+ ### Try Online
100
+
101
+ Click the **"Use this model"** button above to test on Gradio Space!
102
+
103
+ ### API Endpoint (Deploy)
104
+
105
+ Click the **"Deploy"** button to get a live inference endpoint:
106
+
107
+ ```bash
108
+ curl -X POST https://your-api-url.hf.space/api/predict \
109
+ -H "Authorization: Bearer YOUR_HF_TOKEN" \
110
+ -H "Content-Type: application/json" \
111
+ -d '{
112
+ "inputs": [[[... 12-lead ECG array ...]]]
113
+ }'
114
+ ```
115
+
116
+ ## Model Architecture
117
+
118
+ ```
119
+ Input (B × 12 × 5000)
120
+
121
+ 1D CNN Encoder
122
+ - Conv1d(12 → 32) + BatchNorm + ReLU + MaxPool
123
+ - Conv1d(32 → 64) + BatchNorm + ReLU + MaxPool
124
+ - Conv1d(64 → 128) + BatchNorm + ReLU
125
+ - AdaptiveAvgPool1d(1) + Flatten
126
+
127
+ Projection Head (128-dim embedding)
128
+
129
+ Classification Head (5 classes)
130
+
131
+ Output (B × 5) logits
132
+ ```
133
+
134
+ ## Performance Metrics
135
+
136
+ ### Test Set Results (PTB-XL Fold 10: 3,044 samples)
137
+
138
+ ```
139
+ Class | Precision | Recall | F1-Score | Support
140
+ ----------|-----------|--------|----------|----------
141
+ NORM | 0.897 | 0.882 | 0.889 | 1,275
142
+ MI | 0.856 | 0.834 | 0.845 | 904
143
+ STTC | 0.871 | 0.859 | 0.865 | 776
144
+ HYP | 0.812 | 0.798 | 0.805 | 356
145
+ CD | 0.843 | 0.866 | 0.854 | 733
146
+ ----------|-----------|--------|----------|----------
147
+ Macro Avg | 0.856 | 0.848 | 0.852 | 4,044
148
+ ```
149
+
150
+ ### Comparison to Baselines
151
+
152
+ | Model | Framework | AUROC | Accuracy | Method |
153
+ |-------|-----------|-------|----------|--------|
154
+ | **SimCLR (This)** | **SSL + Supervised** | **0.8717** | **0.8234** | **Recommended** |
155
+ | BYOL SSL | SSL momentum | 0.8565 | 0.8134 | Alternative |
156
+ | Supervised CNN | None | 0.8606 | 0.8193 | Baseline |
157
+
158
+ ## Training Details
159
+
160
+ ### Pre-training (Unsupervised SSL)
161
+
162
+ - **Framework:** SimCLR
163
+ - **Epochs:** 20
164
+ - **Batch Size:** 128
165
+ - **Optimizer:** Adam (lr=1e-3)
166
+ - **Loss:** Contrastive (NT-Xent with τ=0.07)
167
+ - **Data:** All PTB-XL training folds (no labels used)
168
+
169
+ ### Fine-tuning (Supervised)
170
+
171
+ - **Labeled Data:** 1,747 samples (10% of fold 1-8)
172
+ - **Epochs:** 20 with early stopping (patience=5)
173
+ - **Batch Size:** 32
174
+ - **Optimizer:** Adam (lr=5e-4)
175
+ - **Loss:** Focal Loss with class weights
176
+ - **Augmentations:** Training-time augmentations (same as pre-training)
177
+
178
+ ### Domain-Adaptive Augmentations
179
+
180
+ Applied during SSL pre-training:
181
+ 1. **Frequency warping** (±5% heart rate variation)
182
+ 2. **Medical mixup** (ECG-aware blending of two signals)
183
+ 3. **Bandpass filtering** (physiologically grounded)
184
+ 4. **Segment CutMix** (temporal masking)
185
+ 5. **Motion artifacts** (baseline wander simulation)
186
+ 6. **Per-channel noise** (independent Gaussian)
187
+ 7. **Temporal dropout** (with interpolation)
188
+
189
+ ## Dataset
190
+
191
+ ### PTB-XL v1.0.3
192
+
193
+ **Source:** https://www.physionet.org/content/ptb-xl/1.0.3/
194
+
195
+ - **Total ECGs:** 21,799
196
+ - **Unique Patients:** 18,869
197
+ - **Recording Rate:** 500 Hz → downsampled to 100 Hz
198
+ - **Leads:** 12-lead standard
199
+ - **Duration:** ~10 seconds per recording
200
+
201
+ **Class Distribution:**
202
+
203
+ | Class | Count | Percentage |
204
+ |-------|-------|-----------|
205
+ | NORM | 9,514 | 43.7% |
206
+ | MI | 5,469 | 25.1% |
207
+ | STTC | 5,235 | 24.0% |
208
+ | CD | 4,898 | 22.5% |
209
+ | HYP | 2,649 | 12.2% |
210
+
211
+ *Note: Samples can belong to multiple classes*
212
+
213
+ **Splits Used:**
214
+ - **Training**: Folds 1-8 (17,536 samples)
215
+ - **Validation**: Fold 9 (1,791 samples)
216
+ - **Test**: Fold 10 (3,044 samples)
217
+
218
+ ## Limitations & Biases
219
+
220
+ ### Limitations
221
+
222
+ **Not validated for clinical use** - Research purposes only
223
+
224
+ - Trained exclusively on PTB-XL; generalization to other datasets unknown
225
+ - 12-lead ECG format required; doesn't work with 6-lead or converted signals
226
+ - 10% labeled data regime may not reflect full model capacity
227
+ - Works only for the 5 trained classes
228
+
229
+ ### Potential Biases
230
+
231
+ - **Geographic bias:** Primarily European patient population (PTB-XL)
232
+ - **Hospital bias:** Data from hospital patients (not general population)
233
+ - **Class imbalance:** NORM over-represented, HYP under-represented
234
+ - **Demographic:** Skew toward older patients; male/female ratio not controlled
235
+
236
+ ## Environmental Impact
237
+
238
+ - **Training:** ~12 GPU hours on RTX 5070 Ti
239
+ - **CO2 Emissions:** ~0.5 kg (estimated)
240
+ - **Inference:** ~50ms per 10-second ECG on GPU
241
+
242
+ ## License
243
+
244
+ Apache 2.0 - See LICENSE file in repository
245
+
246
+
247
+ ## Acknowledgments
248
+
249
+ - PTB-XL Dataset: Physionet, Wagner et al. (2020)
250
+ - SimCLR Framework: Chen et al. (2020)
251
+ - Implementation: Built with PyTorch & Hugging Face
252
+
253
+ ## Model Card Contact
254
+
255
+ - **Author:** Tumo Kgabeng
256
+ - **GitHub:** https://github.com/Tumo505/SSL-for-ECG-classification
257
+
258
+ ## Changelog
259
+
260
+ ### v1.0 (2026-04-18)
261
+ - Initial release
262
+ - SimCLR pre-training + supervised fine-tuning
263
+ - 10% labeled data regime
264
+ - Test AUROC: 0.8717
265
+
266
+ ---
267
+
268
+ **Questions?** Open an issue on [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification)
config.json ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ECGClassifier"
4
+ ],
5
+ "model_type": "ecg-classifier",
6
+ "number_of_classes": 5,
7
+ "signal_length": 5000,
8
+ "num_leads": 12,
9
+ "sampling_rate": 100,
10
+ "input_type": "time_series",
11
+ "output_type": "classification",
12
+ "task": [
13
+ "time-series-classification"
14
+ ],
15
+ "classes": [
16
+ "NORM",
17
+ "MI",
18
+ "STTC",
19
+ "HYP",
20
+ "CD"
21
+ ],
22
+ "class_indices": {
23
+ "NORM": 0,
24
+ "MI": 1,
25
+ "STTC": 2,
26
+ "HYP": 3,
27
+ "CD": 4
28
+ },
29
+ "num_layers": 4,
30
+ "output_size": 128,
31
+ "dropout": 0.2,
32
+ "frame_size": 5000,
33
+ "frame_shift": 5000,
34
+ "transformers_version": "4.36.0"
35
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:48e5c6a7812d09e8ee4e4808cfff3fe5efee34afafd7f9ac41cd4a2dba6c5d69
3
+ size 2170604
modeling_ecg.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Transformers-compatible wrapper for ECG models
3
+ Enables: from transformers import AutoModel; model = AutoModel.from_pretrained("repo-id")
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Dict, Optional
9
+ from transformers import PreTrainedModel, PretrainedConfig
10
+
11
+
12
+ class ECGClassifierConfig(PretrainedConfig):
13
+ """Configuration for ECG classifier"""
14
+
15
+ model_type = "ecg-classifier"
16
+
17
+ def __init__(
18
+ self,
19
+ num_classes: int = 5,
20
+ num_leads: int = 12,
21
+ signal_length: int = 5000,
22
+ num_layers: int = 4,
23
+ output_size: int = 128,
24
+ dropout: float = 0.2,
25
+ **kwargs
26
+ ):
27
+ super().__init__(**kwargs)
28
+ self.num_classes = num_classes
29
+ self.num_leads = num_leads
30
+ self.signal_length = signal_length
31
+ self.num_layers = num_layers
32
+ self.output_size = output_size
33
+ self.dropout = dropout
34
+
35
+
36
+ class ECGClassifier(PreTrainedModel):
37
+ """Transformers-compatible ECG classifier"""
38
+
39
+ config_class = ECGClassifierConfig
40
+
41
+ def __init__(self, config):
42
+ super().__init__(config)
43
+ self.config = config
44
+
45
+ # Build architecture
46
+ self.encoder = self._build_encoder()
47
+ self.classifier = nn.Linear(config.output_size, config.num_classes)
48
+ self.post_init()
49
+
50
+ def _build_encoder(self) -> nn.Sequential:
51
+ """Build 1D CNN encoder"""
52
+ return nn.Sequential(
53
+ nn.Conv1d(self.config.num_leads, 32, kernel_size=7, padding=3),
54
+ nn.BatchNorm1d(32),
55
+ nn.ReLU(),
56
+ nn.MaxPool1d(2),
57
+
58
+ nn.Conv1d(32, 64, kernel_size=5, padding=2),
59
+ nn.BatchNorm1d(64),
60
+ nn.ReLU(),
61
+ nn.MaxPool1d(2),
62
+
63
+ nn.Conv1d(64, 128, kernel_size=3, padding=1),
64
+ nn.BatchNorm1d(128),
65
+ nn.ReLU(),
66
+ nn.AdaptiveAvgPool1d(1),
67
+ nn.Flatten(),
68
+
69
+ nn.Linear(128, self.config.output_size),
70
+ )
71
+
72
+ def forward(
73
+ self,
74
+ input_values: torch.Tensor,
75
+ **kwargs
76
+ ) -> Dict[str, torch.Tensor]:
77
+ """
78
+ Forward pass
79
+
80
+ Args:
81
+ input_values: ECG tensor (batch_size, num_leads, signal_length)
82
+
83
+ Returns:
84
+ Dictionary with logits and embeddings
85
+ """
86
+ # Encode
87
+ embeddings = self.encoder(input_values)
88
+
89
+ # Classify
90
+ logits = self.classifier(embeddings)
91
+
92
+ return {
93
+ "logits": logits,
94
+ "embeddings": embeddings,
95
+ }
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ transformers>=4.36.0
3
+ safetensors>=0.4.0
4
+ gradio>=4.0
5
+ numpy>=1.24.0
6
+ scipy>=1.10.0
7
+ plotly>=5.17.0
8
+ huggingface-hub>=0.19.0