varunkul commited on
Commit
e2c61ce
·
verified ·
1 Parent(s): e744be2

Upload 6 files

Browse files
Files changed (6) hide show
  1. app/api.py +65 -0
  2. app/app.py +138 -0
  3. app/elevenlabs_tools.py +38 -0
  4. app/inference_wav2vec.py +214 -0
  5. app/train.py +315 -0
  6. app/train_wav2vec.py +207 -0
app/api.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, io, base64
2
+ import numpy as np
3
+ from fastapi import FastAPI, UploadFile, File, Form
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Optional, Dict, Any
7
+ from PIL import Image
8
+ from matplotlib import cm
9
+
10
+ BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").lower()
11
+ try:
12
+ if BACKEND == "wav2vec2":
13
+ from .inference_wav2vec import Detector # type: ignore
14
+ else:
15
+ from .inference import Detector # type: ignore
16
+ except Exception:
17
+ if BACKEND == "wav2vec2":
18
+ from app.inference_wav2vec import Detector # type: ignore
19
+ else:
20
+ from app.inference import Detector # type: ignore
21
+
22
+ DEFAULT_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth" if BACKEND=="wav2vec2" else "app/models/weights/cnn_melspec.pth"
23
+ WEIGHTS = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS)
24
+ det = Detector(weights_path=WEIGHTS)
25
+
26
+ app = FastAPI(title="Voice Guard API", version="1.1.0")
27
+ app.add_middleware(
28
+ CORSMiddleware,
29
+ allow_origins=["*"], # tighten in prod
30
+ allow_methods=["*"],
31
+ allow_headers=["*"],
32
+ )
33
+
34
+ class AnalyzeResponse(BaseModel):
35
+ human: float
36
+ ai: float
37
+ label: str
38
+ threshold: float
39
+ threshold_source: Optional[str] = None
40
+ backend: str
41
+ source_hint: str
42
+ replay_score: Optional[float] = None
43
+ decision: Optional[str] = None
44
+ decision_details: Optional[Dict[str, Any]] = None
45
+ heatmap_b64: str
46
+
47
+ def heatmap_png_b64(cam: np.ndarray) -> str:
48
+ cam = np.clip(cam, 0.0, 1.0).astype(np.float32)
49
+ rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
50
+ im = Image.fromarray(rgb)
51
+ buf = io.BytesIO(); im.save(buf, format="PNG")
52
+ return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")
53
+
54
+ @app.post("/analyze", response_model=AnalyzeResponse)
55
+ async def analyze(file: UploadFile = File(...), source_hint: str = Form("auto")):
56
+ raw = await file.read()
57
+ proba = det.predict_proba(raw, source_hint=source_hint)
58
+ cam = np.array(det.explain(raw, source_hint=source_hint)["cam"], dtype=np.float32)
59
+ return {
60
+ **proba,
61
+ "heatmap_b64": heatmap_png_b64(cam),
62
+ }
63
+
64
+ @app.get("/health")
65
+ def health(): return {"ok": True, "backend": BACKEND}
app/app.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import numpy as np
4
+ import gradio as gr
5
+ from dotenv import load_dotenv
6
+ from matplotlib import cm
7
+
8
+ load_dotenv()
9
+
10
+ # -------------------------
11
+ # 0) Env & defaults
12
+ # -------------------------
13
+ BACKEND = os.getenv("DETECTOR_BACKEND", "wav2vec2").strip().lower() # "wav2vec2" or "cnn"
14
+
15
+ DEFAULT_W2V_WEIGHTS = "app/models/weights/wav2vec2_classifier.pth"
16
+ DEFAULT_CNN_WEIGHTS = "app/models/weights/cnn_melspec.pth"
17
+ DEFAULT_WEIGHTS = DEFAULT_W2V_WEIGHTS if BACKEND == "wav2vec2" else DEFAULT_CNN_WEIGHTS
18
+ MODEL_WEIGHTS_PATH = os.getenv("MODEL_WEIGHTS_PATH", DEFAULT_WEIGHTS).strip()
19
+
20
+
21
+ # -------------------------
22
+ # 1) Import your Detector
23
+ # -------------------------
24
+ def _import_detector(backend):
25
+ """
26
+ Import the correct Detector class depending on backend and package layout.
27
+ Works both when run as a module ('.inference_*') and as a script ('app.inference_*').
28
+ """
29
+ try:
30
+ if backend == "wav2vec2":
31
+ from .inference_wav2vec import Detector # type: ignore
32
+ else:
33
+ from .inference import Detector # type: ignore
34
+ except Exception:
35
+ if backend == "wav2vec2":
36
+ from app.inference_wav2vec import Detector # type: ignore
37
+ else:
38
+ from app.inference import Detector # type: ignore
39
+ return Detector
40
+
41
+
42
+ try:
43
+ Detector = _import_detector(BACKEND)
44
+ except Exception as e:
45
+ # Fallback dummy to keep the UI alive even if import fails,
46
+ # so you can see the error in the JSON panel.
47
+ class Detector: # type: ignore
48
+ def __init__(self, *args, **kwargs):
49
+ self._err = f"Detector import failed: {e}"
50
+
51
+ def predict_proba(self, *args, **kwargs):
52
+ return {"error": self._err}
53
+
54
+ def explain(self, *args, **kwargs):
55
+ return {"cam": np.zeros((128, 128), dtype=np.float32).tolist()}
56
+
57
+
58
+ # Single, shared detector (created lazily so startup is fast on Spaces)
59
+ _DET = None
60
+ def _get_detector():
61
+ global _DET
62
+ if _DET is None:
63
+ _DET = Detector(weights_path=MODEL_WEIGHTS_PATH)
64
+ return _DET
65
+
66
+
67
+ # -------------------------
68
+ # 2) Core functions
69
+ # -------------------------
70
+ def predict_and_explain(audio_path: str | None, source_hint: str):
71
+ """
72
+ audio_path: filepath from Gradio (since type='filepath')
73
+ source_hint: "Auto", "Microphone", "Upload"
74
+ """
75
+ source = (source_hint or "Auto").strip().lower()
76
+
77
+ if not audio_path or not os.path.exists(audio_path):
78
+ return {"error": "No audio received. Record or upload a 2–4s clip."}, None
79
+
80
+ det = _get_detector()
81
+
82
+ # Your Detector is expected to accept a file path and optional source hint
83
+ proba = det.predict_proba(audio_path, source_hint=source)
84
+ exp = det.explain(audio_path, source_hint=source)
85
+
86
+ # Explanation to heatmap (float [0,1] -> magma RGB uint8)
87
+ cam = np.array(exp.get("cam", []), dtype=np.float32)
88
+ if cam.ndim == 1:
89
+ # if model returned a 1D vector, tile to square-ish map
90
+ side = int(np.sqrt(cam.size))
91
+ side = max(side, 2)
92
+ cam = cam[: side * side].reshape(side, side)
93
+ cam = np.clip(cam, 0.0, 1.0)
94
+ cam_rgb = (cm.magma(cam)[..., :3] * 255).astype(np.uint8)
95
+
96
+ # Ensure proba is JSON-serializable
97
+ if not isinstance(proba, dict):
98
+ proba = {"result": proba}
99
+
100
+ return proba, cam_rgb
101
+
102
+
103
+ def provenance(audio_path: str | None):
104
+ # Stub (you can wire a provenance model or checksum here)
105
+ return {"ok": True, "note": "Provenance check not wired in this app.py."}
106
+
107
+
108
+ # -------------------------
109
+ # 3) UI
110
+ # -------------------------
111
+ with gr.Blocks(title=f"AI Voice Detector · {BACKEND.upper()}") as demo:
112
+ gr.Markdown(f"# 🔎 AI Voice Detector — Backend: **{BACKEND.upper()}**")
113
+ gr.Markdown(
114
+ "Record or upload a short clip (~3s). Get probabilities, a label, and an explanation heatmap."
115
+ )
116
+
117
+ with gr.Row():
118
+ audio_in = gr.Audio(sources=["microphone", "upload"], type="filepath", label="Audio")
119
+ with gr.Column():
120
+ src = gr.Radio(choices=["Auto", "Microphone", "Upload"], value="Auto", label="Source")
121
+ btn_predict = gr.Button("Analyze", variant="primary")
122
+ btn_prov = gr.Button("Provenance Check (optional)")
123
+
124
+ with gr.Row():
125
+ json_out = gr.JSON(label="Prediction (probabilities + label)")
126
+ cam_out = gr.Image(label="Explanation Heatmap (saliency)")
127
+ prov_out = gr.JSON(label="Provenance Result (if available)")
128
+
129
+ btn_predict.click(predict_and_explain, inputs=[audio_in, src], outputs=[json_out, cam_out])
130
+ btn_prov.click(provenance, inputs=audio_in, outputs=prov_out)
131
+
132
+
133
+ # -------------------------
134
+ # 4) Launch (Spaces-friendly)
135
+ # -------------------------
136
+ if __name__ == "__main__":
137
+ # queue() keeps UI responsive under load; host/port are Spaces-safe and local-friendly
138
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
app/elevenlabs_tools.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os, time, hashlib, json, pathlib, random
3
+ from typing import List, Optional
4
+ from dotenv import load_dotenv
5
+ import requests
6
+
7
+ load_dotenv()
8
+ ELEVEN_API_KEY = os.getenv("ELEVEN_API_KEY", "")
9
+ ELEVEN_VOICE_ID = os.getenv("ELEVEN_VOICE_ID", "")
10
+ BASE = "https://api.elevenlabs.io/v1"
11
+
12
+ def _headers():
13
+ return {"xi-api-key": ELEVEN_API_KEY, "accept": "audio/mpeg", "Content-Type": "application/json"}
14
+
15
+ def generate_tts_dataset(texts: List[str], voice_id: Optional[str]=None, out_dir: str="data/raw/ai", model_id: str="eleven_monolingual_v1"):
16
+ """Generate AI speech MP3s from ElevenLabs into out_dir. Convert to WAV (16k mono) for training."""
17
+ voice_id = voice_id or ELEVEN_VOICE_ID
18
+ assert ELEVEN_API_KEY, "Set ELEVEN_API_KEY in .env"
19
+ assert voice_id, "Provide ELEVEN_VOICE_ID in .env or pass voice_id"
20
+ os.makedirs(out_dir, exist_ok=True)
21
+ for i, txt in enumerate(texts):
22
+ payload = {"text": txt, "model_id": model_id, "voice_settings": {"stability": 0.4, "similarity_boost": 0.7}}
23
+ url = f"{BASE}/text-to-speech/{voice_id}"
24
+ r = requests.post(url, headers=_headers(), json=payload)
25
+ if r.status_code != 200:
26
+ print("TTS error", r.status_code, r.text[:200]); continue
27
+ mp3_path = os.path.join(out_dir, f"elab_{i:04d}.mp3")
28
+ with open(mp3_path, "wb") as f:
29
+ f.write(r.content)
30
+ print("saved", mp3_path)
31
+ print("Done. Convert MP3 to WAV (16kHz mono) before training.")
32
+
33
+ def check_ai_speech(audio_bytes: bytes) -> dict:
34
+ """Stub: if your plan exposes classifier API, call it here; else returns unsupported."""
35
+ return {"supported": False, "prob_ai": None, "provider": "elevenlabs", "note": "Classifier not enabled in this template."}
36
+
37
+ # if __name__ == "__main__":
38
+ # generate_tts_dataset()
app/inference_wav2vec.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .models.wav2vec_detector import Wav2VecClassifier
7
+ from .utils.audio import load_audio, pad_or_trim, TARGET_SR
8
+
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+
11
+ # ---------- Thresholds & biases ----------
12
+ AI_THRESHOLD_DEFAULT = float(os.getenv("DETECTOR_AI_THRESHOLD", "0.60"))
13
+ MIC_THRESHOLD = float(os.getenv("DETECTOR_MIC_THRESHOLD", "0.68"))
14
+ UPLOAD_THRESHOLD = float(os.getenv("DETECTOR_UPLOAD_THRESHOLD", str(AI_THRESHOLD_DEFAULT)))
15
+ AI_LOGIT_BIAS = float(os.getenv("DETECTOR_AI_LOGIT_BIAS", "0.00")) # add to AI logit globally
16
+
17
+ # ---------- Decision rule ----------
18
+ # 'threshold' -> AI if ai_prob >= threshold
19
+ # 'argmax' -> AI if ai_prob > human_prob
20
+ # 'hybrid' -> threshold, but if replay_score >= T1 and ai_prob >= 0.50 -> AI
21
+ DECISION_RULE = os.getenv("DECISION_RULE", "threshold").lower()
22
+
23
+ # ---------- Replay-attack heuristic ----------
24
+ REPLAY_ENABLE = os.getenv("REPLAY_ENABLE", "1") != "0"
25
+ REPLAY_AI_BONUS = float(os.getenv("REPLAY_AI_BONUS", "1.2"))
26
+ REPLAY_FORCE_LABEL = os.getenv("REPLAY_FORCE_LABEL", "0") == "1"
27
+ REPLAY_T1 = float(os.getenv("REPLAY_T1", "0.35")) # soft start
28
+ REPLAY_T2 = float(os.getenv("REPLAY_T2", "0.55")) # strong replay
29
+
30
+ # ---------- DSP helpers ----------
31
+ def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray:
32
+ m = float(np.max(np.abs(y)) + eps)
33
+ return (y / m) * peak if m > 0 else y
34
+
35
+ def rms_normalize(y: np.ndarray, target_rms: float = 0.03, eps: float = 1e-9) -> np.ndarray:
36
+ rms = float(np.sqrt(np.mean(y**2)) + eps)
37
+ g = target_rms / rms
38
+ return np.clip(y * g, -1.0, 1.0)
39
+
40
+ def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray:
41
+ if y.size == 0: return y
42
+ win = max(1, int(sr * 0.02))
43
+ pad = max(1, int(sr * (min_ms / 1000.0)))
44
+ energy = np.convolve(y ** 2, np.ones(win) / win, mode="same")
45
+ ref = np.max(energy) + 1e-12
46
+ mask = 10.0 * np.log10(energy / ref + 1e-12) > -thresh_db
47
+ if not np.any(mask): return y
48
+ idx = np.where(mask)[0]
49
+ start = max(0, int(idx[0] - pad))
50
+ end = min(len(y), int(idx[-1] + pad))
51
+ return y[start:end]
52
+
53
+ def noise_gate(y, sr, gate_db=-42.0):
54
+ m = np.max(np.abs(y)) + 1e-9
55
+ thr = m * (10.0 ** (gate_db / 20.0))
56
+ y2 = y.copy()
57
+ y2[np.abs(y2) < thr] = 0.0
58
+ return y2
59
+
60
+ def bandpass_fft(y: np.ndarray, sr: int, low=100.0, high=3800.0):
61
+ n = int(2 ** np.ceil(np.log2(len(y) + 1)))
62
+ Y = np.fft.rfft(y, n=n)
63
+ freqs = np.fft.rfftfreq(n, d=1.0/sr)
64
+ mask = (freqs >= low) & (freqs <= high)
65
+ Y_filtered = Y * mask
66
+ y_filt = np.fft.irfft(Y_filtered, n=n)[:len(y)]
67
+ return y_filt.astype(np.float32, copy=False)
68
+
69
+ # ---------- Replay score ----------
70
+ def replay_score(y: np.ndarray, sr: int) -> float:
71
+ if len(y) < sr:
72
+ y = np.pad(y, (0, sr - len(y)))
73
+ N = 4096
74
+ if len(y) < N:
75
+ y = np.pad(y, (0, N - len(y)))
76
+ w = np.hanning(N)
77
+ seg = y[:N] * w
78
+ X = np.abs(np.fft.rfft(seg)) + 1e-9
79
+
80
+ cep = np.fft.irfft(np.log(X))
81
+ qmin = max(1, int(0.0003 * sr))
82
+ qmax = min(len(cep) - 1, int(0.0040 * sr))
83
+ cwin = np.abs(cep[qmin:qmax])
84
+ c_peak = float(np.max(cwin)); c_mean = float(np.mean(cwin) + 1e-9)
85
+ cep_score = np.clip((c_peak - c_mean) / (c_peak + c_mean), 0.0, 1.0)
86
+
87
+ F = np.fft.rfftfreq(N, 1.0 / sr)
88
+ total = float(np.sum(X))
89
+ hf = float(np.sum(X[F >= 5000.0]))
90
+ hf_ratio = hf / (total + 1e-9)
91
+ hf_term = np.clip((0.25 - hf_ratio) / 0.25, 0.0, 1.0)
92
+
93
+ return float(np.clip(0.6 * cep_score + 0.4 * hf_term, 0.0, 1.0))
94
+
95
+ # ---------- Detector ----------
96
+ class Detector:
97
+ def __init__(self, weights_path: str, encoder: str | None = None, unfreeze_last: int = 0):
98
+ cfg = None
99
+ js = weights_path.replace(".pth", ".json")
100
+ if os.path.exists(js):
101
+ try:
102
+ with open(js, "r", encoding="utf-8") as f:
103
+ cfg = json.load(f)
104
+ except Exception:
105
+ cfg = None
106
+ enc = encoder or (cfg.get("encoder") if cfg else "facebook/wav2vec2-base")
107
+ unf = unfreeze_last or int(cfg.get("unfreeze_last", 0)) if cfg else 0
108
+
109
+ self.model = Wav2VecClassifier(encoder=enc, unfreeze_last=unf).to(DEVICE)
110
+ if weights_path and os.path.exists(weights_path):
111
+ state = torch.load(weights_path, map_location=DEVICE)
112
+ self.model.load_state_dict(state, strict=False)
113
+ self.model.eval()
114
+
115
+ def _preprocess(self, y: np.ndarray, sr: int, source_hint: str | None):
116
+ y = trim_silence(y, sr, 40.0, 30)
117
+ y = bandpass_fft(y, sr, 100.0, 3800.0)
118
+ if source_hint and source_hint.lower().startswith("micro"):
119
+ y = noise_gate(y, sr, -42.0)
120
+ y = rms_normalize(y, 0.035)
121
+ y = peak_normalize(y, 0.95)
122
+ else:
123
+ y = rms_normalize(y, 0.03)
124
+ y = peak_normalize(y, 0.95)
125
+ y = pad_or_trim(y, duration_s=3.0, sr=sr)
126
+ return y
127
+
128
+ @torch.inference_mode()
129
+ def predict_proba(self, wav_bytes_or_path, source_hint: str | None = None):
130
+ y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR)
131
+ rscore = replay_score(y0, sr) if REPLAY_ENABLE else 0.0
132
+
133
+ y = self._preprocess(y0, sr, source_hint)
134
+ x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE)
135
+
136
+ logits, _ = self.model(x)
137
+ logits = logits.clone()
138
+ logits[:, 1] += AI_LOGIT_BIAS
139
+
140
+ # Replay bonus on AI logit
141
+ if REPLAY_ENABLE and (source_hint and source_hint.lower().startswith("micro")) and (rscore >= REPLAY_T1):
142
+ ramp = np.clip((rscore - REPLAY_T1) / max(REPLAY_T2 - REPLAY_T1, 1e-6), 0.0, 1.0)
143
+ logits[:, 1] += REPLAY_AI_BONUS * ramp
144
+
145
+ probs = F.softmax(logits, dim=-1).cpu().numpy()[0]
146
+ p_h, p_ai = float(probs[0]), float(probs[1])
147
+
148
+ thr_source = "mic" if (source_hint and source_hint.lower().startswith("micro")) else "upload"
149
+ thr = MIC_THRESHOLD if thr_source == "mic" else UPLOAD_THRESHOLD
150
+
151
+ # Labels by different rules
152
+ label_thresh = "ai" if p_ai >= thr else "human"
153
+ label_argmax = "ai" if p_ai > p_h else "human"
154
+ label_hybrid = label_thresh
155
+ if REPLAY_ENABLE and rscore >= REPLAY_T1 and p_ai >= 0.50:
156
+ label_hybrid = "ai"
157
+ if REPLAY_ENABLE and rscore >= REPLAY_T2 and (source_hint and source_hint.lower().startswith("micro")):
158
+ if REPLAY_FORCE_LABEL or p_ai >= (thr - 0.05):
159
+ label_hybrid = "ai"
160
+
161
+ if DECISION_RULE == "argmax":
162
+ label = label_argmax
163
+ rule_used = "argmax"
164
+ elif DECISION_RULE == "hybrid":
165
+ label = label_hybrid
166
+ rule_used = "hybrid(threshold+replay)"
167
+ else:
168
+ label = label_thresh
169
+ rule_used = "threshold"
170
+
171
+ return {
172
+ "human": p_h,
173
+ "ai": p_ai,
174
+ "label": label,
175
+ "threshold": float(thr),
176
+ "threshold_source": thr_source,
177
+ "backend": "wav2vec2",
178
+ "source_hint": (source_hint or "auto"),
179
+ "replay_score": float(rscore),
180
+ "decision": rule_used,
181
+ "decision_details": {
182
+ "ai_prob": p_ai,
183
+ "human_prob": p_h,
184
+ "prob_margin": p_ai - p_h,
185
+ "ai_vs_threshold_margin": p_ai - thr,
186
+ "replay_score": rscore,
187
+ "mic_threshold": MIC_THRESHOLD,
188
+ "upload_threshold": UPLOAD_THRESHOLD,
189
+ "force_label_AI": bool(REPLAY_FORCE_LABEL and rscore >= REPLAY_T2),
190
+ },
191
+ }
192
+
193
+ def explain(self, wav_bytes_or_path, source_hint: str | None = None):
194
+ self.model.eval()
195
+ y0, sr = load_audio(wav_bytes_or_path, target_sr=TARGET_SR)
196
+ y = self._preprocess(y0, sr, source_hint)
197
+ x = torch.from_numpy(y).float().unsqueeze(0).to(DEVICE)
198
+ x.requires_grad_(True)
199
+ logits, feats = self.model(x)
200
+ logits[:, 1].sum().backward(retain_graph=True)
201
+ if feats.grad is None:
202
+ s = x.grad.detach().abs().squeeze(0)
203
+ s = s / (s.max() + 1e-6)
204
+ H = 64
205
+ step = max(1, s.numel() // 256)
206
+ s_small = s[::step][:256].cpu().numpy()
207
+ cam = np.tile(s_small[None, :], (H, 1))
208
+ else:
209
+ g = feats.grad.detach().abs().sum(dim=-1).squeeze(0)
210
+ g = g / (g.max() + 1e-6)
211
+ H = 64
212
+ cam = np.tile(g.cpu().numpy()[None, :], (H, 1))
213
+ cam = (cam - cam.min()) / (cam.max() - cam.min() + 1e-6)
214
+ return {"cam": cam.tolist(), "probs": None}
app/train.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import argparse
3
+ import random
4
+ from pathlib import Path
5
+ from contextlib import nullcontext
6
+ import importlib.util
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import torch.backends.cudnn as cudnn
12
+ from torch.utils.data import Dataset, DataLoader
13
+
14
+ # ---------- Local imports ----------
15
+ try:
16
+ from .models.cnn_melspec import TinyMelCNN
17
+ from .utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR
18
+ except ImportError:
19
+ from app.models.cnn_melspec import TinyMelCNN
20
+ from app.utils.audio import load_audio, pad_or_trim, logmel, TARGET_SR
21
+
22
+ # ---------- Augmentations (robust across versions) ----------
23
+ from audiomentations import (
24
+ Compose, AddGaussianNoise, TimeStretch, PitchShift, BandPassFilter
25
+ )
26
+
27
+ def make_gain(min_db, max_db, p):
28
+ """Handle both min_gain_in_db/max_gain_in_db and min_gain_db/max_gain_db."""
29
+ from audiomentations import Gain as _Gain
30
+ try:
31
+ return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p)
32
+ except TypeError:
33
+ return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p)
34
+
35
+ def make_clipping(p=0.3):
36
+ """
37
+ Build ClippingDistortion across versions.
38
+ Newer: min_percent/max_percent (0..20 typical)
39
+ Older: min_percentile_threshold/max_percentile_threshold in [0..100]
40
+ Returns None if not available.
41
+ """
42
+ try:
43
+ from audiomentations import ClippingDistortion as _Clip
44
+ except Exception:
45
+ return None
46
+
47
+ # Try newer signature
48
+ for kwargs in (
49
+ dict(min_percent=0.0, max_percent=20.0, p=p),
50
+ dict(min_percent=5.0, max_percent=30.0, p=p),
51
+ ):
52
+ try:
53
+ return _Clip(**kwargs)
54
+ except Exception:
55
+ pass
56
+
57
+ # Try older signature
58
+ for kwargs in (
59
+ dict(min_percentile_threshold=95, max_percentile_threshold=100, p=p),
60
+ dict(min_percentile_threshold=90, max_percentile_threshold=99, p=p),
61
+ ):
62
+ try:
63
+ return _Clip(**kwargs)
64
+ except Exception:
65
+ pass
66
+
67
+ return None
68
+
69
+ def have_fast_mp3():
70
+ return importlib.util.find_spec("fast_mp3_augment") is not None
71
+
72
+ def make_mp3_compression(min_bitrate=48, max_bitrate=96, p=0.6):
73
+ """
74
+ Only enable Mp3Compression when the fast backend is present.
75
+ On Windows without the extra package this often breaks; we skip it.
76
+ """
77
+ if not have_fast_mp3():
78
+ return None
79
+ try:
80
+ from audiomentations import Mp3Compression as _Mp3
81
+ # Prefer the fast backend; if API lacks backend arg, constructor still works.
82
+ try:
83
+ return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p, backend="fast_mp3_augment")
84
+ except TypeError:
85
+ return _Mp3(min_bitrate=min_bitrate, max_bitrate=max_bitrate, p=p)
86
+ except Exception:
87
+ return None
88
+
89
+ # ---------- Repro ----------
90
+ def set_seed(seed: int = 42):
91
+ random.seed(seed); np.random.seed(seed)
92
+ torch.manual_seed(seed); torch.cuda.manual_seed_all(seed)
93
+
94
+ # ---------- Dataset ----------
95
+ class FolderDataset(Dataset):
96
+ """
97
+ data_dir/
98
+ human/*.wav
99
+ ai/*.wav
100
+ """
101
+ def __init__(self, root: str, split: str = "train", val_ratio: float = 0.15,
102
+ seed: int = 42, clip_seconds: float = 3.0):
103
+ self.root = Path(root)
104
+ self.clip_seconds = float(clip_seconds)
105
+
106
+ human = sorted((self.root / "human").glob("*.wav"))
107
+ ai = sorted((self.root / "ai").glob("*.wav"))
108
+ pairs = [(p, 0) for p in human] + [(p, 1) for p in ai]
109
+
110
+ rng = random.Random(seed)
111
+ rng.shuffle(pairs)
112
+
113
+ n_val = int(len(pairs) * val_ratio)
114
+ self.items = pairs[n_val:] if split == "train" else pairs[:n_val]
115
+ self.is_train = split == "train"
116
+
117
+ self._len_h = sum(1 for _, y in self.items if y == 0)
118
+ self._len_a = sum(1 for _, y in self.items if y == 1)
119
+
120
+ # Human: mild, natural perturbations
121
+ self.aug_human = Compose([
122
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.4),
123
+ TimeStretch(min_rate=0.96, max_rate=1.04, p=0.3),
124
+ PitchShift(min_semitones=-1, max_semitones=1, p=0.2),
125
+ make_gain(-4, 4, p=0.3),
126
+ ])
127
+
128
+ # AI: replay-aware chain (speaker/room/mic simulation)
129
+ ai_transforms = [
130
+ BandPassFilter(min_center_freq=200.0, max_center_freq=3500.0, p=0.5),
131
+ AddGaussianNoise(min_amplitude=0.001, max_amplitude=0.01, p=0.3),
132
+ TimeStretch(min_rate=0.95, max_rate=1.05, p=0.25),
133
+ make_gain(-6, 6, p=0.3),
134
+ ]
135
+ clip = make_clipping(p=0.3)
136
+ if clip is not None:
137
+ ai_transforms.insert(1, clip)
138
+ mp3 = make_mp3_compression()
139
+ if mp3 is not None:
140
+ ai_transforms.insert(0, mp3)
141
+
142
+ self.aug_ai = Compose(ai_transforms)
143
+
144
+ def __len__(self):
145
+ return len(self.items)
146
+
147
+ def __getitem__(self, idx: int):
148
+ path, label = self.items[idx]
149
+ y, sr = load_audio(str(path), TARGET_SR)
150
+ y = pad_or_trim(y, duration_s=self.clip_seconds, sr=sr)
151
+
152
+ if self.is_train:
153
+ if label == 1:
154
+ y = self.aug_ai(samples=y, sample_rate=sr)
155
+ else:
156
+ y = self.aug_human(samples=y, sample_rate=sr)
157
+
158
+ mel = logmel(y, sr) # (n_mels, T)
159
+ x = torch.from_numpy(mel).unsqueeze(0) # (1, n_mels, T)
160
+ y_t = torch.tensor(label, dtype=torch.long)
161
+ return x, y_t
162
+
163
+ # ---------- Dataloaders ----------
164
+ def make_dataloaders(args):
165
+ ds_tr = FolderDataset(args.data_dir, split="train", val_ratio=args.val_ratio,
166
+ seed=args.seed, clip_seconds=args.clip_seconds)
167
+ ds_va = FolderDataset(args.data_dir, split="val", val_ratio=args.val_ratio,
168
+ seed=args.seed, clip_seconds=args.clip_seconds)
169
+
170
+ # Windows is happier with workers=0; keep configurable
171
+ workers = args.workers if args.workers >= 0 else (0 if os.name == "nt" else max(1, (os.cpu_count() or 4)//2))
172
+ pin = (not args.cpu) and torch.cuda.is_available()
173
+
174
+ dl_tr = DataLoader(
175
+ ds_tr, batch_size=args.batch_size, shuffle=True,
176
+ num_workers=workers, pin_memory=pin,
177
+ persistent_workers=(workers > 0), drop_last=True,
178
+ )
179
+ dl_va = DataLoader(
180
+ ds_va, batch_size=max(1, args.batch_size // 2), shuffle=False,
181
+ num_workers=workers, pin_memory=pin,
182
+ persistent_workers=(workers > 0),
183
+ )
184
+ return ds_tr, ds_va, dl_tr, dl_va
185
+
186
+ def class_weights_from_dataset(ds: FolderDataset, eps: float = 1e-6):
187
+ n_h, n_a = max(ds._len_h, eps), max(ds._len_a, eps)
188
+ w_h = (n_h + n_a) / (2 * n_h)
189
+ w_a = (n_h + n_a) / (2 * n_a)
190
+ return torch.tensor([w_h, w_a], dtype=torch.float32)
191
+
192
+ # ---------- Training / Eval ----------
193
+ def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1):
194
+ model.train()
195
+ total_loss = 0.0
196
+ correct = 0
197
+ seen = 0
198
+ opt.zero_grad(set_to_none=True)
199
+
200
+ for step, (x, y) in enumerate(dl):
201
+ x = x.to(device, non_blocking=True)
202
+ y = y.to(device, non_blocking=True)
203
+
204
+ with autocast_ctx:
205
+ logits = model(x)
206
+ loss = loss_fn(logits, y)
207
+
208
+ loss = loss / grad_accum
209
+ if getattr(scaler, "is_enabled", lambda: False)():
210
+ scaler.scale(loss).backward()
211
+ else:
212
+ loss.backward()
213
+
214
+ if (step + 1) % grad_accum == 0:
215
+ if getattr(scaler, "is_enabled", lambda: False)():
216
+ scaler.step(opt)
217
+ scaler.update()
218
+ else:
219
+ opt.step()
220
+ opt.zero_grad(set_to_none=True)
221
+
222
+ total_loss += float(loss) * x.size(0) * grad_accum
223
+ correct += int((logits.argmax(1) == y).sum().item())
224
+ seen += x.size(0)
225
+
226
+ return total_loss / max(seen, 1), correct / max(seen, 1)
227
+
228
+ @torch.no_grad()
229
+ def evaluate(model, dl, device, loss_fn):
230
+ model.eval()
231
+ total_loss = 0.0
232
+ correct = 0
233
+ seen = 0
234
+ for x, y in dl:
235
+ x = x.to(device, non_blocking=True)
236
+ y = y.to(device, non_blocking=True)
237
+ logits = model(x)
238
+ loss = loss_fn(logits, y)
239
+ total_loss += float(loss) * x.size(0)
240
+ correct += int((logits.argmax(1) == y).sum().item())
241
+ seen += x.size(0)
242
+ return total_loss / max(seen, 1), correct / max(seen, 1)
243
+
244
+ def main(args):
245
+ set_seed(args.seed)
246
+ device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu"
247
+ cudnn.benchmark = True
248
+
249
+ ds_tr, ds_va, dl_tr, dl_va = make_dataloaders(args)
250
+ print(f"Train items: {len(ds_tr)} (human={ds_tr._len_h}, ai={ds_tr._len_a})")
251
+ print(f"Val items: {len(ds_va)}")
252
+
253
+ model = TinyMelCNN().to(device)
254
+
255
+ weights = class_weights_from_dataset(ds_tr).to(device)
256
+ loss_fn = torch.nn.CrossEntropyLoss(weight=weights)
257
+
258
+ opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=1e-4)
259
+
260
+ # AMP (use new torch.amp if available, else fallback)
261
+ try:
262
+ from torch.amp import GradScaler, autocast as amp_autocast
263
+ scaler = GradScaler("cuda", enabled=(device == "cuda" and args.amp))
264
+ autocast_ctx = amp_autocast("cuda") if (device == "cuda" and args.amp) else nullcontext()
265
+ except Exception:
266
+ from torch.cuda.amp import GradScaler, autocast as amp_autocast # deprecated but works
267
+ scaler = GradScaler(enabled=(device == "cuda" and args.amp))
268
+ autocast_ctx = amp_autocast() if (device == "cuda" and args.amp) else nullcontext()
269
+
270
+ best_va = -1.0
271
+ patience_counter = 0
272
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
273
+
274
+ for epoch in range(args.epochs):
275
+ tr_loss, tr_acc = train_one_epoch(
276
+ model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn,
277
+ grad_accum=args.grad_accum
278
+ )
279
+ va_loss, va_acc = evaluate(model, dl_va, device, loss_fn)
280
+
281
+ print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")
282
+
283
+ # Save "last" every epoch
284
+ torch.save(model.state_dict(), args.out.replace(".pth", ".last.pth"))
285
+
286
+ if va_acc > best_va + 1e-4:
287
+ best_va = va_acc
288
+ torch.save(model.state_dict(), args.out)
289
+ patience_counter = 0
290
+ print(f"✅ Saved best to {args.out} (val_acc={best_va:.3f})")
291
+ else:
292
+ patience_counter += 1
293
+ if args.early_stop > 0 and patience_counter >= args.early_stop:
294
+ print(f"⏹️ Early stopping at epoch {epoch+1} (best val_acc={best_va:.3f})")
295
+ break
296
+
297
+ print("Done.")
298
+
299
+ if __name__ == "__main__":
300
+ p = argparse.ArgumentParser(description="Train AI Voice Detector (replay-aware, version-robust, no fast_mp3 required)")
301
+ p.add_argument("--data_dir", type=str, required=True, help="Folder with subfolders human/ and ai/")
302
+ p.add_argument("--out", type=str, default="app/models/weights/cnn_melspec.pth")
303
+ p.add_argument("--epochs", type=int, default=10)
304
+ p.add_argument("--batch_size", type=int, default=32)
305
+ p.add_argument("--grad_accum", type=int, default=2)
306
+ p.add_argument("--lr", type=float, default=1e-3)
307
+ p.add_argument("--val_ratio", type=float, default=0.15)
308
+ p.add_argument("--clip_seconds", type=float, default=3.0)
309
+ p.add_argument("--workers", type=int, default=-1) # try --workers 0 on Windows if you see issues
310
+ p.add_argument("--amp", action="store_true", default=True)
311
+ p.add_argument("--cpu", action="store_true")
312
+ p.add_argument("--early_stop", type=int, default=0)
313
+ p.add_argument("--seed", type=int, default=42)
314
+ args = p.parse_args()
315
+ main(args)
app/train_wav2vec.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, argparse, random
2
+ from pathlib import Path
3
+ from contextlib import nullcontext
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torch.backends.cudnn as cudnn
9
+ from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
10
+
11
+ try:
12
+ from .models.wav2vec_detector import Wav2VecClassifier
13
+ from .utils.audio import load_audio, pad_or_trim, TARGET_SR
14
+ except ImportError:
15
+ from app.models.wav2vec_detector import Wav2VecClassifier
16
+ from app.utils.audio import load_audio, pad_or_trim, TARGET_SR
17
+
18
+ from audiomentations import Compose, AddGaussianNoise, BandPassFilter
19
+ def make_gain(min_db, max_db, p):
20
+ from audiomentations import Gain as _Gain
21
+ try: return _Gain(min_gain_in_db=min_db, max_gain_in_db=max_db, p=p)
22
+ except TypeError: return _Gain(min_gain_db=min_db, max_gain_db=max_db, p=p)
23
+
24
+ def set_seed(s=42):
25
+ random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
26
+
27
+ def peak_normalize(y: np.ndarray, peak: float = 0.95, eps: float = 1e-9) -> np.ndarray:
28
+ m = float(np.max(np.abs(y)) + eps)
29
+ return (y / m) * peak if m > 0 else y
30
+
31
+ def trim_silence(y: np.ndarray, sr: int, thresh_db: float = 40.0, min_ms: int = 30) -> np.ndarray:
32
+ if y.size == 0: return y
33
+ win = max(1, int(sr * 0.02)); pad = max(1, int(sr * (min_ms / 1000.0)))
34
+ energy = np.convolve(y**2, np.ones(win)/win, mode="same")
35
+ ref = np.max(energy) + 1e-12
36
+ mask = 10*np.log10(energy/ref + 1e-12) > -thresh_db
37
+ if not np.any(mask): return y
38
+ idx = np.where(mask)[0]; start=max(0,int(idx[0]-pad)); end=min(len(y),int(idx[-1]+pad))
39
+ return y[start:end]
40
+
41
+ class WavDataset(Dataset):
42
+ """data_dir/{human,ai}/*.wav"""
43
+ def __init__(self, root, split="train", val_ratio=0.15, seed=42, clip_seconds=3.0):
44
+ self.root = Path(root); self.clip = float(clip_seconds)
45
+ human = sorted((self.root/"human").glob("*.wav"))
46
+ ai = sorted((self.root/"ai").glob("*.wav"))
47
+ items = [(p,0) for p in human] + [(p,1) for p in ai]
48
+ rng = random.Random(seed); rng.shuffle(items)
49
+ n_val = int(len(items)*val_ratio)
50
+ self.items = items[n_val:] if split=="train" else items[:n_val]
51
+ self.is_train = split=="train"
52
+ self.nh = sum(1 for _,y in self.items if y==0)
53
+ self.na = sum(1 for _,y in self.items if y==1)
54
+ self.aug_h = Compose([AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-4,4,p=0.3)])
55
+ self.aug_a = Compose([BandPassFilter(200.0,3500.0,p=0.5), AddGaussianNoise(0.001,0.01,p=0.3), make_gain(-6,6,p=0.3)])
56
+
57
+ def __len__(self): return len(self.items)
58
+ def __getitem__(self, idx):
59
+ path, label = self.items[idx]
60
+ y, sr = load_audio(str(path), TARGET_SR)
61
+ y = trim_silence(y, sr, 40.0, 30)
62
+ y = peak_normalize(y, 0.95)
63
+ y = pad_or_trim(y, duration_s=self.clip, sr=sr)
64
+ if self.is_train:
65
+ y = (self.aug_a if label==1 else self.aug_h)(samples=y, sample_rate=sr)
66
+ return torch.from_numpy(y).float(), torch.tensor(label, dtype=torch.long)
67
+
68
+ def make_loaders(args):
69
+ ds_tr = WavDataset(args.data_dir, "train", args.val_ratio, args.seed, args.clip_seconds)
70
+ ds_va = WavDataset(args.data_dir, "val", args.val_ratio, args.seed, args.clip_seconds)
71
+
72
+ # Weighted sampler to balance classes
73
+ labels = [y for _, y in ds_tr.items]
74
+ n0 = max(1, labels.count(0)); n1 = max(1, labels.count(1))
75
+ w0 = (n0 + n1) / (2 * n0); w1 = (n0 + n1) / (2 * n1)
76
+ sample_weights = [w0 if y == 0 else w1 for y in labels]
77
+ sampler = WeightedRandomSampler(sample_weights, num_samples=len(labels), replacement=True)
78
+
79
+ workers = args.workers if args.workers >= 0 else (0 if os.name=="nt" else max(1,(os.cpu_count() or 4)//2))
80
+ pin = (not args.cpu) and torch.cuda.is_available()
81
+ dl_tr = DataLoader(ds_tr, batch_size=args.batch_size, sampler=sampler,
82
+ num_workers=workers, pin_memory=pin, drop_last=True)
83
+ dl_va = DataLoader(ds_va, batch_size=max(1,args.batch_size//2), shuffle=False,
84
+ num_workers=workers, pin_memory=pin)
85
+ return ds_tr, ds_va, dl_tr, dl_va
86
+
87
+ class FocalLoss(torch.nn.Module):
88
+ def __init__(self, alpha=None, gamma=1.5):
89
+ super().__init__()
90
+ self.alpha = alpha
91
+ self.gamma = gamma
92
+ self.ce = torch.nn.CrossEntropyLoss(weight=alpha)
93
+ def forward(self, logits, target):
94
+ ce = self.ce(logits, target)
95
+ with torch.no_grad():
96
+ pt = torch.exp(-ce)
97
+ return ((1 - pt) ** self.gamma) * ce
98
+
99
+ def train_one_epoch(model, dl, device, opt, scaler, autocast_ctx, loss_fn, grad_accum=1):
100
+ model.train(); total=0.0; correct=0; seen=0
101
+ opt.zero_grad(set_to_none=True)
102
+ for step,(x,y) in enumerate(dl):
103
+ x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
104
+ with autocast_ctx:
105
+ logits,_=model(x); loss=loss_fn(logits,y)
106
+ loss=loss/grad_accum
107
+ if getattr(scaler,"is_enabled",lambda:False)(): scaler.scale(loss).backward()
108
+ else: loss.backward()
109
+ if (step+1)%grad_accum==0:
110
+ if getattr(scaler,"is_enabled",lambda:False)():
111
+ scaler.step(opt); scaler.update()
112
+ else:
113
+ opt.step()
114
+ opt.zero_grad(set_to_none=True)
115
+ total += float(loss) * x.size(0) * grad_accum
116
+ correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0)
117
+ return total/max(seen,1), correct/max(seen,1)
118
+
119
+ @torch.no_grad()
120
+ def evaluate(model, dl, device, loss_fn):
121
+ model.eval(); total=0.0; correct=0; seen=0
122
+ for x,y in dl:
123
+ x=x.to(device,non_blocking=True); y=y.to(device,non_blocking=True)
124
+ logits,_=model(x); loss=loss_fn(logits,y)
125
+ total += float(loss) * x.size(0); correct += int((logits.argmax(1)==y).sum().item()); seen += x.size(0)
126
+ return total/max(seen,1), correct/max(seen,1)
127
+
128
+ def main(args):
129
+ set_seed(args.seed)
130
+ device = "cuda" if (torch.cuda.is_available() and not args.cpu) else "cpu"
131
+ cudnn.benchmark = True
132
+
133
+ ds_tr, ds_va, dl_tr, dl_va = make_loaders(args)
134
+ print(f"Train items: {len(ds_tr)} (human={ds_tr.nh}, ai={ds_tr.na})")
135
+ print(f"Val items: {len(ds_va)}")
136
+
137
+ model = Wav2VecClassifier(
138
+ encoder=args.encoder,
139
+ unfreeze_last=args.unfreeze_last,
140
+ dropout=args.dropout,
141
+ hidden=args.hidden
142
+ ).to(device)
143
+
144
+ # Focal loss with class weights
145
+ nh, na = ds_tr.nh, ds_tr.na
146
+ w = torch.tensor([(nh+na)/(2*nh+1e-6), (nh+na)/(2*na+1e-6)], dtype=torch.float32).to(device)
147
+ loss_fn = FocalLoss(alpha=w, gamma=1.5)
148
+
149
+ head_params = list(model.head.parameters())
150
+ enc_params = [p for p in model.encoder.parameters() if p.requires_grad]
151
+ param_groups = [{"params": head_params, "lr": args.lr_head}]
152
+ if enc_params:
153
+ param_groups.append({"params": enc_params, "lr": args.lr_encoder})
154
+ opt = torch.optim.AdamW(param_groups, weight_decay=1e-4)
155
+
156
+ try:
157
+ from torch.amp import GradScaler, autocast as amp_autocast
158
+ scaler = GradScaler("cuda", enabled=(device=="cuda" and args.amp))
159
+ autocast_ctx = amp_autocast("cuda") if (device=="cuda" and args.amp) else nullcontext()
160
+ except Exception:
161
+ from torch.cuda.amp import GradScaler, autocast as amp_autocast
162
+ scaler = GradScaler(enabled=(device=="cuda" and args.amp))
163
+ autocast_ctx = amp_autocast() if (device=="cuda" and args.amp) else nullcontext()
164
+
165
+ best=-1.0; patience=0
166
+ Path(args.out).parent.mkdir(parents=True, exist_ok=True)
167
+ with open(args.out.replace(".pth",".json"), "w", encoding="utf-8") as f:
168
+ json.dump({"encoder": args.encoder, "unfreeze_last": args.unfreeze_last}, f)
169
+
170
+ for epoch in range(args.epochs):
171
+ tr_loss, tr_acc = train_one_epoch(model, dl_tr, device, opt, scaler, autocast_ctx, loss_fn, args.grad_accum)
172
+ va_loss, va_acc = evaluate(model, dl_va, device, loss_fn)
173
+ print(f"epoch {epoch+1:02d}/{args.epochs} | train {tr_loss:.3f}/{tr_acc:.3f} | val {va_loss:.3f}/{va_acc:.3f}")
174
+ torch.save(model.state_dict(), args.out.replace(".pth",".last.pth"))
175
+ if va_acc > best + 1e-4:
176
+ best = va_acc; patience=0
177
+ torch.save(model.state_dict(), args.out)
178
+ print(f"✅ Saved best to {args.out} (val_acc={best:.3f})")
179
+ else:
180
+ patience += 1
181
+ if args.early_stop>0 and patience>=args.early_stop:
182
+ print(f"⏹️ Early stopping at epoch {epoch+1} (best={best:.3f})")
183
+ break
184
+ print("Done.")
185
+
186
+ if __name__ == "__main__":
187
+ ap = argparse.ArgumentParser(description="Train Wav2Vec2-based AI Voice Detector (balanced)")
188
+ ap.add_argument("--data_dir", required=True, help="Folder with human/ and ai/ WAVs")
189
+ ap.add_argument("--out", default="app/models/weights/wav2vec2_classifier.pth")
190
+ ap.add_argument("--encoder", default="facebook/wav2vec2-base")
191
+ ap.add_argument("--unfreeze_last", type=int, default=0)
192
+ ap.add_argument("--epochs", type=int, default=8)
193
+ ap.add_argument("--batch_size", type=int, default=16)
194
+ ap.add_argument("--grad_accum", type=int, default=2)
195
+ ap.add_argument("--lr_head", type=float, default=1e-3)
196
+ ap.add_argument("--lr_encoder", type=float, default=1e-5)
197
+ ap.add_argument("--val_ratio", type=float, default=0.15)
198
+ ap.add_argument("--clip_seconds", type=float, default=3.0)
199
+ ap.add_argument("--workers", type=int, default=-1)
200
+ ap.add_argument("--amp", action="store_true", default=True)
201
+ ap.add_argument("--cpu", action="store_true")
202
+ ap.add_argument("--dropout", type=float, default=0.2)
203
+ ap.add_argument("--hidden", type=int, default=256)
204
+ ap.add_argument("--early_stop", type=int, default=0)
205
+ ap.add_argument("--seed", type=int, default=42)
206
+ args = ap.parse_args()
207
+ main(args)