firestoretest / app.py
mistermprah's picture
Update app.py
6f4b373 verified
import streamlit as st
import librosa
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import firebase_admin
from firebase_admin import credentials, firestore
from datetime import datetime
# Optional: for microphone recording
try:
from streamlit_audiorecorder import audiorecorder
RECORDING_ENABLED = True
except ImportError:
RECORDING_ENABLED = False
# --- Firebase Initialization ---
# On Hugging Face Spaces, add your service account JSON to Secrets.toml as "firebase_key"
@st.cache_resource
def init_firestore():
# Load service account from st.secrets
cred_dict = st.secrets["firebase_key"]
cred = credentials.Certificate(cred_dict)
firebase_admin.initialize_app(cred)
return firestore.client()
db = init_firestore()
# --- Load ML Model ---
@st.cache_resource
def load_model():
return tf.keras.models.load_model('Heart_ResNet.h5')
model = load_model()
# --- Audio Processing & Classification ---
SAMPLE_RATE = 22050
DURATION = 10
INPUT_LEN = SAMPLE_RATE * DURATION
CLASS_NAMES = ["artifact", "murmur", "normal"]
@st.cache
def process_audio(raw_bytes):
# Save raw bytes to temp file
with open("temp.wav", "wb") as f:
f.write(raw_bytes)
X, sr = librosa.load("temp.wav", sr=SAMPLE_RATE, duration=DURATION)
if len(X) < INPUT_LEN:
X = np.pad(X, (0, INPUT_LEN - len(X)), mode='constant')
mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, n_fft=512, hop_length=256).T, axis=0)
return mfccs, X, sr
def classify(mfccs):
feats = mfccs.reshape(1, 52, 1)
preds = model.predict(feats)
return {name: float(preds[0][i]) for i, name in enumerate(CLASS_NAMES)}
# --- User Authentication Helpers ---
def register_user(email, password):
user_ref = db.collection("users").document(email)
if user_ref.get().exists:
return False, "User already exists"
user_ref.set({"email": email, "password": password, "history": []})
return True, "Registration successful"
def login_user(email, password):
user_ref = db.collection("users").document(email)
user = user_ref.get()
if user.exists and user.to_dict().get("password") == password:
return True, user.to_dict()
return False, None
# --- Firestore History Functions ---
def save_history(email, result):
record = {"timestamp": datetime.utcnow().isoformat(), "result": result}
db.collection("users").document(email).update({
"history": firestore.ArrayUnion([record])
})
def load_history(email):
user = db.collection("users").document(email).get().to_dict()
return user.get("history", [])
# --- Streamlit App Layout ---
st.title("🩺 Heartbeat Sound Classifier with Firestore")
# Session state for user
if "user" not in st.session_state:
st.session_state.user = None
# Sidebar: Auth or Logout
with st.sidebar:
if st.session_state.user:
st.markdown(f"**Logged in as:** {st.session_state.user['email']}")
if st.button("Logout"):
st.session_state.user = None
st.experimental_rerun()
else:
tab = st.radio("Account", ["Login", "Register"])
email = st.text_input("Email")
password = st.text_input("Password", type="password")
if tab == "Register":
if st.button("Sign Up"):
success, msg = register_user(email, password)
st.success(msg) if success else st.error(msg)
else:
if st.button("Login"):
success, user = login_user(email, password)
if success:
st.session_state.user = user
st.experimental_rerun()
else:
st.error("Invalid credentials")
# Main: show after login
if st.session_state.user:
st.header("Upload or Record Your Heartbeat")
mode = st.radio("Input Mode:", ["Upload File", "Record (mic)" if RECORDING_ENABLED else "Upload File"])
raw_audio = None
if mode == "Upload File":
up = st.file_uploader("Select WAV/MP3 file", type=["wav", "mp3"])
if up:
raw_audio = up.read()
st.audio(raw_audio, format='audio/wav')
else:
audio_data = audiorecorder()
if audio_data is not None:
raw_audio = audio_data.tobytes()
st.audio(raw_audio, format='audio/wav')
if raw_audio:
if st.button("Classify Heartbeat"):
with st.spinner("Analyzing..."):
mfccs, waveform, sr = process_audio(raw_audio)
results = classify(mfccs)
save_history(st.session_state.user["email"], results)
# Display metrics
st.subheader("Results")
cols = st.columns(len(CLASS_NAMES))
for col, name in zip(cols, CLASS_NAMES):
col.metric(name.title(), f"{results[name]*100:.2f}%")
# Waveform plot
fig, ax = plt.subplots(figsize=(8, 3))
librosa.display.waveshow(waveform, sr=sr, ax=ax)
ax.set(title="Heartbeat Waveform", xlabel="Time (s)", ylabel="Amplitude")
st.pyplot(fig)
# Show history
st.header("Your Classification History")
history = load_history(st.session_state.user["email"])
if history:
for rec in sorted(history, key=lambda x: x['timestamp'], reverse=True):
st.write(f"**{rec['timestamp']}**")
st.json(rec['result'])
else:
st.info("No history yet. Classify your first heartbeat!")
else:
st.info("Please login or register to continue.")
# Footer
st.markdown("---")
st.markdown("Built with ❤️ and deployed on Hugging Face Spaces")