Spaces:
Sleeping
Sleeping
import streamlit as st | |
import librosa | |
import numpy as np | |
import tensorflow as tf | |
import matplotlib.pyplot as plt | |
import firebase_admin | |
from firebase_admin import credentials, firestore | |
from datetime import datetime | |
# Optional: for microphone recording | |
try: | |
from streamlit_audiorecorder import audiorecorder | |
RECORDING_ENABLED = True | |
except ImportError: | |
RECORDING_ENABLED = False | |
# --- Firebase Initialization --- | |
# On Hugging Face Spaces, add your service account JSON to Secrets.toml as "firebase_key" | |
def init_firestore(): | |
# Load service account from st.secrets | |
cred_dict = st.secrets["firebase_key"] | |
cred = credentials.Certificate(cred_dict) | |
firebase_admin.initialize_app(cred) | |
return firestore.client() | |
db = init_firestore() | |
# --- Load ML Model --- | |
def load_model(): | |
return tf.keras.models.load_model('Heart_ResNet.h5') | |
model = load_model() | |
# --- Audio Processing & Classification --- | |
SAMPLE_RATE = 22050 | |
DURATION = 10 | |
INPUT_LEN = SAMPLE_RATE * DURATION | |
CLASS_NAMES = ["artifact", "murmur", "normal"] | |
def process_audio(raw_bytes): | |
# Save raw bytes to temp file | |
with open("temp.wav", "wb") as f: | |
f.write(raw_bytes) | |
X, sr = librosa.load("temp.wav", sr=SAMPLE_RATE, duration=DURATION) | |
if len(X) < INPUT_LEN: | |
X = np.pad(X, (0, INPUT_LEN - len(X)), mode='constant') | |
mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, n_fft=512, hop_length=256).T, axis=0) | |
return mfccs, X, sr | |
def classify(mfccs): | |
feats = mfccs.reshape(1, 52, 1) | |
preds = model.predict(feats) | |
return {name: float(preds[0][i]) for i, name in enumerate(CLASS_NAMES)} | |
# --- User Authentication Helpers --- | |
def register_user(email, password): | |
user_ref = db.collection("users").document(email) | |
if user_ref.get().exists: | |
return False, "User already exists" | |
user_ref.set({"email": email, "password": password, "history": []}) | |
return True, "Registration successful" | |
def login_user(email, password): | |
user_ref = db.collection("users").document(email) | |
user = user_ref.get() | |
if user.exists and user.to_dict().get("password") == password: | |
return True, user.to_dict() | |
return False, None | |
# --- Firestore History Functions --- | |
def save_history(email, result): | |
record = {"timestamp": datetime.utcnow().isoformat(), "result": result} | |
db.collection("users").document(email).update({ | |
"history": firestore.ArrayUnion([record]) | |
}) | |
def load_history(email): | |
user = db.collection("users").document(email).get().to_dict() | |
return user.get("history", []) | |
# --- Streamlit App Layout --- | |
st.title("🩺 Heartbeat Sound Classifier with Firestore") | |
# Session state for user | |
if "user" not in st.session_state: | |
st.session_state.user = None | |
# Sidebar: Auth or Logout | |
with st.sidebar: | |
if st.session_state.user: | |
st.markdown(f"**Logged in as:** {st.session_state.user['email']}") | |
if st.button("Logout"): | |
st.session_state.user = None | |
st.experimental_rerun() | |
else: | |
tab = st.radio("Account", ["Login", "Register"]) | |
email = st.text_input("Email") | |
password = st.text_input("Password", type="password") | |
if tab == "Register": | |
if st.button("Sign Up"): | |
success, msg = register_user(email, password) | |
st.success(msg) if success else st.error(msg) | |
else: | |
if st.button("Login"): | |
success, user = login_user(email, password) | |
if success: | |
st.session_state.user = user | |
st.experimental_rerun() | |
else: | |
st.error("Invalid credentials") | |
# Main: show after login | |
if st.session_state.user: | |
st.header("Upload or Record Your Heartbeat") | |
mode = st.radio("Input Mode:", ["Upload File", "Record (mic)" if RECORDING_ENABLED else "Upload File"]) | |
raw_audio = None | |
if mode == "Upload File": | |
up = st.file_uploader("Select WAV/MP3 file", type=["wav", "mp3"]) | |
if up: | |
raw_audio = up.read() | |
st.audio(raw_audio, format='audio/wav') | |
else: | |
audio_data = audiorecorder() | |
if audio_data is not None: | |
raw_audio = audio_data.tobytes() | |
st.audio(raw_audio, format='audio/wav') | |
if raw_audio: | |
if st.button("Classify Heartbeat"): | |
with st.spinner("Analyzing..."): | |
mfccs, waveform, sr = process_audio(raw_audio) | |
results = classify(mfccs) | |
save_history(st.session_state.user["email"], results) | |
# Display metrics | |
st.subheader("Results") | |
cols = st.columns(len(CLASS_NAMES)) | |
for col, name in zip(cols, CLASS_NAMES): | |
col.metric(name.title(), f"{results[name]*100:.2f}%") | |
# Waveform plot | |
fig, ax = plt.subplots(figsize=(8, 3)) | |
librosa.display.waveshow(waveform, sr=sr, ax=ax) | |
ax.set(title="Heartbeat Waveform", xlabel="Time (s)", ylabel="Amplitude") | |
st.pyplot(fig) | |
# Show history | |
st.header("Your Classification History") | |
history = load_history(st.session_state.user["email"]) | |
if history: | |
for rec in sorted(history, key=lambda x: x['timestamp'], reverse=True): | |
st.write(f"**{rec['timestamp']}**") | |
st.json(rec['result']) | |
else: | |
st.info("No history yet. Classify your first heartbeat!") | |
else: | |
st.info("Please login or register to continue.") | |
# Footer | |
st.markdown("---") | |
st.markdown("Built with ❤️ and deployed on Hugging Face Spaces") | |