Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- app/api.py +65 -0
- app/app.py +138 -0
- app/elevenlabs_tools.py +38 -0
- app/inference_wav2vec.py +214 -0
- app/train.py +315 -0
- app/train_wav2vec.py +207 -0
app/api.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, io, base64
|
| 2 |
+
import numpy as np
|
| 3 |
+
from fastapi import FastAPI, UploadFile, File, Form
|
| 4 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 5 |
+
from pydantic import BaseModel
|
| 6 |
+
from typing import Optional, Dict, Any
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from matplotlib import cm
|
| 9 |
+
|
| 10 |
+
BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").lower()
|
| 11 |
+
try:
|
| 12 |
+
if BACKEND == "wav2vec2":
|
| 13 |
+
from .inference_wav2vec import Detector # type: ignore
|
| 14 |
+
else:
|
| 15 |
+
from .inference import Detector # type: ignore
|
| 16 |
+
except Exception:
|
| 17 |
+
if BACKEND == "wav2vec2":
|
| 18 |
+
from app.inference_wav2vec import Detector # type: ignore
|
| 19 |
+
else:
|
| 20 |
+
from app.inference import Detector # type: ignore
|
| 21 |
+
|
| 22 |
+
DEFAULT_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth" if BACKEND=="wav2vec2" else "app/models/weights/cnn_melspec.pth"
|
| 23 |
+
WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS)
|
| 24 |
+
det = Detector(weights_path=WEIGHTS)
|
| 25 |
+
|
| 26 |
+
app = FastAPI(title="Voice Guard API", version="1.1.0")
|
| 27 |
+
app.add_middleware(
|
| 28 |
+
CORSMiddleware,
|
| 29 |
+
allow_origins=["*"], # tighten in prod
|
| 30 |
+
allow_methods=["*"],
|
| 31 |
+
allow_headers=["*"],
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
class AnalyzeResponse(BaseModel):
|
| 35 |
+
human: float
|
| 36 |
+
ai: float
|
| 37 |
+
label: str
|
| 38 |
+
threshold: float
|
| 39 |
+
threshold_source: Optional[str] = None
|
| 40 |
+
backend: str
|
| 41 |
+
source_hint: str
|
| 42 |
+
replay_score: Optional[float] = None
|
| 43 |
+
decision: Optional[str] = None
|
| 44 |
+
decision_details: Optional[Dict[str, Any]] = None
|
| 45 |
+
heatmap_b64: str
|
| 46 |
+
|
| 47 |
+
def heatmap_png_b64(cam: np.ndarray) -> str:
|
| 48 |
+
cam = np.clip(cam, 0.0, 1.0).astype(np.float32)
|
| 49 |
+
rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
|
| 50 |
+
im = Image.fromarray(rgb)
|
| 51 |
+
buf = io.BytesIO(); im.save(buf, format="PNG")
|
| 52 |
+
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")
|
| 53 |
+
|
| 54 |
+
@app.post("/analyze", response_model=AnalyzeResponse)
|
| 55 |
+
async def analyze(file: UploadFile = File(...), source_hint: str = Form("auto")):
|
| 56 |
+
raw = await file.read()
|
| 57 |
+
proba = det.predict_proba(raw, source_hint=source_hint)
|
| 58 |
+
cam = np.array(det.explain(raw, source_hint=source_hint)["cam"], dtype=np.float32)
|
| 59 |
+
return {
|
| 60 |
+
**proba,
|
| 61 |
+
"heatmap_b64": heatmap_png_b64(cam),
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
@app.get("/health")
|
| 65 |
+
def health(): return {"ok": True, "backend": BACKEND}
|
app/app.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import numpy as np
|
| 4 |
+
import gradio as gr
|
| 5 |
+
from dotenv import load_dotenv
|
| 6 |
+
from matplotlib import cm
|
| 7 |
+
|
| 8 |
+
load_dotenv()
|
| 9 |
+
|
| 10 |
+
# -------------------------
|
| 11 |
+
# 0) Env & defaults
|
| 12 |
+
# -------------------------
|
| 13 |
+
BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").strip().lower() # "wav2vec2" or "cnn"
|
| 14 |
+
|
| 15 |
+
DEFAULT_W2V_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth"
|
| 16 |
+
DEFAULT_CNN_WEIGHTS = "app/models/weights/cnn_melspec.pth"
|
| 17 |
+
DEFAULT_WEIGHTS = DEFAULT_W2V_WEIGHTS if BACKEND == "wav2vec2" else DEFAULT_CNN_WEIGHTS
|
| 18 |
+
MODEL_WEIGHTS_PATH = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS).strip()
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
# -------------------------
|
| 22 |
+
# 1) Import your Detector
|
| 23 |
+
# -------------------------
|
| 24 |
+
def _import_detector(backend):
|
| 25 |
+
"""
|
| 26 |
+
Import the correct Detector class depending on backend and package layout.
|
| 27 |
+
Works both when run as a module ('.inference_*') and as a script ('app.inference_*').
|
| 28 |
+
"""
|
| 29 |
+
try:
|
| 30 |
+
if backend == "wav2vec2":
|
| 31 |
+
from .inference_wav2vec import Detector # type: ignore
|
| 32 |
+
else:
|
| 33 |
+
from .inference import Detector # type: ignore
|
| 34 |
+
except Exception:
|
| 35 |
+
if backend == "wav2vec2":
|
| 36 |
+
from app.inference_wav2vec import Detector # type: ignore
|
| 37 |
+
else:
|
| 38 |
+
from app.inference import Detector # type: ignore
|
| 39 |
+
return Detector
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
Detector = _import_detector(BACKEND)
|
| 44 |
+
except Exception as e:
|
| 45 |
+
# Fallback dummy to keep the UI alive even if import fails,
|
| 46 |
+
# so you can see the error in the JSON panel.
|
| 47 |
+
class Detector: # type: ignore
|
| 48 |
+
def __init__(self, *args, **kwargs):
|
| 49 |
+
self._err = f"Detector import failed: {e}"
|
| 50 |
+
|
| 51 |
+
def predict_proba(self, *args, **kwargs):
|
| 52 |
+
return {"error": self._err}
|
| 53 |
+
|
| 54 |
+
def explain(self, *args, **kwargs):
|
| 55 |
+
return {"cam": np.zeros((128, 128), dtype=np.float32).tolist()}
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# Single, shared detector (created lazily so startup is fast on Spaces)
|
| 59 |
+
_DET = None
|
| 60 |
+
def _get_detector():
|
| 61 |
+
global _DET
|
| 62 |
+
if _DET is None:
|
| 63 |
+
_DET = Detector(weights_path=MODEL_WEIGHTS_PATH)
|
| 64 |
+
return _DET
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# -------------------------
|
| 68 |
+
# 2) Core functions
|
| 69 |
+
# -------------------------
|
| 70 |
+
def predict_and_explain(audio_path: str | None, source_hint: str):
|
| 71 |
+
"""
|
| 72 |
+
audio_path: filepath from Gradio (since type='filepath')
|
| 73 |
+
source_hint: "Auto", "Microphone", "Upload"
|
| 74 |
+
"""
|
| 75 |
+
source = (source_hint or "Auto").strip().lower()
|
| 76 |
+
|
| 77 |
+
if not audio_path or not os.path.exists(audio_path):
|
| 78 |
+
return {"error": "No audio received. Record or upload a 2–4s clip."}, None
|
| 79 |
+
|
| 80 |
+
det = _get_detector()
|
| 81 |
+
|
| 82 |
+
# Your Detector is expected to accept a file path and optional source hint
|
| 83 |
+
proba = det.predict_proba(audio_path, source_hint=source)
|
| 84 |
+
exp = det.explain(audio_path, source_hint=source)
|
| 85 |
+
|
| 86 |
+
# Explanation to heatmap (float [0,1] -> magma RGB uint8)
|
| 87 |
+
cam = np.array(exp.get("cam", []), dtype=np.float32)
|
| 88 |
+
if cam.ndim == 1:
|
| 89 |
+
# if model returned a 1D vector, tile to square-ish map
|
| 90 |
+
side = int(np.sqrt(cam.size))
|
| 91 |
+
side = max(side, 2)
|
| 92 |
+
cam = cam[: side * side].reshape(side, side)
|
| 93 |
+
cam = np.clip(cam, 0.0, 1.0)
|
| 94 |
+
cam_rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
|
| 95 |
+
|
| 96 |
+
# Ensure proba is JSON-serializable
|
| 97 |
+
if not isinstance(proba, dict):
|
| 98 |
+
proba = {"result": proba}
|
| 99 |
+
|
| 100 |
+
return proba, cam_rgb
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def provenance(audio_path: str | None):
|
| 104 |
+
# Stub (you can wire a provenance model or checksum here)
|
| 105 |
+
return {"ok": True, "note": "Provenance check not wired in this app.py."}
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# -------------------------
|
| 109 |
+
# 3) UI
|
| 110 |
+
# -------------------------
|
| 111 |
+
with gr.Blocks(title=f"AI Voice Detector · {BACKEND.upper()}") as demo:
|
| 112 |
+
gr.Markdown(f"# 🔎 AI Voice Detector — Backend: **{BACKEND.upper()}**")
|
| 113 |
+
gr.Markdown(
|
| 114 |
+
"Record or upload a short clip (~3s). Get probabilities, a label, and an explanation heatmap."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
with gr.Row():
|
| 118 |
+
audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
|
| 119 |
+
with gr.Column():
|
| 120 |
+
src = gr.Radio(choices=["Auto", "Microphone", "Upload"], value="Auto", label="Source")
|
| 121 |
+
btn_predict = gr.Button("Analyze", variant="primary")
|
| 122 |
+
btn_prov = gr.Button("Provenance Check (optional)")
|
| 123 |
+
|
| 124 |
+
with gr.Row():
|
| 125 |
+
json_out = gr.JSON(label="Prediction (probabilities + label)")
|
| 126 |
+
cam_out = gr.Image(label="Explanation Heatmap (saliency)")
|
| 127 |
+
prov_out = gr.JSON(label="Provenance Result (if available)")
|
| 128 |
+
|
| 129 |
+
btn_predict.click(predict_and_explain, inputs=[audio_in, src], outputs=[json_out, cam_out])
|
| 130 |
+
btn_prov.click(provenance, inputs=audio_in, outputs=prov_out)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
# -------------------------
|
| 134 |
+
# 4) Launch (Spaces-friendly)
|
| 135 |
+
# -------------------------
|
| 136 |
+
if __name__ == "__main__":
|
| 137 |
+
# queue() keeps UI responsive under load; host/port are Spaces-safe and local-friendly
|
| 138 |
+
demo.queue().launch(server_name="0.0.0.0", server_port=7860)
|
app/elevenlabs_tools.py
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import os, time, hashlib, json, pathlib, random
|
| 3 |
+
from typing import List, Optional
|
| 4 |
+
from dotenv import load_dotenv
|
| 5 |
+
import requests
|
| 6 |
+
|
| 7 |
+
load_dotenv()
|
| 8 |
+
ELEVEN_API_KEY = os.getenv("ELEVEN_API_KEY", "")
|
| 9 |
+
ELEVEN_VOICE_ID = os.getenv("ELEVEN_VOICE_ID", "")
|
| 10 |
+
BASE = "https://api.elevenlabs.io/v1"
|
| 11 |
+
|
| 12 |
+
def _headers():
|
| 13 |
+
return {"xi-api-key": ELEVEN_API_KEY, "accept": "audio/mpeg", "Content-Type": "application/json"}
|
| 14 |
+
|
| 15 |
+
def generate_tts_dataset(texts: List[str], voice_id: Optional[str]=None, out_dir: str="data/raw/ai", model_id: str="eleven_monolingual_v1"):
|
| 16 |
+
"""Generate AI speech MP3s from ElevenLabs into out_dir. Convert to WAV (16k mono) for training."""
|
| 17 |
+
voice_id = voice_id or ELEVEN_VOICE_ID
|
| 18 |
+
assert ELEVEN_API_KEY, "Set ELEVEN_API_KEY in .env"
|
| 19 |
+
assert voice_id, "Provide ELEVEN_VOICE_ID in .env or pass voice_id"
|
| 20 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 21 |
+
for i, txt in enumerate(texts):
|
| 22 |
+
payload = {"text": txt, "model_id": model_id, "voice_settings": {"stability": 0.4, "similarity_boost": 0.7}}
|
| 23 |
+
url = f"{BASE}/text-to-speech/{voice_id}"
|
| 24 |
+
r = requests.post(url, headers=_headers(), json=payload)
|
| 25 |
+
if r.status_code != 200:
|
| 26 |
+
print("TTS error", r.status_code, r.text[:200]); continue
|
| 27 |
+
mp3_path = os.path.join(out_dir, f"elab_{i:04d}.mp3")
|
| 28 |
+
with open(mp3_path, "wb") as f:
|
| 29 |
+
f.write(r.content)
|
| 30 |
+
print("saved", mp3_path)
|
| 31 |
+
print("Done. Convert MP3 to WAV (16kHz mono) before training.")
|
| 32 |
+
|
| 33 |
+
def check_ai_speech(audio_bytes: bytes) -> dict:
|
| 34 |
+
"""Stub: if your plan exposes classifier API, call it here; else returns unsupported."""
|
| 35 |
+
return {"supported": False, "prob_ai": None, "provider": "elevenlabs", "note": "Classifier not enabled in this template."}
|
| 36 |
+
|
| 37 |
+
# if __name__ == "__main__":
|
| 38 |
+
# generate_tts_dataset()
|
app/inference_wav2vec.py
ADDED
|
@@ -0,0 +1,214 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .models.wav2vec_detector import Wav2VecClassifier
|
| 7 |
+
from .utils.audio import load_audio, pad_or_trim, TARGET_SR
|
| 8 |
+
|
| 9 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
|
| 11 |
+
# ---------- Thresholds & biases ----------
|
| 12 |
+
AI_THRESHOLD_DEFAULT = float(os.getenv("DETECTOR_AI_THRESHOLD", "0.60"))
|
| 13 |
+
MIC_THRESHOLD = float(os.getenv("DETECTOR_MIC_THRESHOLD", "0.68"))
|
| 14 |
+
UPLOAD_THRESHOLD = float(os.getenv("DETECTOR_UPLOAD_THRESHOLD", str(AI_THRESHOLD_DEFAULT)))
|
| 15 |
+
AI_LOGIT_BIAS = float(os.getenv("DETECTOR_AI_LOGIT_BIAS", "0.00")) # add to AI logit globally
|
| 16 |
+
|
| 17 |
+
# ---------- Decision rule ----------
|
| 18 |
+
# 'threshold' -> AI if ai_prob >= threshold
|
| 19 |
+
# 'argmax' -> AI if ai_prob > human_prob
|
| 20 |
+
# 'hybrid' -> threshold, but if replay_score >= T1 and ai_prob >= 0.50 -> AI
|
| 21 |
+
DECISION_RULE = os.getenv("DECISION_RULE", "threshold").lower()
|
| 22 |
+
|
| 23 |
+
# ---------- Replay-attack heuristic ----------
|
| 24 |
+
REPLAY_ENABLE = os.getenv("REPLAY_ENABLE", "1") != "0"
|
| 25 |
+
REPLAY_AI_BONUS = float(os.getenv("REPLAY_AI_BONUS", "1.2"))
|
| 26 |
+
REPLAY_FORCE_LABEL = os.getenv("REPLAY_FORCE_LABEL", "0") == "1"
|
| 27 |
+
REPLAY_T1 = float(os.getenv("REPLAY_T1", "0.35")) # soft start
|
| 28 |
+
REPLAY_T2 = float(os.getenv("REPLAY_T2", "0.55")) # strong replay
|
| 29 |
+
|
| 30 |
+
# ---------- DSP helpers ----------
|
| 31 |
+
def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray:
|
| 32 |
+
m = float(np.max(np.abs(y)) + eps)
|
| 33 |
+
return (y / m) * peak if m > 0 else y
|
| 34 |
+
|
| 35 |
+
def rms_normalize(y: np.ndarray, target_rms: float = 0.03, eps: float = 1e-9) -> np.ndarray:
|
| 36 |
+
rms = float(np.sqrt(np.mean(y**2)) + eps)
|
| 37 |
+
g = target_rms / rms
|
| 38 |
+
return np.clip(y * g, -1.0, 1.0)
|
| 39 |
+
|
| 40 |
+
def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray:
|
| 41 |
+
if y.size == 0: return y
|
| 42 |
+
win = max(1, int(sr * 0.02))
|
| 43 |
+
pad = max(1, int(sr * (min_ms / 1000.0)))
|
| 44 |
+
energy = np.convolve(y ** 2, np.ones(win) / win, mode="same")
|
| 45 |
+
ref = np.max(energy) + 1e-12
|
| 46 |
+
mask = 10.0 * np.log10(energy / ref + 1e-12) > -thresh_db
|
| 47 |
+
if not np.any(mask): return y
|
| 48 |
+
idx = np.where(mask)[0]
|
| 49 |
+
start = max(0, int(idx[0] - pad))
|
| 50 |
+
end = min(len(y), int(idx[-1] + pad))
|
| 51 |
+
return y[start:end]
|
| 52 |
+
|
| 53 |
+
def noise_gate(y, sr, gate_db=-42.0):
|
| 54 |
+
m = np.max(np.abs(y)) + 1e-9
|
| 55 |
+
thr = m * (10.0 ** (gate_db / 20.0))
|
| 56 |
+
y2 = y.copy()
|
| 57 |
+
y2[np.abs(y2) < thr] = 0.0
|
| 58 |
+
return y2
|
| 59 |
+
|
| 60 |
+
def bandpass_fft(y: np.ndarray, sr: int, low=100.0, high=3800.0):
|
| 61 |
+
n = int(2 ** np.ceil(np.log2(len(y) + 1)))
|
| 62 |
+
Y = np.fft.rfft(y, n=n)
|
| 63 |
+
freqs = np.fft.rfftfreq(n, d=1.0/sr)
|
| 64 |
+
mask = (freqs >= low) & (freqs <= high)
|
| 65 |
+
Y_filtered = Y * mask
|
| 66 |
+
y_filt = np.fft.irfft(Y_filtered, n=n)[:len(y)]
|
| 67 |
+
return y_filt.astype(np.float32, copy=False)
|
| 68 |
+
|
| 69 |
+
# ---------- Replay score ----------
|
| 70 |
+
def replay_score(y: np.ndarray, sr: int) -> float:
|
| 71 |
+
if len(y) < sr:
|
| 72 |
+
y = np.pad(y, (0, sr - len(y)))
|
| 73 |
+
N = 4096
|
| 74 |
+
if len(y) < N:
|
| 75 |
+
y = np.pad(y, (0, N - len(y)))
|
| 76 |
+
w = np.hanning(N)
|
| 77 |
+
seg = y[:N] * w
|
| 78 |
+
X = np.abs(np.fft.rfft(seg)) + 1e-9
|
| 79 |
+
|
| 80 |
+
cep = np.fft.irfft(np.log(X))
|
| 81 |
+
qmin = max(1, int(0.0003 * sr))
|
| 82 |
+
qmax = min(len(cep) - 1, int(0.0040 * sr))
|
| 83 |
+
cwin = np.abs(cep[qmin:qmax])
|
| 84 |
+
c_peak = float(np.max(cwin)); c_mean = float(np.mean(cwin) + 1e-9)
|
| 85 |
+
cep_score = np.clip((c_peak - c_mean) / (c_peak + c_mean), 0.0, 1.0)
|
| 86 |
+
|
| 87 |
+
F = np.fft.rfftfreq(N, 1.0 / sr)
|
| 88 |
+
total = float(np.sum(X))
|
| 89 |
+
hf = float(np.sum(X[F >= 5000.0]))
|
| 90 |
+
hf_ratio = hf / (total + 1e-9)
|
| 91 |
+
hf_term = np.clip((0.25 - hf_ratio) / 0.25, 0.0, 1.0)
|
| 92 |
+
|
| 93 |
+
return float(np.clip(0.6 * cep_score + 0.4 * hf_term, 0.0, 1.0))
|
| 94 |
+
|
| 95 |
+
# ---------- Detector ----------
|
| 96 |
+
class Detector:
|
| 97 |
+
def __init__(self, weights_path: str, encoder: str | None = None, unfreeze_last: int = 0):
|
| 98 |
+
cfg = None
|
| 99 |
+
js = weights_path.replace(".pth", ".json")
|
| 100 |
+
if os.path.exists(js):
|
| 101 |
+
try:
|
| 102 |
+
with open(js, "r", encoding="utf-8") as f:
|
| 103 |
+
cfg = json.load(f)
|
| 104 |
+
except Exception:
|
| 105 |
+
cfg = None
|
| 106 |
+
enc = encoder or (cfg.get("encoder") if cfg else "facebook/wav2vec2-base")
|
| 107 |
+
unf = unfreeze_last or int(cfg.get("unfreeze_last", 0)) if cfg else 0
|
| 108 |
+
|
| 109 |
+
self.model = Wav2VecClassifier(encoder=enc, unfreeze_last=unf).to(DEVICE)
|
| 110 |
+
if weights_path and os.path.exists(weights_path):
|
| 111 |
+
state = torch.load(weights_path, map_location=DEVICE)
|
| 112 |
+
self.model.load_state_dict(state, strict=False)
|
| 113 |
+
self.model.eval()
|
| 114 |
+
|
| 115 |
+
def _preprocess(self, y: np.ndarray, sr: int, source_hint: str | None):
|
| 116 |
+
y = trim_silence(y, sr, 40.0, 30)
|
| 117 |
+
y = bandpass_fft(y, sr, 100.0, 3800.0)
|
| 118 |
+
if source_hint and source_hint.lower().startswith("micro"):
|
| 119 |
+
y = noise_gate(y, sr, -42.0)
|
| 120 |
+
y = rms_normalize(y, 0.035)
|
| 121 |
+
y = peak_normalize(y, 0.95)
|
| 122 |
+
else:
|
| 123 |
+
y = rms_normalize(y, 0.03)
|
| 124 |
+
y = peak_normalize(y, 0.95)
|
| 125 |
+
y = pad_or_trim(y, duration_s=3.0, sr=sr)
|
| 126 |
+
return y
|
| 127 |
+
|
| 128 |
+
@torch.inference_mode()
|
| 129 |
+
def predict_proba(self, wav_bytes_or_path, source_hint: str | None = None):
|
| 130 |
+
y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR)
|
| 131 |
+
rscore = replay_score(y0, sr) if REPLAY_ENABLE else 0.0
|
| 132 |
+
|
| 133 |
+
y = self._preprocess(y0, sr, source_hint)
|
| 134 |
+
x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE)
|
| 135 |
+
|
| 136 |
+
logits, _ = self.model(x)
|
| 137 |
+
logits = logits.clone()
|
| 138 |
+
logits[:, 1] += AI_LOGIT_BIAS
|
| 139 |
+
|
| 140 |
+
# Replay bonus on AI logit
|
| 141 |
+
if REPLAY_ENABLE and (source_hint and source_hint.lower().startswith("micro")) and (rscore >= REPLAY_T1):
|
| 142 |
+
ramp = np.clip((rscore - REPLAY_T1) / max(REPLAY_T2 - REPLAY_T1, 1e-6), 0.0, 1.0)
|
| 143 |
+
logits[:, 1] += REPLAY_AI_BONUS * ramp
|
| 144 |
+
|
| 145 |
+
probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
|
| 146 |
+
p_h, p_ai = float(probs[0]), float(probs[1])
|
| 147 |
+
|
| 148 |
+
thr_source = "mic" if (source_hint and source_hint.lower().startswith("micro")) else "upload"
|
| 149 |
+
thr = MIC_THRESHOLD if thr_source == "mic" else UPLOAD_THRESHOLD
|
| 150 |
+
|
| 151 |
+
# Labels by different rules
|
| 152 |
+
label_thresh = "ai" if p_ai >= thr else "human"
|
| 153 |
+
label_argmax = "ai" if p_ai > p_h else "human"
|
| 154 |
+
label_hybrid = label_thresh
|
| 155 |
+
if REPLAY_ENABLE and rscore >= REPLAY_T1 and p_ai >= 0.50:
|
| 156 |
+
label_hybrid = "ai"
|
| 157 |
+
if REPLAY_ENABLE and rscore >= REPLAY_T2 and (source_hint and source_hint.lower().startswith("micro")):
|
| 158 |
+
if REPLAY_FORCE_LABEL or p_ai >= (thr - 0.05):
|
| 159 |
+
label_hybrid = "ai"
|
| 160 |
+
|
| 161 |
+
if DECISION_RULE == "argmax":
|
| 162 |
+
label = label_argmax
|
| 163 |
+
rule_used = "argmax"
|
| 164 |
+
elif DECISION_RULE == "hybrid":
|
| 165 |
+
label = label_hybrid
|
| 166 |
+
rule_used = "hybrid(threshold+replay)"
|
| 167 |
+
else:
|
| 168 |
+
label = label_thresh
|
| 169 |
+
rule_used = "threshold"
|
| 170 |
+
|
| 171 |
+
return {
|
| 172 |
+
"human": p_h,
|
| 173 |
+
"ai": p_ai,
|
| 174 |
+
"label": label,
|
| 175 |
+
"threshold": float(thr),
|
| 176 |
+
"threshold_source": thr_source,
|
| 177 |
+
"backend": "wav2vec2",
|
| 178 |
+
"source_hint": (source_hint or "auto"),
|
| 179 |
+
"replay_score": float(rscore),
|
| 180 |
+
"decision": rule_used,
|
| 181 |
+
"decision_details": {
|
| 182 |
+
"ai_prob": p_ai,
|
| 183 |
+
"human_prob": p_h,
|
| 184 |
+
"prob_margin": p_ai - p_h,
|
| 185 |
+
"ai_vs_threshold_margin": p_ai - thr,
|
| 186 |
+
"replay_score": rscore,
|
| 187 |
+
"mic_threshold": MIC_THRESHOLD,
|
| 188 |
+
"upload_threshold": UPLOAD_THRESHOLD,
|
| 189 |
+
"force_label_AI": bool(REPLAY_FORCE_LABEL and rscore >= REPLAY_T2),
|
| 190 |
+
},
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def explain(self, wav_bytes_or_path, source_hint: str | None = None):
|
| 194 |
+
self.model.eval()
|
| 195 |
+
y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR)
|
| 196 |
+
y = self._preprocess(y0, sr, source_hint)
|
| 197 |
+
x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE)
|
| 198 |
+
x.requires_grad_(True)
|
| 199 |
+
logits, feats = self.model(x)
|
| 200 |
+
logits[:, 1].sum().backward(retain_graph=True)
|
| 201 |
+
if feats.grad is None:
|
| 202 |
+
s = x.grad.detach().abs().squeeze(0)
|
| 203 |
+
s = s / (s.max() + 1e-6)
|
| 204 |
+
H = 64
|
| 205 |
+
step = max(1, s.numel() // 256)
|
| 206 |
+
s_small = s[::step][:256].cpu().numpy()
|
| 207 |
+
cam = np.tile(s_small[None, :], (H, 1))
|
| 208 |
+
else:
|
| 209 |
+
g = feats.grad.detach().abs().sum(dim=-1).squeeze(0)
|
| 210 |
+
g = g / (g.max() + 1e-6)
|
| 211 |
+
H = 64
|
| 212 |
+
cam = np.tile(g.cpu().numpy()[None, :], (H, 1))
|
| 213 |
+
cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6)
|
| 214 |
+
return {"cam": cam.tolist(), "probs": None}
|
app/train.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import argparse
|
| 3 |
+
import random
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from contextlib import nullcontext
|
| 6 |
+
import importlib.util
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
import torch.backends.cudnn as cudnn
|
| 12 |
+
from torch.utils.data import Dataset, DataLoader
|
| 13 |
+
|
| 14 |
+
# ---------- Local imports ----------
|
| 15 |
+
try:
|
| 16 |
+
from .models.cnn_melspec import TinyMelCNN
|
| 17 |
+
from .utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR
|
| 18 |
+
except ImportError:
|
| 19 |
+
from app.models.cnn_melspec import TinyMelCNN
|
| 20 |
+
from app.utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR
|
| 21 |
+
|
| 22 |
+
# ---------- Augmentations (robust across versions) ----------
|
| 23 |
+
from audiomentations import (
|
| 24 |
+
Compose, AddGaussianNoise, TimeStretch, PitchShift, BandPassFilter
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
def make_gain(min_db, max_db, p):
|
| 28 |
+
"""Handle both min_gain_in_db/max_gain_in_db and min_gain_db/max_gain_db."""
|
| 29 |
+
from audiomentations import Gain as _Gain
|
| 30 |
+
try:
|
| 31 |
+
return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p)
|
| 32 |
+
except TypeError:
|
| 33 |
+
return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p)
|
| 34 |
+
|
| 35 |
+
def make_clipping(p=0.3):
|
| 36 |
+
"""
|
| 37 |
+
Build ClippingDistortion across versions.
|
| 38 |
+
Newer: min_percent/max_percent (0..20 typical)
|
| 39 |
+
Older: min_percentile_threshold/max_percentile_threshold in [0..100]
|
| 40 |
+
Returns None if not available.
|
| 41 |
+
"""
|
| 42 |
+
try:
|
| 43 |
+
from audiomentations import ClippingDistortion as _Clip
|
| 44 |
+
except Exception:
|
| 45 |
+
return None
|
| 46 |
+
|
| 47 |
+
# Try newer signature
|
| 48 |
+
for kwargs in (
|
| 49 |
+
dict(min_percent=0.0, max_percent=20.0, p=p),
|
| 50 |
+
dict(min_percent=5.0, max_percent=30.0, p=p),
|
| 51 |
+
):
|
| 52 |
+
try:
|
| 53 |
+
return _Clip(**kwargs)
|
| 54 |
+
except Exception:
|
| 55 |
+
pass
|
| 56 |
+
|
| 57 |
+
# Try older signature
|
| 58 |
+
for kwargs in (
|
| 59 |
+
dict(min_percentile_threshold=95, max_percentile_threshold=100, p=p),
|
| 60 |
+
dict(min_percentile_threshold=90, max_percentile_threshold=99, p=p),
|
| 61 |
+
):
|
| 62 |
+
try:
|
| 63 |
+
return _Clip(**kwargs)
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
|
| 67 |
+
return None
|
| 68 |
+
|
| 69 |
+
def have_fast_mp3():
|
| 70 |
+
return importlib.util.find_spec("fast_mp3_augment") is not None
|
| 71 |
+
|
| 72 |
+
def make_mp3_compression(min_bitrate=48, max_bitrate=96, p=0.6):
|
| 73 |
+
"""
|
| 74 |
+
Only enable Mp3Compression when the fast backend is present.
|
| 75 |
+
On Windows without the extra package this often breaks; we skip it.
|
| 76 |
+
"""
|
| 77 |
+
if not have_fast_mp3():
|
| 78 |
+
return None
|
| 79 |
+
try:
|
| 80 |
+
from audiomentations import Mp3Compression as _Mp3
|
| 81 |
+
# Prefer the fast backend; if API lacks backend arg, constructor still works.
|
| 82 |
+
try:
|
| 83 |
+
return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p, backend="fast_mp3_augment")
|
| 84 |
+
except TypeError:
|
| 85 |
+
return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p)
|
| 86 |
+
except Exception:
|
| 87 |
+
return None
|
| 88 |
+
|
| 89 |
+
# ---------- Repro ----------
|
| 90 |
+
def set_seed(seed: int = 42):
|
| 91 |
+
random.seed(seed); np.random.seed(seed)
|
| 92 |
+
torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
|
| 93 |
+
|
| 94 |
+
# ---------- Dataset ----------
|
| 95 |
+
class FolderDataset(Dataset):
|
| 96 |
+
"""
|
| 97 |
+
data_dir/
|
| 98 |
+
human/*.wav
|
| 99 |
+
ai/*.wav
|
| 100 |
+
"""
|
| 101 |
+
def __init__(self, root: str, split: str = "train", val_ratio: float = 0.15,
|
| 102 |
+
seed: int = 42, clip_seconds: float = 3.0):
|
| 103 |
+
self.root = Path(root)
|
| 104 |
+
self.clip_seconds = float(clip_seconds)
|
| 105 |
+
|
| 106 |
+
human = sorted((self.root / "human").glob("*.wav"))
|
| 107 |
+
ai = sorted((self.root / "ai").glob("*.wav"))
|
| 108 |
+
pairs = [(p, 0) for p in human] + [(p, 1) for p in ai]
|
| 109 |
+
|
| 110 |
+
rng = random.Random(seed)
|
| 111 |
+
rng.shuffle(pairs)
|
| 112 |
+
|
| 113 |
+
n_val = int(len(pairs) * val_ratio)
|
| 114 |
+
self.items = pairs[n_val:] if split == "train" else pairs[:n_val]
|
| 115 |
+
self.is_train = split == "train"
|
| 116 |
+
|
| 117 |
+
self._len_h = sum(1 for _, y in self.items if y == 0)
|
| 118 |
+
self._len_a = sum(1 for _, y in self.items if y == 1)
|
| 119 |
+
|
| 120 |
+
# Human: mild, natural perturbations
|
| 121 |
+
self.aug_human = Compose([
|
| 122 |
+
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.4),
|
| 123 |
+
TimeStretch(min_rate=0.96, max_rate=1.04, p=0.3),
|
| 124 |
+
PitchShift(min_semitones=-1, max_semitones=1, p=0.2),
|
| 125 |
+
make_gain(-4, 4, p=0.3),
|
| 126 |
+
])
|
| 127 |
+
|
| 128 |
+
# AI: replay-aware chain (speaker/room/mic simulation)
|
| 129 |
+
ai_transforms = [
|
| 130 |
+
BandPassFilter(min_center_freq=200.0, max_center_freq=3500.0, p=0.5),
|
| 131 |
+
AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.3),
|
| 132 |
+
TimeStretch(min_rate=0.95, max_rate=1.05, p=0.25),
|
| 133 |
+
make_gain(-6, 6, p=0.3),
|
| 134 |
+
]
|
| 135 |
+
clip = make_clipping(p=0.3)
|
| 136 |
+
if clip is not None:
|
| 137 |
+
ai_transforms.insert(1, clip)
|
| 138 |
+
mp3 = make_mp3_compression()
|
| 139 |
+
if mp3 is not None:
|
| 140 |
+
ai_transforms.insert(0, mp3)
|
| 141 |
+
|
| 142 |
+
self.aug_ai = Compose(ai_transforms)
|
| 143 |
+
|
| 144 |
+
def __len__(self):
|
| 145 |
+
return len(self.items)
|
| 146 |
+
|
| 147 |
+
def __getitem__(self, idx: int):
|
| 148 |
+
path, label = self.items[idx]
|
| 149 |
+
y, sr = load_audio(str(path), TARGET_SR)
|
| 150 |
+
y = pad_or_trim(y, duration_s=self.clip_seconds, sr=sr)
|
| 151 |
+
|
| 152 |
+
if self.is_train:
|
| 153 |
+
if label == 1:
|
| 154 |
+
y = self.aug_ai(samples=y, sample_rate=sr)
|
| 155 |
+
else:
|
| 156 |
+
y = self.aug_human(samples=y, sample_rate=sr)
|
| 157 |
+
|
| 158 |
+
mel = logmel(y, sr) # (n_mels, T)
|
| 159 |
+
x = torch.from_numpy(mel).unsqueeze(0) # (1, n_mels, T)
|
| 160 |
+
y_t = torch.tensor(label, dtype=torch.long)
|
| 161 |
+
return x, y_t
|
| 162 |
+
|
| 163 |
+
# ---------- Dataloaders ----------
|
| 164 |
+
def make_dataloaders(args):
|
| 165 |
+
ds_tr = FolderDataset(args.data_dir, split="train", val_ratio=args.val_ratio,
|
| 166 |
+
seed=args.seed, clip_seconds=args.clip_seconds)
|
| 167 |
+
ds_va = FolderDataset(args.data_dir, split="val", val_ratio=args.val_ratio,
|
| 168 |
+
seed=args.seed, clip_seconds=args.clip_seconds)
|
| 169 |
+
|
| 170 |
+
# Windows is happier with workers=0; keep configurable
|
| 171 |
+
workers = args.workers if args.workers >= 0 else (0 if os.name == "nt" else max(1, (os.cpu_count() or 4)//2))
|
| 172 |
+
pin = (not args.cpu) and torch.cuda.is_available()
|
| 173 |
+
|
| 174 |
+
dl_tr = DataLoader(
|
| 175 |
+
ds_tr, batch_size=args.batch_size, shuffle=True,
|
| 176 |
+
num_workers=workers, pin_memory=pin,
|
| 177 |
+
persistent_workers=(workers > 0), drop_last=True,
|
| 178 |
+
)
|
| 179 |
+
dl_va = DataLoader(
|
| 180 |
+
ds_va, batch_size=max(1, args.batch_size // 2), shuffle=False,
|
| 181 |
+
num_workers=workers, pin_memory=pin,
|
| 182 |
+
persistent_workers=(workers > 0),
|
| 183 |
+
)
|
| 184 |
+
return ds_tr, ds_va, dl_tr, dl_va
|
| 185 |
+
|
| 186 |
+
def class_weights_from_dataset(ds: FolderDataset, eps: float = 1e-6):
|
| 187 |
+
n_h, n_a = max(ds._len_h, eps), max(ds._len_a, eps)
|
| 188 |
+
w_h = (n_h + n_a) / (2 * n_h)
|
| 189 |
+
w_a = (n_h + n_a) / (2 * n_a)
|
| 190 |
+
return torch.tensor([w_h, w_a], dtype=torch.float32)
|
| 191 |
+
|
| 192 |
+
# ---------- Training / Eval ----------
|
| 193 |
+
def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1):
|
| 194 |
+
model.train()
|
| 195 |
+
total_loss = 0.0
|
| 196 |
+
correct = 0
|
| 197 |
+
seen = 0
|
| 198 |
+
opt.zero_grad(set_to_none=True)
|
| 199 |
+
|
| 200 |
+
for step, (x, y) in enumerate(dl):
|
| 201 |
+
x = x.to(device, non_blocking=True)
|
| 202 |
+
y = y.to(device, non_blocking=True)
|
| 203 |
+
|
| 204 |
+
with autocast_ctx:
|
| 205 |
+
logits = model(x)
|
| 206 |
+
loss = loss_fn(logits, y)
|
| 207 |
+
|
| 208 |
+
loss = loss / grad_accum
|
| 209 |
+
if getattr(scaler, "is_enabled", lambda: False)():
|
| 210 |
+
scaler.scale(loss).backward()
|
| 211 |
+
else:
|
| 212 |
+
loss.backward()
|
| 213 |
+
|
| 214 |
+
if (step + 1) % grad_accum == 0:
|
| 215 |
+
if getattr(scaler, "is_enabled", lambda: False)():
|
| 216 |
+
scaler.step(opt)
|
| 217 |
+
scaler.update()
|
| 218 |
+
else:
|
| 219 |
+
opt.step()
|
| 220 |
+
opt.zero_grad(set_to_none=True)
|
| 221 |
+
|
| 222 |
+
total_loss += float(loss) * x.size(0) * grad_accum
|
| 223 |
+
correct += int((logits.argmax(1) == y).sum().item())
|
| 224 |
+
seen += x.size(0)
|
| 225 |
+
|
| 226 |
+
return total_loss / max(seen, 1), correct / max(seen, 1)
|
| 227 |
+
|
| 228 |
+
@torch.no_grad()
|
| 229 |
+
def evaluate(model, dl, device, loss_fn):
|
| 230 |
+
model.eval()
|
| 231 |
+
total_loss = 0.0
|
| 232 |
+
correct = 0
|
| 233 |
+
seen = 0
|
| 234 |
+
for x, y in dl:
|
| 235 |
+
x = x.to(device, non_blocking=True)
|
| 236 |
+
y = y.to(device, non_blocking=True)
|
| 237 |
+
logits = model(x)
|
| 238 |
+
loss = loss_fn(logits, y)
|
| 239 |
+
total_loss += float(loss) * x.size(0)
|
| 240 |
+
correct += int((logits.argmax(1) == y).sum().item())
|
| 241 |
+
seen += x.size(0)
|
| 242 |
+
return total_loss / max(seen, 1), correct / max(seen, 1)
|
| 243 |
+
|
| 244 |
+
def main(args):
|
| 245 |
+
set_seed(args.seed)
|
| 246 |
+
device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu"
|
| 247 |
+
cudnn.benchmark = True
|
| 248 |
+
|
| 249 |
+
ds_tr, ds_va, dl_tr, dl_va = make_dataloaders(args)
|
| 250 |
+
print(f"Train items: {len(ds_tr)} (human={ds_tr._len_h}, ai={ds_tr._len_a})")
|
| 251 |
+
print(f"Val items: {len(ds_va)}")
|
| 252 |
+
|
| 253 |
+
model = TinyMelCNN().to(device)
|
| 254 |
+
|
| 255 |
+
weights = class_weights_from_dataset(ds_tr).to(device)
|
| 256 |
+
loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
|
| 257 |
+
|
| 258 |
+
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
|
| 259 |
+
|
| 260 |
+
# AMP (use new torch.amp if available, else fallback)
|
| 261 |
+
try:
|
| 262 |
+
from torch.amp import GradScaler, autocast as amp_autocast
|
| 263 |
+
scaler = GradScaler("cuda", enabled=(device == "cuda" and args.amp))
|
| 264 |
+
autocast_ctx = amp_autocast("cuda") if (device == "cuda" and args.amp) else nullcontext()
|
| 265 |
+
except Exception:
|
| 266 |
+
from torch.cuda.amp import GradScaler, autocast as amp_autocast # deprecated but works
|
| 267 |
+
scaler = GradScaler(enabled=(device == "cuda" and args.amp))
|
| 268 |
+
autocast_ctx = amp_autocast() if (device == "cuda" and args.amp) else nullcontext()
|
| 269 |
+
|
| 270 |
+
best_va = -1.0
|
| 271 |
+
patience_counter = 0
|
| 272 |
+
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
| 273 |
+
|
| 274 |
+
for epoch in range(args.epochs):
|
| 275 |
+
tr_loss, tr_acc = train_one_epoch(
|
| 276 |
+
model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn,
|
| 277 |
+
grad_accum=args.grad_accum
|
| 278 |
+
)
|
| 279 |
+
va_loss, va_acc = evaluate(model, dl_va, device, loss_fn)
|
| 280 |
+
|
| 281 |
+
print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")
|
| 282 |
+
|
| 283 |
+
# Save "last" every epoch
|
| 284 |
+
torch.save(model.state_dict(), args.out.replace(".pth", ".last.pth"))
|
| 285 |
+
|
| 286 |
+
if va_acc > best_va + 1e-4:
|
| 287 |
+
best_va = va_acc
|
| 288 |
+
torch.save(model.state_dict(), args.out)
|
| 289 |
+
patience_counter = 0
|
| 290 |
+
print(f"✅ Saved best to {args.out} (val_acc={best_va:.3f})")
|
| 291 |
+
else:
|
| 292 |
+
patience_counter += 1
|
| 293 |
+
if args.early_stop > 0 and patience_counter >= args.early_stop:
|
| 294 |
+
print(f"⏹️ Early stopping at epoch {epoch+1} (best val_acc={best_va:.3f})")
|
| 295 |
+
break
|
| 296 |
+
|
| 297 |
+
print("Done.")
|
| 298 |
+
|
| 299 |
+
if __name__ == "__main__":
|
| 300 |
+
p = argparse.ArgumentParser(description="Train AI Voice Detector (replay-aware, version-robust, no fast_mp3 required)")
|
| 301 |
+
p.add_argument("--data_dir", type=str, required=True, help="Folder with subfolders human/ and ai/")
|
| 302 |
+
p.add_argument("--out", type=str, default="app/models/weights/cnn_melspec.pth")
|
| 303 |
+
p.add_argument("--epochs", type=int, default=10)
|
| 304 |
+
p.add_argument("--batch_size", type=int, default=32)
|
| 305 |
+
p.add_argument("--grad_accum", type=int, default=2)
|
| 306 |
+
p.add_argument("--lr", type=float, default=1e-3)
|
| 307 |
+
p.add_argument("--val_ratio", type=float, default=0.15)
|
| 308 |
+
p.add_argument("--clip_seconds", type=float, default=3.0)
|
| 309 |
+
p.add_argument("--workers", type=int, default=-1) # try --workers 0 on Windows if you see issues
|
| 310 |
+
p.add_argument("--amp", action="store_true", default=True)
|
| 311 |
+
p.add_argument("--cpu", action="store_true")
|
| 312 |
+
p.add_argument("--early_stop", type=int, default=0)
|
| 313 |
+
p.add_argument("--seed", type=int, default=42)
|
| 314 |
+
args = p.parse_args()
|
| 315 |
+
main(args)
|
app/train_wav2vec.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, json, argparse, random
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from contextlib import nullcontext
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torch.backends.cudnn as cudnn
|
| 9 |
+
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from .models.wav2vec_detector import Wav2VecClassifier
|
| 13 |
+
from .utils.audio import load_audio, pad_or_trim, TARGET_SR
|
| 14 |
+
except ImportError:
|
| 15 |
+
from app.models.wav2vec_detector import Wav2VecClassifier
|
| 16 |
+
from app.utils.audio import load_audio, pad_or_trim, TARGET_SR
|
| 17 |
+
|
| 18 |
+
from audiomentations import Compose, AddGaussianNoise, BandPassFilter
|
| 19 |
+
def make_gain(min_db, max_db, p):
|
| 20 |
+
from audiomentations import Gain as _Gain
|
| 21 |
+
try: return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p)
|
| 22 |
+
except TypeError: return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p)
|
| 23 |
+
|
| 24 |
+
def set_seed(s=42):
|
| 25 |
+
random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
|
| 26 |
+
|
| 27 |
+
def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray:
|
| 28 |
+
m = float(np.max(np.abs(y)) + eps)
|
| 29 |
+
return (y / m) * peak if m > 0 else y
|
| 30 |
+
|
| 31 |
+
def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray:
|
| 32 |
+
if y.size == 0: return y
|
| 33 |
+
win = max(1, int(sr * 0.02)); pad = max(1, int(sr * (min_ms / 1000.0)))
|
| 34 |
+
energy = np.convolve(y**2, np.ones(win)/win, mode="same")
|
| 35 |
+
ref = np.max(energy) + 1e-12
|
| 36 |
+
mask = 10*np.log10(energy/ref + 1e-12) > -thresh_db
|
| 37 |
+
if not np.any(mask): return y
|
| 38 |
+
idx = np.where(mask)[0]; start=max(0,int(idx[0]-pad)); end=min(len(y),int(idx[-1]+pad))
|
| 39 |
+
return y[start:end]
|
| 40 |
+
|
| 41 |
+
class WavDataset(Dataset):
|
| 42 |
+
"""data_dir/{human,ai}/*.wav"""
|
| 43 |
+
def __init__(self, root, split="train", val_ratio=0.15, seed=42, clip_seconds=3.0):
|
| 44 |
+
self.root = Path(root); self.clip = float(clip_seconds)
|
| 45 |
+
human = sorted((self.root/"human").glob("*.wav"))
|
| 46 |
+
ai = sorted((self.root/"ai").glob("*.wav"))
|
| 47 |
+
items = [(p,0) for p in human] + [(p,1) for p in ai]
|
| 48 |
+
rng = random.Random(seed); rng.shuffle(items)
|
| 49 |
+
n_val = int(len(items)*val_ratio)
|
| 50 |
+
self.items = items[n_val:] if split=="train" else items[:n_val]
|
| 51 |
+
self.is_train = split=="train"
|
| 52 |
+
self.nh = sum(1 for _,y in self.items if y==0)
|
| 53 |
+
self.na = sum(1 for _,y in self.items if y==1)
|
| 54 |
+
self.aug_h = Compose([AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-4,4,p=0.3)])
|
| 55 |
+
self.aug_a = Compose([BandPassFilter(200.0,3500.0,p=0.5), AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-6,6,p=0.3)])
|
| 56 |
+
|
| 57 |
+
def __len__(self): return len(self.items)
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
path, label = self.items[idx]
|
| 60 |
+
y, sr = load_audio(str(path), TARGET_SR)
|
| 61 |
+
y = trim_silence(y, sr, 40.0, 30)
|
| 62 |
+
y = peak_normalize(y, 0.95)
|
| 63 |
+
y = pad_or_trim(y, duration_s=self.clip, sr=sr)
|
| 64 |
+
if self.is_train:
|
| 65 |
+
y = (self.aug_a if label==1 else self.aug_h)(samples=y, sample_rate=sr)
|
| 66 |
+
return torch.from_numpy(y).float(), torch.tensor(label, dtype=torch.long)
|
| 67 |
+
|
| 68 |
+
def make_loaders(args):
|
| 69 |
+
ds_tr = WavDataset(args.data_dir, "train", args.val_ratio, args.seed, args.clip_seconds)
|
| 70 |
+
ds_va = WavDataset(args.data_dir, "val", args.val_ratio, args.seed, args.clip_seconds)
|
| 71 |
+
|
| 72 |
+
# Weighted sampler to balance classes
|
| 73 |
+
labels = [y for _, y in ds_tr.items]
|
| 74 |
+
n0 = max(1, labels.count(0)); n1 = max(1, labels.count(1))
|
| 75 |
+
w0 = (n0 + n1) / (2 * n0); w1 = (n0 + n1) / (2 * n1)
|
| 76 |
+
sample_weights = [w0 if y == 0 else w1 for y in labels]
|
| 77 |
+
sampler = WeightedRandomSampler(sample_weights, num_samples=len(labels), replacement=True)
|
| 78 |
+
|
| 79 |
+
workers = args.workers if args.workers >= 0 else (0 if os.name=="nt" else max(1,(os.cpu_count() or 4)//2))
|
| 80 |
+
pin = (not args.cpu) and torch.cuda.is_available()
|
| 81 |
+
dl_tr = DataLoader(ds_tr, batch_size=args.batch_size, sampler=sampler,
|
| 82 |
+
num_workers=workers, pin_memory=pin, drop_last=True)
|
| 83 |
+
dl_va = DataLoader(ds_va, batch_size=max(1,args.batch_size//2), shuffle=False,
|
| 84 |
+
num_workers=workers, pin_memory=pin)
|
| 85 |
+
return ds_tr, ds_va, dl_tr, dl_va
|
| 86 |
+
|
| 87 |
+
class FocalLoss(torch.nn.Module):
|
| 88 |
+
def __init__(self, alpha=None, gamma=1.5):
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.alpha = alpha
|
| 91 |
+
self.gamma = gamma
|
| 92 |
+
self.ce = torch.nn.CrossEntropyLoss(weight=alpha)
|
| 93 |
+
def forward(self, logits, target):
|
| 94 |
+
ce = self.ce(logits, target)
|
| 95 |
+
with torch.no_grad():
|
| 96 |
+
pt = torch.exp(-ce)
|
| 97 |
+
return ((1 - pt) ** self.gamma) * ce
|
| 98 |
+
|
| 99 |
+
def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1):
|
| 100 |
+
model.train(); total=0.0; correct=0; seen=0
|
| 101 |
+
opt.zero_grad(set_to_none=True)
|
| 102 |
+
for step,(x,y) in enumerate(dl):
|
| 103 |
+
x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
|
| 104 |
+
with autocast_ctx:
|
| 105 |
+
logits,_=model(x); loss=loss_fn(logits,y)
|
| 106 |
+
loss=loss/grad_accum
|
| 107 |
+
if getattr(scaler,"is_enabled",lambda:False)(): scaler.scale(loss).backward()
|
| 108 |
+
else: loss.backward()
|
| 109 |
+
if (step+1)%grad_accum==0:
|
| 110 |
+
if getattr(scaler,"is_enabled",lambda:False)():
|
| 111 |
+
scaler.step(opt); scaler.update()
|
| 112 |
+
else:
|
| 113 |
+
opt.step()
|
| 114 |
+
opt.zero_grad(set_to_none=True)
|
| 115 |
+
total += float(loss) * x.size(0) * grad_accum
|
| 116 |
+
correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0)
|
| 117 |
+
return total/max(seen,1), correct/max(seen,1)
|
| 118 |
+
|
| 119 |
+
@torch.no_grad()
|
| 120 |
+
def evaluate(model, dl, device, loss_fn):
|
| 121 |
+
model.eval(); total=0.0; correct=0; seen=0
|
| 122 |
+
for x,y in dl:
|
| 123 |
+
x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
|
| 124 |
+
logits,_=model(x); loss=loss_fn(logits,y)
|
| 125 |
+
total += float(loss) * x.size(0); correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0)
|
| 126 |
+
return total/max(seen,1), correct/max(seen,1)
|
| 127 |
+
|
| 128 |
+
def main(args):
|
| 129 |
+
set_seed(args.seed)
|
| 130 |
+
device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu"
|
| 131 |
+
cudnn.benchmark = True
|
| 132 |
+
|
| 133 |
+
ds_tr, ds_va, dl_tr, dl_va = make_loaders(args)
|
| 134 |
+
print(f"Train items: {len(ds_tr)} (human={ds_tr.nh}, ai={ds_tr.na})")
|
| 135 |
+
print(f"Val items: {len(ds_va)}")
|
| 136 |
+
|
| 137 |
+
model = Wav2VecClassifier(
|
| 138 |
+
encoder=args.encoder,
|
| 139 |
+
unfreeze_last=args.unfreeze_last,
|
| 140 |
+
dropout=args.dropout,
|
| 141 |
+
hidden=args.hidden
|
| 142 |
+
).to(device)
|
| 143 |
+
|
| 144 |
+
# Focal loss with class weights
|
| 145 |
+
nh, na = ds_tr.nh, ds_tr.na
|
| 146 |
+
w = torch.tensor([(nh+na)/(2*nh+1e-6), (nh+na)/(2*na+1e-6)], dtype=torch.float32).to(device)
|
| 147 |
+
loss_fn = FocalLoss(alpha=w, gamma=1.5)
|
| 148 |
+
|
| 149 |
+
head_params = list(model.head.parameters())
|
| 150 |
+
enc_params = [p for p in model.encoder.parameters() if p.requires_grad]
|
| 151 |
+
param_groups = [{"params": head_params, "lr": args.lr_head}]
|
| 152 |
+
if enc_params:
|
| 153 |
+
param_groups.append({"params": enc_params, "lr": args.lr_encoder})
|
| 154 |
+
opt = torch.optim.AdamW(param_groups, weight_decay=1e-4)
|
| 155 |
+
|
| 156 |
+
try:
|
| 157 |
+
from torch.amp import GradScaler, autocast as amp_autocast
|
| 158 |
+
scaler = GradScaler("cuda", enabled=(device=="cuda" and args.amp))
|
| 159 |
+
autocast_ctx = amp_autocast("cuda") if (device=="cuda" and args.amp) else nullcontext()
|
| 160 |
+
except Exception:
|
| 161 |
+
from torch.cuda.amp import GradScaler, autocast as amp_autocast
|
| 162 |
+
scaler = GradScaler(enabled=(device=="cuda" and args.amp))
|
| 163 |
+
autocast_ctx = amp_autocast() if (device=="cuda" and args.amp) else nullcontext()
|
| 164 |
+
|
| 165 |
+
best=-1.0; patience=0
|
| 166 |
+
Path(args.out).parent.mkdir(parents=True, exist_ok=True)
|
| 167 |
+
with open(args.out.replace(".pth",".json"), "w", encoding="utf-8") as f:
|
| 168 |
+
json.dump({"encoder": args.encoder, "unfreeze_last": args.unfreeze_last}, f)
|
| 169 |
+
|
| 170 |
+
for epoch in range(args.epochs):
|
| 171 |
+
tr_loss, tr_acc = train_one_epoch(model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn, args.grad_accum)
|
| 172 |
+
va_loss, va_acc = evaluate(model, dl_va, device, loss_fn)
|
| 173 |
+
print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")
|
| 174 |
+
torch.save(model.state_dict(), args.out.replace(".pth",".last.pth"))
|
| 175 |
+
if va_acc > best + 1e-4:
|
| 176 |
+
best = va_acc; patience=0
|
| 177 |
+
torch.save(model.state_dict(), args.out)
|
| 178 |
+
print(f"✅ Saved best to {args.out} (val_acc={best:.3f})")
|
| 179 |
+
else:
|
| 180 |
+
patience += 1
|
| 181 |
+
if args.early_stop>0 and patience>=args.early_stop:
|
| 182 |
+
print(f"⏹️ Early stopping at epoch {epoch+1} (best={best:.3f})")
|
| 183 |
+
break
|
| 184 |
+
print("Done.")
|
| 185 |
+
|
| 186 |
+
if __name__ == "__main__":
|
| 187 |
+
ap = argparse.ArgumentParser(description="Train Wav2Vec2-based AI Voice Detector (balanced)")
|
| 188 |
+
ap.add_argument("--data_dir", required=True, help="Folder with human/ and ai/ WAVs")
|
| 189 |
+
ap.add_argument("--out", default="app/models/weights/wav2vec2_classifier.pth")
|
| 190 |
+
ap.add_argument("--encoder", default="facebook/wav2vec2-base")
|
| 191 |
+
ap.add_argument("--unfreeze_last", type=int, default=0)
|
| 192 |
+
ap.add_argument("--epochs", type=int, default=8)
|
| 193 |
+
ap.add_argument("--batch_size", type=int, default=16)
|
| 194 |
+
ap.add_argument("--grad_accum", type=int, default=2)
|
| 195 |
+
ap.add_argument("--lr_head", type=float, default=1e-3)
|
| 196 |
+
ap.add_argument("--lr_encoder", type=float, default=1e-5)
|
| 197 |
+
ap.add_argument("--val_ratio", type=float, default=0.15)
|
| 198 |
+
ap.add_argument("--clip_seconds", type=float, default=3.0)
|
| 199 |
+
ap.add_argument("--workers", type=int, default=-1)
|
| 200 |
+
ap.add_argument("--amp", action="store_true", default=True)
|
| 201 |
+
ap.add_argument("--cpu", action="store_true")
|
| 202 |
+
ap.add_argument("--dropout", type=float, default=0.2)
|
| 203 |
+
ap.add_argument("--hidden", type=int, default=256)
|
| 204 |
+
ap.add_argument("--early_stop", type=int, default=0)
|
| 205 |
+
ap.add_argument("--seed", type=int, default=42)
|
| 206 |
+
args = ap.parse_args()
|
| 207 |
+
main(args)
|