ECG Heartbeat Classification: A Deep Transferable Representation
Paper • 1805.00794 • Published
Unsupervised cardiac arrhythmia detection on ECG signals, built entirely from scratch in PyTorch.
A 1D CNN encoder compresses each heartbeat into noise-robust features. A Masked Autoregressive Flow (MAF) learns the exact probability density of normal beats. Any beat with low log p(x) is flagged as anomalous — no anomaly labels needed during training.
Try the model interactively in the browser: 👉 HF Spaces — Live Demo
| Metric | Score |
|---|---|
| AUROC | 0.9300 |
| F1 Score | 0.8800 |
| Precision | 0.9100 |
| Recall | 0.8500 |
Evaluated on 21,892 test beats from the MIT-BIH Arrhythmia Database (all 5 classes). Trained on normal beats only — the model never sees any anomaly during training.
from huggingface_hub import hf_hub_download
import torch
import json
# Download weights and config
model_path = hf_hub_download("Aibygaurav/ecg-anomaly-maf", "best_model.pt")
config_path = hf_hub_download("Aibygaurav/ecg-anomaly-maf", "train_config.json")
with open(config_path) as f:
config = json.load(f)
# Load model (requires src/ from GitHub repo)
from src.hybrid_model import HybridECGModel
model = HybridECGModel(
input_len = 187,
latent_dim = config["latent_dim"],
n_layers = config["n_layers"],
hidden_dims = config["hidden_dims"],
)
model.load_state_dict(torch.load(model_path, map_location="cpu", weights_only=True))
model.eval()
# Score a beat — higher = more normal, lower = anomaly
import torch
beat = torch.randn(1, 187) # replace with a real normalised ECG beat
with torch.no_grad():
log_prob = model.log_prob(beat).item()
print(f"log p(x) = {log_prob:.2f}")
Input: 187 time steps (one heartbeat)
↓
CNN Encoder
Conv1D blocks → 64-dimensional feature vector
↓
MAF (8 layers, alternating reversal)
Each layer: MADE with hidden dims [512, 512]
↓
z ~ N(0, I) ← base distribution
Two-stage training:
MIT-BIH Arrhythmia Database — preprocessed CSV version (Kaggle)
| Class | Description | Train count |
|---|---|---|
| 0 | Normal | ~72,000 |
| 1 | Supraventricular | ~2,500 |
| 2 | Ventricular | ~7,000 |
| 3 | Fusion | ~800 |
| 4 | Unclassifiable | ~7,000 |
No external normalising flow libraries. All components implemented manually:
MaskedLinear — linear layer with binary autoregressive weight masksMADE — Masked Autoencoder for Distribution EstimationMAFLayer — single affine flow transformation with O(D) log-determinantMAF — stacked flow with exact log_prob and sampleECGEncoder — 1D CNN compressing 187-step waveform to 64 features