metadata
language: en
license: apache-2.0
app_type: gradio
space: Tumo505/SSL-ECG-Classification
datasets:
- ptb-xl
metrics:
- auroc
- accuracy
tags:
- ecg
- medical
- time-series
- classification
- self-supervised-learning
- ssl
- cardiac
- healthcare
model-index:
- name: SSL-ECG-Classifier
results:
- task:
name: Time Series Classification
type: tabular-classification
dataset:
name: PTB-XL
type: ptb-xl
split: test
args:
fold: 10
metrics:
- name: AUROC
type: auroc
value: 0.8717
- name: Accuracy
type: accuracy
value: 0.8234
inference: true
widget:
- src: >-
https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_normal.csv
example_title: Normal ECG (NORM)
- src: >-
https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_mi.csv
example_title: Myocardial Infarction (MI)
- src: >-
https://huggingface.co/datasets/Tumo505/ecg-samples/resolve/main/example_sttc.csv
example_title: ST/T Changes (STTC)
SSL-ECG-Classifier: Self-Supervised Learning for ECG Classification
Self-Supervised Learning (SSL) pre-trained model for ECG cardiovascular disease classification.
Model Overview
| Property | Value |
|---|---|
| Framework | SimCLR |
| Test AUROC | 0.8717 |
| Test Accuracy | 0.8234 |
| Dataset | PTB-XL (21.8K ECGs) |
| Fine-tuning | 10% labeled data (1,747 samples) |
| Input | 12-lead ECG @ 100 Hz (5,000 samples) |
| Output | 5-class classification |
Classes Predicted
- NORM: Normal ECG
- MI: Myocardial Infarction
- STTC: ST/T Changes
- HYP: Hypertrophy (LVH)
- CD: Conduction Disturbances
Quick Start
Python (Transformers)
import torch
from transformers import AutoModel
# Load model
model = AutoModel.from_pretrained("Tumo505/SSL-ECG-Classificcation", trust_remote_code=True)
model.eval()
# Prepare 12-lead ECG (batch_size, 12 leads, 5000 samples)
ecg = torch.randn(1, 12, 5000)
# Predict
with torch.no_grad():
output = model(ecg)
logits = output["logits"]
probs = torch.softmax(logits, dim=-1)
classes = ["NORM", "MI", "STTC", "HYP", "CD"]
prediction = classes[probs.argmax(dim=-1)[0]]
confidence = probs.max().item()
print(f"Prediction: {prediction} ({confidence:.1%})")
Try Online
Click the "Use this model" button above to test on Gradio Space!
API Endpoint (Deploy)
Click the "Deploy" button to get a live inference endpoint:
curl -X POST https://your-api-url.hf.space/api/predict \
-H "Authorization: Bearer YOUR_HF_TOKEN" \
-H "Content-Type: application/json" \
-d '{
"inputs": [[[... 12-lead ECG array ...]]]
}'
Model Architecture
Input (B × 12 × 5000)
↓
1D CNN Encoder
- Conv1d(12 → 32) + BatchNorm + ReLU + MaxPool
- Conv1d(32 → 64) + BatchNorm + ReLU + MaxPool
- Conv1d(64 → 128) + BatchNorm + ReLU
- AdaptiveAvgPool1d(1) + Flatten
↓
Projection Head (128-dim embedding)
↓
Classification Head (5 classes)
↓
Output (B × 5) logits
Performance Metrics
Test Set Results (PTB-XL Fold 10: 3,044 samples)
Class | Precision | Recall | F1-Score | Support
----------|-----------|--------|----------|----------
NORM | 0.897 | 0.882 | 0.889 | 1,275
MI | 0.856 | 0.834 | 0.845 | 904
STTC | 0.871 | 0.859 | 0.865 | 776
HYP | 0.812 | 0.798 | 0.805 | 356
CD | 0.843 | 0.866 | 0.854 | 733
----------|-----------|--------|----------|----------
Macro Avg | 0.856 | 0.848 | 0.852 | 4,044
Comparison to Baselines
| Model | Framework | AUROC | Accuracy | Method |
|---|---|---|---|---|
| SimCLR (This) | SSL + Supervised | 0.8717 | 0.8234 | Recommended |
| BYOL SSL | SSL momentum | 0.8565 | 0.8134 | Alternative |
| Supervised CNN | None | 0.8606 | 0.8193 | Baseline |
Training Details
Pre-training (Unsupervised SSL)
- Framework: SimCLR
- Epochs: 20
- Batch Size: 128
- Optimizer: Adam (lr=1e-3)
- Loss: Contrastive (NT-Xent with Ï„=0.07)
- Data: All PTB-XL training folds (no labels used)
Fine-tuning (Supervised)
- Labeled Data: 1,747 samples (10% of fold 1-8)
- Epochs: 20 with early stopping (patience=5)
- Batch Size: 32
- Optimizer: Adam (lr=5e-4)
- Loss: Focal Loss with class weights
- Augmentations: Training-time augmentations (same as pre-training)
Domain-Adaptive Augmentations
Applied during SSL pre-training:
- Frequency warping (±5% heart rate variation)
- Medical mixup (ECG-aware blending of two signals)
- Bandpass filtering (physiologically grounded)
- Segment CutMix (temporal masking)
- Motion artifacts (baseline wander simulation)
- Per-channel noise (independent Gaussian)
- Temporal dropout (with interpolation)
Dataset
PTB-XL v1.0.3
Source: https://www.physionet.org/content/ptb-xl/1.0.3/
- Total ECGs: 21,799
- Unique Patients: 18,869
- Recording Rate: 500 Hz → downsampled to 100 Hz
- Leads: 12-lead standard
- Duration: ~10 seconds per recording
Class Distribution:
| Class | Count | Percentage |
|---|---|---|
| NORM | 9,514 | 43.7% |
| MI | 5,469 | 25.1% |
| STTC | 5,235 | 24.0% |
| CD | 4,898 | 22.5% |
| HYP | 2,649 | 12.2% |
Note: Samples can belong to multiple classes
Splits Used:
- Training: Folds 1-8 (17,536 samples)
- Validation: Fold 9 (1,791 samples)
- Test: Fold 10 (3,044 samples)
Limitations & Biases
Limitations
Not validated for clinical use - Research purposes only
- Trained exclusively on PTB-XL; generalization to other datasets unknown
- 12-lead ECG format required; doesn't work with 6-lead or converted signals
- 10% labeled data regime may not reflect full model capacity
- Works only for the 5 trained classes
Potential Biases
- Geographic bias: Primarily European patient population (PTB-XL)
- Hospital bias: Data from hospital patients (not general population)
- Class imbalance: NORM over-represented, HYP under-represented
- Demographic: Skew toward older patients; male/female ratio not controlled
Environmental Impact
- Training: ~12 GPU hours on RTX 5070 Ti
- CO2 Emissions: ~0.5 kg (estimated)
- Inference: ~50ms per 10-second ECG on GPU
License
Apache 2.0 - See LICENSE file in repository
Acknowledgments
- PTB-XL Dataset: Physionet, Wagner et al. (2020)
- SimCLR Framework: Chen et al. (2020)
- Implementation: Built with PyTorch & Hugging Face
Model Card Contact
- Author: Tumo Kgabeng
- GitHub: https://github.com/Tumo505/SSL-for-ECG-classification
Changelog
v1.0 (2026-04-18)
- Initial release
- SimCLR pre-training + supervised fine-tuning
- 10% labeled data regime
- Test AUROC: 0.8717
Questions? Open an issue on GitHub