File size: 5,701 Bytes
ba0eaa1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4b373
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
import streamlit as st
import librosa
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
import firebase_admin
from firebase_admin import credentials, firestore
from datetime import datetime

# Optional: for microphone recording
try:
    from streamlit_audiorecorder import audiorecorder
    RECORDING_ENABLED = True
except ImportError:
    RECORDING_ENABLED = False

# --- Firebase Initialization ---
# On Hugging Face Spaces, add your service account JSON to Secrets.toml as "firebase_key"
@st.cache_resource
def init_firestore():
    # Load service account from st.secrets
    cred_dict = st.secrets["firebase_key"]
    cred = credentials.Certificate(cred_dict)
    firebase_admin.initialize_app(cred)
    return firestore.client()

db = init_firestore()

# --- Load ML Model ---
@st.cache_resource
def load_model():
    return tf.keras.models.load_model('Heart_ResNet.h5')

model = load_model()

# --- Audio Processing & Classification ---
SAMPLE_RATE = 22050
DURATION = 10
INPUT_LEN = SAMPLE_RATE * DURATION
CLASS_NAMES = ["artifact", "murmur", "normal"]

@st.cache

def process_audio(raw_bytes):
    # Save raw bytes to temp file
    with open("temp.wav", "wb") as f:
        f.write(raw_bytes)
    X, sr = librosa.load("temp.wav", sr=SAMPLE_RATE, duration=DURATION)
    if len(X) < INPUT_LEN:
        X = np.pad(X, (0, INPUT_LEN - len(X)), mode='constant')
    mfccs = np.mean(librosa.feature.mfcc(y=X, sr=sr, n_mfcc=52, n_fft=512, hop_length=256).T, axis=0)
    return mfccs, X, sr


def classify(mfccs):
    feats = mfccs.reshape(1, 52, 1)
    preds = model.predict(feats)
    return {name: float(preds[0][i]) for i, name in enumerate(CLASS_NAMES)}

# --- User Authentication Helpers ---

def register_user(email, password):
    user_ref = db.collection("users").document(email)
    if user_ref.get().exists:
        return False, "User already exists"
    user_ref.set({"email": email, "password": password, "history": []})
    return True, "Registration successful"


def login_user(email, password):
    user_ref = db.collection("users").document(email)
    user = user_ref.get()
    if user.exists and user.to_dict().get("password") == password:
        return True, user.to_dict()
    return False, None

# --- Firestore History Functions ---

def save_history(email, result):
    record = {"timestamp": datetime.utcnow().isoformat(), "result": result}
    db.collection("users").document(email).update({
        "history": firestore.ArrayUnion([record])
    })


def load_history(email):
    user = db.collection("users").document(email).get().to_dict()
    return user.get("history", [])

# --- Streamlit App Layout ---
st.title("🩺 Heartbeat Sound Classifier with Firestore")

# Session state for user
if "user" not in st.session_state:
    st.session_state.user = None

# Sidebar: Auth or Logout
with st.sidebar:
    if st.session_state.user:
        st.markdown(f"**Logged in as:** {st.session_state.user['email']}")
        if st.button("Logout"):
            st.session_state.user = None
            st.experimental_rerun()
    else:
        tab = st.radio("Account", ["Login", "Register"])
        email = st.text_input("Email")
        password = st.text_input("Password", type="password")
        if tab == "Register":
            if st.button("Sign Up"):
                success, msg = register_user(email, password)
                st.success(msg) if success else st.error(msg)
        else:
            if st.button("Login"):
                success, user = login_user(email, password)
                if success:
                    st.session_state.user = user
                    st.experimental_rerun()
                else:
                    st.error("Invalid credentials")

# Main: show after login
if st.session_state.user:
    st.header("Upload or Record Your Heartbeat")
    mode = st.radio("Input Mode:", ["Upload File", "Record (mic)" if RECORDING_ENABLED else "Upload File"])

    raw_audio = None
    if mode == "Upload File":
        up = st.file_uploader("Select WAV/MP3 file", type=["wav", "mp3"])
        if up:
            raw_audio = up.read()
            st.audio(raw_audio, format='audio/wav')
    else:
        audio_data = audiorecorder()
        if audio_data is not None:
            raw_audio = audio_data.tobytes()
            st.audio(raw_audio, format='audio/wav')

    if raw_audio:
        if st.button("Classify Heartbeat"):
            with st.spinner("Analyzing..."):
                mfccs, waveform, sr = process_audio(raw_audio)
                results = classify(mfccs)
                save_history(st.session_state.user["email"], results)

                # Display metrics
                st.subheader("Results")
                cols = st.columns(len(CLASS_NAMES))
                for col, name in zip(cols, CLASS_NAMES):
                    col.metric(name.title(), f"{results[name]*100:.2f}%")

                # Waveform plot
                fig, ax = plt.subplots(figsize=(8, 3))
                librosa.display.waveshow(waveform, sr=sr, ax=ax)
                ax.set(title="Heartbeat Waveform", xlabel="Time (s)", ylabel="Amplitude")
                st.pyplot(fig)

    # Show history
    st.header("Your Classification History")
    history = load_history(st.session_state.user["email"])
    if history:
        for rec in sorted(history, key=lambda x: x['timestamp'], reverse=True):
            st.write(f"**{rec['timestamp']}**")
            st.json(rec['result'])
    else:
        st.info("No history yet. Classify your first heartbeat!")
else:
    st.info("Please login or register to continue.")

# Footer
st.markdown("---")
st.markdown("Built with ❤️ and deployed on Hugging Face Spaces")