Spaces:
Running
Running
import streamlit as st | |
import requests | |
import pandas as pd | |
import matplotlib.pyplot as plt | |
from huggingface_hub import hf_hub_download | |
from io import BytesIO | |
st.title("Heart Arrhythmia Detection Tools (hadt)") | |
st.markdown(""" | |
This is a demo of the Heart Arrhythmia Detection Tools (hadt) project. | |
The project is available on [GitHub](https://github.com/fabriciojm/hadt). | |
""") | |
models = { | |
"LSTM Multiclass": "lstm_multi_model.h5", | |
"CNN Multiclass": "cnn_multi_model.h5", | |
"PCA XGBoost Multiclass": "pca_xgboost_multi_model.pkl", | |
"LSTM Binary": "lstm_binary_model.h5", | |
"CNN Binary": "cnn_binary_model.h5", | |
"PCA XGBoost Binary": "pca_xgboost_binary_model.pkl", | |
} | |
beat_labels = { | |
"N": "Normal", | |
"Q": "Unknown Beat", | |
"S": "Supraventricular Ectopic", | |
"V": "Ventricular Ectopic", | |
"A": "Abnormal", | |
} | |
# Model selection | |
classification = ["Multiclass", "Binary"] | |
model_list = ["LSTM", "CNN", "PCA XGBoost"] | |
model_selected = st.selectbox("Select a Model", model_list) | |
classification_selected = st.selectbox("Classification type", classification) | |
model_name = f"{model_selected} {classification_selected}" | |
def visualize_single(df, st): | |
st.write("Visualized Data:") | |
fig, ax = plt.subplots(figsize=(10, 6)) | |
df.iloc[0].plot(ax=ax) | |
st.pyplot(fig) | |
# This function will be used when the API is capable of returning extracted beats | |
# def visualize_multiple(beats, st): | |
# st.write("Visualized Data:") | |
# if len(beats) % 4 != 0: | |
# nrows = len(beats) // 4 + 1 | |
# else: | |
# nrows = len(beats) // 4 | |
# fig, axs = plt.subplots(nrows, 4, figsize=(10, nrows*2.5)) | |
# for i, beat in enumerate(beats): | |
# axs.flatten()[i].plot(beat) | |
# # delete last plots if not used | |
# for j in range(len(beats)%4): | |
# fig.delaxes(axs.flatten()[-j-1]) | |
# st.pyplot(fig) | |
st.markdown("""Upload a CSV file with single heartbeat (csv with 180 points) or load from available examples | |
""") | |
# Option to upload or load a file | |
option = st.radio("Choose input method", ("Load example file", "Upload CSV file", "Upload Apple Watch ECG CSV file (EXPERIMENTAL)")) | |
if option == "Load example file": | |
# Load example files from Hugging Face dataset | |
example_files = ["single_N.csv", "single_Q.csv", "single_S.csv", "single_V.csv"] | |
example_selected = st.selectbox("Select an example file", example_files) | |
# Load the selected example file | |
file_path = hf_hub_download(repo_id='fabriciojm/ecg-examples', repo_type='dataset', filename=example_selected) | |
with open(file_path, 'rb') as f: | |
file_content = f.read() | |
uploaded_file = BytesIO(file_content) | |
uploaded_file.name = example_selected # Set a name attribute to mimic the uploaded file | |
df = pd.read_csv(uploaded_file) | |
# st.write("Loaded Data:", df) | |
if 'df' in locals(): | |
visualize_single(df, st) | |
elif option == "Upload CSV file": | |
# File uploader | |
uploaded_file = st.file_uploader("Upload a CSV file", type="csv") | |
st.write("The CSV file should have 180 points per row, following the format in [the examples](https://huggingface.co/datasets/fabriciojm/ecg-examples)") | |
if uploaded_file is not None: | |
df = pd.read_csv(uploaded_file) | |
# st.write("Uploaded Data:", df) | |
if 'df' in locals(): | |
visualize_single(df, st) | |
elif option == "Upload Apple Watch ECG CSV file (EXPERIMENTAL)": | |
# File uploader | |
st.write("DISCLAIMER: this is an experimental feature, and the results may not be accurate. This should not be used as professional medical advice.") | |
uploaded_file = st.file_uploader("Upload a CSV file", type="csv") | |
st.write("The Apple Watch CSV file should have the same format as [the examples](https://huggingface.co/datasets/fabriciojm/apple-ecg-examples)") | |
if uploaded_file is not None: | |
df = pd.read_csv(uploaded_file) | |
# st.write("Uploaded Data:", df) | |
if st.button("Predict"): | |
model = models[model_name] | |
# Reset the file pointer to the beginning | |
uploaded_file.seek(0) | |
# Call the API with the file directly | |
base_url = "https://fabriciojm-hadt-api.hf.space/predict" | |
if option == "Upload Apple Watch ECG CSV file (EXPERIMENTAL)": | |
base_url += "_multibeats" | |
print(f"Request url: {base_url}?model_name={model}") | |
response = requests.post( | |
f"{base_url}?model_name={model}", | |
files={"filepath_csv": (uploaded_file.name, uploaded_file, "text/csv")} | |
) | |
if response.status_code == 200: | |
prediction = response.json()["prediction"] | |
st.write(f"Prediction using {model_name}:") | |
for i, p in enumerate(prediction): | |
st.write(f"Beat {i+1}: {beat_labels[p]} (class {p})") # {beat_labels[prediction]} (class {prediction}) heartbeat | |
else: | |
st.error(f"Error: {response.json().get('detail', 'Unknown error')}") | |