Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Gradio interface for ECG classification | |
| Deploy to Hugging Face Spaces | |
| """ | |
| import gradio as gr | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| import plotly.graph_objects as go | |
| from huggingface_hub import hf_hub_download | |
| import tempfile | |
| import shutil | |
| from pathlib import Path | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Constants | |
| REPO_ID = "Tumo505/SSL-ECG-Classification-model-card" | |
| CLASS_LABELS = ["NORM", "MI", "STTC", "HYP", "CD"] | |
| CLASS_COLORS = { | |
| "NORM": "#90EE90", | |
| "MI": "#FF6B6B", | |
| "STTC": "#FFD93D", | |
| "HYP": "#6C5CE7", | |
| "CD": "#A29BFE" | |
| } | |
| # Define model architecture (1D CNN) | |
| class ECGClassifier(nn.Module): | |
| def __init__(self, num_classes=5, num_leads=12, output_size=128): | |
| super().__init__() | |
| self.encoder = nn.Sequential( | |
| nn.Conv1d(num_leads, 32, kernel_size=7, padding=3), | |
| nn.BatchNorm1d(32), | |
| nn.ReLU(), | |
| nn.MaxPool1d(2), | |
| nn.Conv1d(32, 64, kernel_size=5, padding=2), | |
| nn.BatchNorm1d(64), | |
| nn.ReLU(), | |
| nn.MaxPool1d(2), | |
| nn.Conv1d(64, 128, kernel_size=3, padding=1), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| nn.AdaptiveAvgPool1d(1), | |
| nn.Flatten(), | |
| nn.Linear(128, output_size), | |
| ) | |
| self.classifier = nn.Linear(output_size, num_classes) | |
| def forward(self, x): | |
| embeddings = self.encoder(x) | |
| logits = self.classifier(embeddings) | |
| return logits | |
| def load_ecg_file(file_path): | |
| """ | |
| Comprehensive ECG file loader supporting multiple formats | |
| Supported formats: | |
| - Text: CSV, TXT, TSV (any delimiter) | |
| - NumPy: .npy | |
| - PhysioNet: .hea/.dat (WFDB) | |
| - MATLAB: .mat | |
| - HDF5: .h5, .hdf5 | |
| - EDF: .edf (European Data Format) | |
| - DICOM: .dcm | |
| - XML: .xml (HL7 aECG) | |
| - Binary: .raw, .bin, .bat | |
| """ | |
| file_path = str(file_path) | |
| extension = Path(file_path).suffix.lower() | |
| print(f"Loading {extension} format from: {file_path}") | |
| try: | |
| # WFDB Format (.hea/.dat) | |
| if extension == '.hea': | |
| try: | |
| import wfdb | |
| record_path = file_path.replace('.hea', '') | |
| # Check if .dat file exists | |
| dat_path = record_path + '.dat' | |
| if not Path(dat_path).exists(): | |
| raise Exception( | |
| "WFDB format requires TWO files:\n" | |
| f" 1. {Path(file_path).name} (header)\n" | |
| f" 2. {Path(dat_path).name} (data)\n\n" | |
| "Please upload both files and try again, or upload just the .hea or .dat file in a ZIP archive." | |
| ) | |
| record = wfdb.rdrecord(record_path) | |
| ecg = record.p_signal | |
| print(f"WFDB (.hea/.dat) loaded: {ecg.shape}") | |
| return ecg | |
| except Exception as e: | |
| if "WFDB format requires" in str(e): | |
| raise e | |
| raise Exception(f"WFDB error: {str(e)}") | |
| # Handle .dat files (paired with .hea) | |
| elif extension == '.dat': | |
| try: | |
| import wfdb | |
| record_path = file_path.replace('.dat', '') | |
| hea_path = record_path + '.hea' | |
| if not Path(hea_path).exists(): | |
| raise Exception( | |
| "WFDB format requires TWO files:\n" | |
| f" 1. {Path(hea_path).name} (header)\n" | |
| f" 2. {Path(file_path).name} (data)\n\n" | |
| "Please upload both files and try again, or upload both files in a ZIP archive." | |
| ) | |
| record = wfdb.rdrecord(record_path) | |
| ecg = record.p_signal | |
| print(f"WFDB (.hea/.dat) loaded: {ecg.shape}") | |
| return ecg | |
| except Exception as e: | |
| if "WFDB format requires" in str(e): | |
| raise e | |
| raise Exception(f"WFDB error: {str(e)}") | |
| # MATLAB Format (.mat) | |
| elif extension == '.mat': | |
| try: | |
| from scipy import io | |
| mat_data = io.loadmat(file_path) | |
| # Try common variable names | |
| for key in ['ecg', 'ECG', 'signal', 'data', 'val']: | |
| if key in mat_data: | |
| ecg = np.array(mat_data[key]) | |
| print(f"MATLAB loaded ({key}): {ecg.shape}") | |
| return ecg | |
| # If no standard key, use largest array | |
| arrays = {k: v for k, v in mat_data.items() if isinstance(v, np.ndarray) and v.ndim <= 2} | |
| if arrays: | |
| key = max(arrays.keys(), key=lambda k: arrays[k].size) | |
| ecg = arrays[key] | |
| print(f"MATLAB loaded ({key}): {ecg.shape}") | |
| return ecg | |
| raise Exception("No ECG data found in .mat file") | |
| except ImportError: | |
| raise Exception("SciPy required: pip install scipy") | |
| # HDF5 Format (.h5, .hdf5) | |
| elif extension in ['.h5', '.hdf5']: | |
| try: | |
| import h5py | |
| with h5py.File(file_path, 'r') as f: | |
| # Try common keys | |
| for key in ['ecg', 'ECG', 'signal', 'data', 'waveform']: | |
| if key in f: | |
| ecg = np.array(f[key]) | |
| print(f"HDF5 loaded ({key}): {ecg.shape}") | |
| return ecg | |
| # Use first dataset if no standard key | |
| keys = list(f.keys()) | |
| if keys: | |
| key = keys[0] | |
| ecg = np.array(f[key]) | |
| print(f"HDF5 loaded ({key}): {ecg.shape}") | |
| return ecg | |
| raise Exception("No ECG data found in HDF5 file") | |
| except ImportError: | |
| raise Exception("h5py required: pip install h5py") | |
| # EDF Format (.edf) | |
| elif extension == '.edf': | |
| try: | |
| import pyedflib | |
| f = pyedflib.EdfReader(file_path) | |
| n = f.signals_in_file | |
| ecg = np.zeros((n, f.getNSamples()[0])) | |
| for i in range(n): | |
| ecg[i, :] = f.readSignal(i) | |
| f.close() | |
| print(f"EDF loaded: {ecg.shape}") | |
| return ecg | |
| except ImportError: | |
| raise Exception("pyedflib required: pip install pyedflib") | |
| # DICOM Format (.dcm) | |
| elif extension == '.dcm': | |
| try: | |
| import pydicom | |
| ds = pydicom.dcmread(file_path) | |
| # Extract waveform data | |
| if hasattr(ds, 'WaveformSequence') and len(ds.WaveformSequence) > 0: | |
| waveform_item = ds.WaveformSequence[0] | |
| ecg = np.array(waveform_item.WaveformData, dtype=np.float32) | |
| n_channels = waveform_item.NumberOfWaveformChannels | |
| n_samples = waveform_item.NumberofWaveformSamples | |
| ecg = ecg.reshape(n_channels, n_samples) | |
| print(f"DICOM loaded: {ecg.shape}") | |
| return ecg | |
| else: | |
| raise Exception("No waveform data in DICOM file") | |
| except ImportError: | |
| raise Exception("pydicom required: pip install pydicom") | |
| # XML Format (.xml) - HL7 aECG | |
| elif extension == '.xml': | |
| try: | |
| import xml.etree.ElementTree as ET | |
| tree = ET.parse(file_path) | |
| root = tree.getroot() | |
| # Extract waveform data from XML (HL7 aECG structure) | |
| waveforms = [] | |
| for series in root.findall('.//{urn:hl7-org:v3}series'): | |
| data_str = series.text | |
| if data_str: | |
| values = [float(x) for x in data_str.split()] | |
| waveforms.append(values) | |
| if waveforms: | |
| # Pad to same length | |
| max_len = max(len(w) for w in waveforms) | |
| ecg = np.array([np.pad(w, (0, max_len - len(w)), mode='edge') for w in waveforms]) | |
| print(f"XML (HL7 aECG) loaded: {ecg.shape}") | |
| return ecg | |
| else: | |
| raise Exception("No waveform data in XML file") | |
| except Exception as e: | |
| raise Exception(f"XML parsing error: {str(e)}") | |
| # NumPy Format (.npy) | |
| elif extension == '.npy': | |
| ecg = np.load(file_path) | |
| print(f"NumPy loaded: {ecg.shape}") | |
| return ecg | |
| # Binary Formats (.raw, .bin, .bat, .ecg) | |
| elif extension in ['.raw', '.bin', '.bat', '.ecg']: | |
| try: | |
| # Try as float32 binary | |
| ecg = np.fromfile(file_path, dtype=np.float32) | |
| # Reshape if looks like multi-channel | |
| if len(ecg) % 12 == 0: | |
| ecg = ecg.reshape(12, -1) | |
| elif len(ecg) % 2 == 0: | |
| ecg = ecg.reshape(2, -1) | |
| else: | |
| ecg = ecg.reshape(1, -1) | |
| print(f"Binary (float32) loaded: {ecg.shape}") | |
| return ecg | |
| except: | |
| try: | |
| # Try as float64 | |
| ecg = np.fromfile(file_path, dtype=np.float64) | |
| if len(ecg) % 12 == 0: | |
| ecg = ecg.reshape(12, -1) | |
| elif len(ecg) % 2 == 0: | |
| ecg = ecg.reshape(2, -1) | |
| else: | |
| ecg = ecg.reshape(1, -1) | |
| print(f"Binary (float64) loaded: {ecg.shape}") | |
| return ecg | |
| except: | |
| # Try as text | |
| ecg = np.loadtxt(file_path) | |
| if ecg.ndim == 1: | |
| ecg = ecg.reshape(1, -1) | |
| print(f"Binary as text loaded: {ecg.shape}") | |
| return ecg | |
| # Text Formats (CSV, TXT, TSV, SCP-ECG) | |
| else: | |
| try: | |
| # Try space-separated | |
| ecg = np.genfromtxt(file_path, delimiter=None) | |
| except: | |
| try: | |
| # Try comma-separated | |
| ecg = np.loadtxt(file_path, delimiter=',') | |
| except: | |
| try: | |
| # Try tab-separated | |
| ecg = np.loadtxt(file_path, delimiter='\t') | |
| except: | |
| # Try with skiprows for headers | |
| ecg = np.genfromtxt(file_path, delimiter=None, skip_header=1) | |
| if ecg.ndim == 1: | |
| ecg = ecg.reshape(1, -1) | |
| print(f"Text format loaded: {ecg.shape}") | |
| return ecg | |
| except Exception as e: | |
| raise Exception(f"Failed to load {extension} file: {str(e)}") | |
| # Load model | |
| model = None | |
| try: | |
| print("Loading model from Hub...") | |
| model = ECGClassifier(num_classes=len(CLASS_LABELS), num_leads=12, output_size=128) | |
| # Download weights from Hub | |
| weights_path = hf_hub_download(repo_id=REPO_ID, filename="model.safetensors") | |
| # Load safetensors | |
| from safetensors.torch import load_file | |
| state_dict = load_file(weights_path) | |
| # Load weights into model | |
| model.load_state_dict(state_dict, strict=False) | |
| model.to(device) | |
| model.eval() | |
| print("Model loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| def predict_ecg(file_obj): | |
| """Main prediction function - handles single or multiple files""" | |
| if model is None: | |
| return ( | |
| "**Model Loading Error**\n" | |
| "The model failed to load. Please try again or contact support.", | |
| None | |
| ) | |
| try: | |
| # Handle multiple file uploads (list) or single file | |
| file_path = None | |
| if isinstance(file_obj, list): | |
| # Multiple files uploaded | |
| if not file_obj: | |
| return ("**Error**: No files uploaded", None) | |
| # Look for WFDB pairs (.hea + .dat) | |
| files = [str(f.name) if hasattr(f, 'name') else str(f) for f in file_obj] | |
| hea_files = [f for f in files if f.lower().endswith('.hea')] | |
| dat_files = [f for f in files if f.lower().endswith('.dat')] | |
| if hea_files and dat_files: | |
| # WFDB pair detected - both files present | |
| # Copy .dat file next to .hea file for WFDB to work | |
| import shutil | |
| hea_path = hea_files[0] | |
| dat_path = dat_files[0] | |
| # Get directory of .hea file | |
| hea_dir = Path(hea_path).parent | |
| dat_filename = Path(dat_path).name | |
| target_dat_path = hea_dir / dat_filename | |
| # Copy .dat file to same directory as .hea if not already there | |
| if str(target_dat_path) != dat_path: | |
| shutil.copy(dat_path, target_dat_path) | |
| file_path = hea_path | |
| print(f"WFDB pair detected: {hea_path} + {dat_path}") | |
| else: | |
| # No WFDB pair, use first file | |
| file_path = str(file_obj[0].name) if hasattr(file_obj[0], 'name') else str(file_obj[0]) | |
| print(f"Multiple files uploaded, using first: {file_path}") | |
| else: | |
| # Single file (backward compatible) | |
| if isinstance(file_obj, str): | |
| file_path = file_obj | |
| else: | |
| file_path = file_obj.name if hasattr(file_obj, 'name') else str(file_obj) | |
| # Load ECG using universal loader | |
| print(f"Loading file: {file_path}") | |
| try: | |
| ecg = load_ecg_file(file_path) | |
| except Exception as e: | |
| return (f"**Loading Error**: {str(e)}", None) | |
| # Handle 1D array (single sample) | |
| if ecg.ndim == 1: | |
| ecg = ecg.reshape(1, -1) | |
| # Check if first column is class label (UCR format) | |
| # If so, extract just the time series values | |
| if ecg.shape[1] > 5000: # More than likely samples | |
| print("Detected class label in first column, removing it...") | |
| ecg = ecg[:, 1:] # Remove first column (class label) | |
| # Now ecg should be 2D: (num_samples, num_values) | |
| # We need (12, 5000) for our model | |
| # If single sample, use it | |
| if ecg.shape[0] == 1: | |
| values = ecg[0, :] | |
| else: | |
| # Use first sample if multiple | |
| values = ecg[0, :] | |
| print(f"Time series values shape: {values.shape}") | |
| # Handle single-lead data (repeat 12 times for compatibility) | |
| if len(values) < 5000: | |
| print(f"Padding: {len(values)} values β 5000") | |
| values = np.pad(values, (0, 5000 - len(values)), mode='edge') | |
| elif len(values) > 5000: | |
| print(f"Trimming: {len(values)} values β 5000") | |
| values = values[:5000] | |
| # Reshape as (1 lead, 5000 samples) then replicate to 12 leads | |
| print("Replicating single lead to 12 leads for model compatibility...") | |
| ecg = np.tile(values, (12, 1)) | |
| print(f"Final shape: {ecg.shape}") | |
| # Validation | |
| if ecg.ndim != 2 or ecg.shape[0] != 12 or ecg.shape[1] != 5000: | |
| return ( | |
| f"**Shape Error**\n" | |
| f"Final shape: {ecg.shape}, expected (12, 5000)\n" | |
| "File format not supported.", | |
| None | |
| ) | |
| # Resize to 5000 samples (already done in loading, but ensure consistency) | |
| if ecg.shape[1] != 5000: | |
| if ecg.shape[1] < 5000: | |
| ecg = np.pad(ecg, ((0, 0), (0, 5000 - ecg.shape[1])), mode='edge') | |
| else: | |
| ecg = ecg[:, :5000] | |
| # Normalize each lead independently | |
| ecg = (ecg - ecg.mean(axis=1, keepdims=True)) / (ecg.std(axis=1, keepdims=True) + 1e-8) | |
| # Convert to tensor | |
| x = torch.tensor(ecg, dtype=torch.float32).unsqueeze(0).to(device) | |
| # Predict | |
| with torch.no_grad(): | |
| logits = model(x)[0].cpu().numpy() | |
| probs = torch.softmax(torch.tensor(logits), dim=0).numpy() | |
| # Get prediction | |
| pred_idx = int(np.argmax(probs)) | |
| pred_class = CLASS_LABELS[pred_idx] | |
| confidence = float(probs[pred_idx]) | |
| # Create visualization | |
| fig = go.Figure() | |
| fig.add_trace(go.Bar( | |
| y=CLASS_LABELS, | |
| x=probs, | |
| orientation='h', | |
| marker=dict( | |
| color=[CLASS_COLORS.get(c, '#87CEEB') for c in CLASS_LABELS], | |
| line=dict( | |
| color=['#000000' if i == pred_idx else '#CCCCCC' for i in range(5)], | |
| width=[3 if i == pred_idx else 1 for i in range(5)] | |
| ) | |
| ), | |
| text=[f'{p:.1%}' for p in probs], | |
| textposition='auto', | |
| hovertemplate='<b>%{y}</b><br>Probability: %{x:.2%}<extra></extra>' | |
| )) | |
| fig.update_layout( | |
| title=dict( | |
| text=f"ECG Classification Results<br><sub>Prediction: <b>{pred_class}</b> ({confidence:.1%})</sub>", | |
| x=0.5, | |
| xanchor='center' | |
| ), | |
| xaxis_title="Model Confidence", | |
| yaxis_title="Diagnostic Class", | |
| height=450, | |
| showlegend=False, | |
| font=dict(size=12), | |
| plot_bgcolor='rgba(240,240,240,0.5)' | |
| ) | |
| # Format output text | |
| output_md = f""" | |
| ## Prediction Complete | |
| ### Primary Diagnosis: **{pred_class}** | |
| ### Confidence: **{confidence:.1%}** | |
| --- | |
| ### All Class Probabilities: | |
| | Class | Probability | | |
| |-------|-------------| | |
| | {CLASS_LABELS[0]} | {probs[0]:.2%} | | |
| | {CLASS_LABELS[1]} | {probs[1]:.2%} | | |
| | {CLASS_LABELS[2]} | {probs[2]:.2%} | | |
| | {CLASS_LABELS[3]} | {probs[3]:.2%} | | |
| | {CLASS_LABELS[4]} | {probs[4]:.2%} | | |
| --- | |
| **Model Information:** | |
| - Framework: SimCLR SSL | |
| - Training Data: PTB-XL (10% labeled) | |
| - Test AUROC: 0.8717 | |
| - Input: 12-lead ECG @ 100 Hz | |
| **Disclaimer:** This is a research model for demonstration only. Not validated for clinical use. | |
| """ | |
| return output_md, fig, None | |
| except FileNotFoundError: | |
| return "**File Error:** Could not read uploaded file", None | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"**Error:** {str(e)}\n\nDebug: {traceback.format_exc()}" | |
| return error_msg, None | |
| # Create interface | |
| with gr.Blocks( | |
| title="ECG Classification with Self-Supervised Learning" | |
| ) as demo: | |
| gr.Markdown(""" | |
| # ECG Classification with Self-Supervised Learning | |
| **Test ECG cardiovascular disease classification** using a SimCLR pre-trained model fine-tuned on the PTB-XL dataset. | |
| **Model Performance:** AUROC 0.8717 | Accuracy 0.8234 | 10% labeled data | |
| --- | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| ### Upload Your ECG | |
| **Multi-file upload supported!** Upload multiple files at once, especially for WFDB pairs. | |
| **Clinical & Standardized Formats:** | |
| - `.dcm` β DICOM (medical imaging, PACS systems) | |
| - `.scp` β SCP-ECG (European interoperability standard) | |
| - `.xml` β HL7 aECG / FDA XML (clinical trials, regulatory) | |
| **Research & PhysioNet Formats:** | |
| - `.hea` + `.dat` β WFDB (MIT-BIH, PhysioNet) **Upload both files together** | |
| - `.edf` β European Data Format (multi-channel biosignals) | |
| **Generic / Export Formats:** | |
| - `.csv / .txt / .tsv` β Text formats (auto-detects delimiter) | |
| - `.npy` β NumPy arrays | |
| - `.mat` β MATLAB format | |
| - `.h5 / .hdf5` β HDF5 (efficient large-scale datasets) | |
| - `.raw / .bin` β Binary ECG data | |
| - `.zip` β Archive with multiple files | |
| **Architecture Auto-Conversion:** | |
| - Multi-lead (12 leads): Used directly | |
| - Single-lead β Replicated to 12 leads | |
| - Auto-pads/trims to 5000 samples per lead | |
| **Supported Delimiters:** Space, comma, tab (auto-detected) | |
| --- | |
| **π‘ WFDB Tip:** Upload both `.hea` and `.dat` files together in one go. The system will automatically detect the pair and process them correctly! | |
| """) | |
| file_input = gr.File( | |
| label="ECG File(s)", | |
| file_count="multiple", | |
| file_types=[".csv", ".txt", ".tsv", ".npy", ".hea", ".dat", | |
| ".dcm", ".mat", ".h5", ".hdf5", ".edf", ".xml", | |
| ".raw", ".bin", ".bat", ".ecg", ".zip"], | |
| type="filepath" | |
| ) | |
| submit_btn = gr.Button("Classify ECG", variant="primary", size="lg") | |
| with gr.Column(): | |
| gr.Markdown(""" | |
| ### Results | |
| Predictions appear here after classification. | |
| """) | |
| output_text = gr.Markdown( | |
| "Upload an ECG file to see predictions", | |
| label="Classification Results" | |
| ) | |
| with gr.Row(): | |
| chart_output = gr.Plot(label="Probability Distribution") | |
| # Connect button | |
| submit_btn.click( | |
| fn=predict_ecg, | |
| inputs=[file_input], | |
| outputs=[output_text, chart_output] | |
| ) | |
| # Info section | |
| gr.Markdown(""" | |
| --- | |
| ### About This Model | |
| **DOI:** [`10.57967/hf/8469`](https://doi.org/10.57967/hf/8469) | [Model Card](https://huggingface.co/Tumo505/SSL-ECG-Classification-model-card) | [GitHub](https://github.com/Tumo505/SSL-for-ECG-classification) | |
| **Architecture:** 1D CNN with SimCLR self-supervised pre-training | |
| **Training:** | |
| - Pre-training: SimCLR on 17.5K unlabeled PTB-XL ECGs | |
| - Fine-tuning: Supervised on 1.7K labeled ECGs (10%) | |
| **Classes Predicted:** | |
| - NORM: Normal ECG | |
| - MI: Myocardial Infarction | |
| - STTC: ST/T Changes | |
| - HYP: Hypertrophy | |
| - CD: Conduction Disturbances | |
| **Research Only** - Not validated for clinical use | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch(share=False) | |