Spaces:
Running
Running
| import os, io, pathlib, urllib.request | |
| import numpy as np | |
| import streamlit as st | |
| from PIL import Image | |
| from matplotlib import cm | |
| st.write("### ✅ Voice Guard Streamlit — env-only v4 (no st.secrets)") | |
| # ---- import Detector from app/ or src/ ---- | |
| Detector, _last_err = None, None | |
| for mod in ["app.inference_wav2vec", "app.inference", | |
| "src.inference_wav2vec", "src.inference"]: | |
| try: | |
| Detector = __import__(mod, fromlist=["Detector"]).Detector | |
| break | |
| except Exception as e: | |
| _last_err = e | |
| if Detector is None: | |
| st.error(f"Could not import Detector from app/ or src/. Last error: {_last_err}") | |
| st.stop() | |
| # ---- ENV config only ---- | |
| def cfg(name: str, default: str = "") -> str: | |
| v = os.getenv(name) | |
| return v if v not in (None, "") else default | |
| def ensure_weights() -> str: | |
| wp = cfg("MODEL_WEIGHTS_PATH", "app/models/weights/wav2vec2_classifier.pth") | |
| url = cfg("MODEL_WEIGHTS_URL", "") | |
| dest = pathlib.Path(wp) | |
| if not dest.exists() and url: | |
| dest.parent.mkdir(parents=True, exist_ok=True) | |
| with st.spinner(f"Downloading model weights to {dest} …"): | |
| urllib.request.urlretrieve(url, str(dest)) | |
| st.toast("Weights downloaded", icon="✅") | |
| if not dest.exists() and not url: | |
| st.warning( | |
| f"Model weights not found at '{wp}'. " | |
| "Upload the .pth there OR set MODEL_WEIGHTS_URL in Settings → Variables & secrets." | |
| ) | |
| return str(dest) | |
| def load_detector(): | |
| return Detector(weights_path=ensure_weights()) | |
| det = load_detector() | |
| # ---- helpers ---- | |
| def cam_to_png_bytes(cam: np.ndarray) -> bytes: | |
| cam = np.asarray(cam, dtype=np.float32) | |
| cam = np.nan_to_num(cam, nan=0.0); cam = np.clip(cam, 0.0, 1.0) | |
| rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8) | |
| buf = io.BytesIO(); Image.fromarray(rgb).save(buf, "PNG") | |
| return buf.getvalue() | |
| def analyze(wav_bytes: bytes, source_hint: str): | |
| proba = det.predict_proba(wav_bytes, source_hint=source_hint) | |
| exp = det.explain(wav_bytes, source_hint=source_hint) | |
| return proba, exp | |
| # ---- UI ---- | |
| st.set_page_config(page_title="Voice Guard", page_icon="🛡️", layout="wide") | |
| st.title("🛡️ Voice Guard — Human vs AI Speech") | |
| left, right = st.columns([1,2], gap="large") | |
| with left: | |
| st.subheader("Input") | |
| tab_rec, tab_up = st.tabs(["🎙️ Microphone", "📁 Upload"]) | |
| wav_bytes, source_hint = None, None | |
| with tab_rec: | |
| st.caption("Record ~3–7 s. If mic fails, use Upload.") | |
| try: | |
| from audio_recorder_streamlit import audio_recorder | |
| audio = audio_recorder(text="Record", | |
| recording_color="#ff6a00", | |
| neutral_color="#2b2b2b", | |
| icon_size="2x") | |
| if audio: | |
| wav_bytes, source_hint = audio, "microphone" | |
| st.audio(wav_bytes, format="audio/wav") | |
| except Exception: | |
| st.info("Recorder not available—use Upload tab.") | |
| with tab_up: | |
| f = st.file_uploader("Upload wav/mp3/m4a/aac", type=["wav","mp3","m4a","aac"]) | |
| if f: | |
| wav_bytes, source_hint = f.read(), "upload" | |
| st.audio(wav_bytes) | |
| st.markdown("---") | |
| run = st.button("🔍 Analyze", type="primary", use_container_width=True, | |
| disabled=wav_bytes is None) | |
| with right: | |
| st.subheader("Results") | |
| if run and wav_bytes: | |
| try: | |
| with st.spinner("Analyzing…"): | |
| proba, exp = analyze(wav_bytes, source_hint or "auto") | |
| ph = float(proba.get("human",0.0)); pa = float(proba.get("ai",0.0)) | |
| label = (proba.get("label","human") or "human").upper() | |
| thr = float(proba.get("threshold",0.5)) | |
| rule = proba.get("decision","threshold") | |
| thr_src = proba.get("threshold_source","—") | |
| rscore = proba.get("replay_score", None) | |
| c1,c2,c3 = st.columns(3) | |
| with c1: st.metric("Human", f"{ph*100:.1f}%") | |
| with c2: st.metric("AI", f"{pa*100:.1f}%") | |
| with c3: | |
| color = "#22c55e" if label=="HUMAN" else "#fb7185" | |
| st.markdown(f"**Final Label:** <span style='color:{color}'>{label}</span>", unsafe_allow_html=True) | |
| st.caption(f"thr({thr_src})={thr:.2f} • rule={rule} • replay={'—' if rscore is None else f'{float(rscore):.2f}'}") | |
| st.markdown("##### Explanation Heatmap") | |
| cam = np.asarray(exp.get("cam"), dtype=np.float32) | |
| st.image(cam_to_png_bytes(cam), caption="Spectrogram importance", use_column_width=True) | |
| with st.expander("Raw JSON (debug)"): | |
| st.json({"proba": proba, "explain": {"cam_shape": list(cam.shape)}}) | |
| except Exception as e: | |
| st.error(f"Analyze failed: {e}") | |
| st.caption("Upload 3–7s clips for the most reliable experience across browsers.") |