Spaces:
Sleeping
Sleeping
File size: 5,701 Bytes
ba0eaa1 6f4b373 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 |
import streamlit as st
import librosa
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import firebase_admin
from firebase_admin import credentials, firestore
from datetime import datetime
# Optional: for microphone recording
try:
from streamlit_audiorecorder import audiorecorder
RECORDING_ENABLED = True
except ImportError:
RECORDING_ENABLED = False
# --- Firebase Initialization ---
# On Hugging Face Spaces, add your service account JSON to Secrets.toml as "firebase_key"
@st.cache_resource
def init_firestore():
# Load service account from st.secrets
cred_dict = st.secrets["firebase_key"]
cred = credentials.Certificate(cred_dict)
firebase_admin.initialize_app(cred)
return firestore.client()
db = init_firestore()
# --- Load ML Model ---
@st.cache_resource
def load_model():
return tf.keras.models.load_model('Heart_ResNet.h5')
model = load_model()
# --- Audio Processing & Classification ---
SAMPLE_RATE = 22050
DURATION = 10
INPUT_LEN = SAMPLE_RATE * DURATION
CLASS_NAMES = ["artifact", "murmur", "normal"]
@st.cache
def process_audio(raw_bytes):
# Save raw bytes to temp file
with open("temp.wav", "wb") as f:
f.write(raw_bytes)
X, sr = librosa.load("temp.wav", sr=SAMPLE_RATE, duration=DURATION)
if len(X) < INPUT_LEN:
X = np.pad(X, (0, INPUT_LEN - len(X)), mode='constant')
mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, n_fft=512, hop_length=256).T, axis=0)
return mfccs, X, sr
def classify(mfccs):
feats = mfccs.reshape(1, 52, 1)
preds = model.predict(feats)
return {name: float(preds[0][i]) for i, name in enumerate(CLASS_NAMES)}
# --- User Authentication Helpers ---
def register_user(email, password):
user_ref = db.collection("users").document(email)
if user_ref.get().exists:
return False, "User already exists"
user_ref.set({"email": email, "password": password, "history": []})
return True, "Registration successful"
def login_user(email, password):
user_ref = db.collection("users").document(email)
user = user_ref.get()
if user.exists and user.to_dict().get("password") == password:
return True, user.to_dict()
return False, None
# --- Firestore History Functions ---
def save_history(email, result):
record = {"timestamp": datetime.utcnow().isoformat(), "result": result}
db.collection("users").document(email).update({
"history": firestore.ArrayUnion([record])
})
def load_history(email):
user = db.collection("users").document(email).get().to_dict()
return user.get("history", [])
# --- Streamlit App Layout ---
st.title("🩺 Heartbeat Sound Classifier with Firestore")
# Session state for user
if "user" not in st.session_state:
st.session_state.user = None
# Sidebar: Auth or Logout
with st.sidebar:
if st.session_state.user:
st.markdown(f"**Logged in as:** {st.session_state.user['email']}")
if st.button("Logout"):
st.session_state.user = None
st.experimental_rerun()
else:
tab = st.radio("Account", ["Login", "Register"])
email = st.text_input("Email")
password = st.text_input("Password", type="password")
if tab == "Register":
if st.button("Sign Up"):
success, msg = register_user(email, password)
st.success(msg) if success else st.error(msg)
else:
if st.button("Login"):
success, user = login_user(email, password)
if success:
st.session_state.user = user
st.experimental_rerun()
else:
st.error("Invalid credentials")
# Main: show after login
if st.session_state.user:
st.header("Upload or Record Your Heartbeat")
mode = st.radio("Input Mode:", ["Upload File", "Record (mic)" if RECORDING_ENABLED else "Upload File"])
raw_audio = None
if mode == "Upload File":
up = st.file_uploader("Select WAV/MP3 file", type=["wav", "mp3"])
if up:
raw_audio = up.read()
st.audio(raw_audio, format='audio/wav')
else:
audio_data = audiorecorder()
if audio_data is not None:
raw_audio = audio_data.tobytes()
st.audio(raw_audio, format='audio/wav')
if raw_audio:
if st.button("Classify Heartbeat"):
with st.spinner("Analyzing..."):
mfccs, waveform, sr = process_audio(raw_audio)
results = classify(mfccs)
save_history(st.session_state.user["email"], results)
# Display metrics
st.subheader("Results")
cols = st.columns(len(CLASS_NAMES))
for col, name in zip(cols, CLASS_NAMES):
col.metric(name.title(), f"{results[name]*100:.2f}%")
# Waveform plot
fig, ax = plt.subplots(figsize=(8, 3))
librosa.display.waveshow(waveform, sr=sr, ax=ax)
ax.set(title="Heartbeat Waveform", xlabel="Time (s)", ylabel="Amplitude")
st.pyplot(fig)
# Show history
st.header("Your Classification History")
history = load_history(st.session_state.user["email"])
if history:
for rec in sorted(history, key=lambda x: x['timestamp'], reverse=True):
st.write(f"**{rec['timestamp']}**")
st.json(rec['result'])
else:
st.info("No history yet. Classify your first heartbeat!")
else:
st.info("Please login or register to continue.")
# Footer
st.markdown("---")
st.markdown("Built with ❤️ and deployed on Hugging Face Spaces")
|