audio-edit / audiojob.py
liuyang
Enhance dual-mono detection in _stereo_decision_quick method of AudioJobRunner by implementing VAD-based similarity metrics, improving audio processing accuracy and robustness with fallback to previous methods if VAD is unavailable.
00d0703
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Audio preprocess runner (plan-only, fixed split with configurable overlap).
Design:
- Probe with ffprobe (duration/channels/codec) from a URL source.
- Quick stereo check on a short sample to decide split vs downmix.
- "Preprocess" is plan-only: record ingest recipe for the transcriber (16 kHz mono WAV).
- Split uses fixed windows with configurable overlap (virtual slices).
- No audio decoding/materialization on this node.
- Keeps the 'transcribe' stage (initialized for downstream use).
Note: source_uri is always a URL. No storage adapter.
"""
import os, re, json, time, hashlib, datetime as dt, subprocess, uuid, logging
from typing import Optional, Dict, Any, List, Tuple
logger = logging.getLogger(__name__)
# -------------------------- Utils ----------------------------
def utc_now_iso() -> str:
return dt.datetime.utcnow().replace(microsecond=0).isoformat() + "Z"
def clamp(v: float, lo: float, hi: float) -> float:
return max(lo, min(hi, v))
def run(cmd: List[str], timeout: Optional[int] = None) -> Tuple[int, str, str]:
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
try:
out, err = proc.communicate(timeout=timeout)
except subprocess.TimeoutExpired:
proc.kill(); out, err = proc.communicate(); return 124, out, err + "\nTIMEOUT"
return proc.returncode, out, err
def run_with_retry_collect(cmd: List[str], retries: int, timeout: Optional[int]) -> str:
"""Return combined out+err; good for commands that log to stderr (e.g., volumedetect)."""
last = ""
for i in range(max(1, retries)):
code, out, err = run(cmd, timeout)
txt = (out or "") + (err or "")
if code == 0: return txt
last = txt; time.sleep(1.2 ** i)
raise RuntimeError(f"Command failed: {' '.join(cmd)}\n{last}")
def run_json_stdout(cmd: List[str], retries: int, timeout: Optional[int]) -> str:
"""Return stdout only; avoids stderr noise breaking JSON parsing."""
last = ""
for i in range(max(1, retries)):
code, out, err = run(cmd, timeout)
if code == 0: return out or ""
last = (out or "") + (err or "")
time.sleep(1.2 ** i)
raise RuntimeError(f"Command failed: {' '.join(cmd)}\n{last}")
# ---------------------- Defaults / Presets -------------------
DEFAULT_PRESETS = {
"sample_rate_target": 16000,
"channel_policy": "auto", # auto | split | downmix
"stereo_side_mid_threshold_db": 20.0, # if SIDE <= MID - thr => safe to downmix
"chunk_target_ms": 15*60000, # fixed window size
"overlap_ms": 500, # configurable overlap
"ff_timeout_sec": 120,
"ff_retries": 2,
"dual_mono_side_mid_db": 15.0, # was 25.0; YT “near dual-mono” often sits ~15–20 dB down
"dual_mono_rms_delta_db": 1.0, # was 0.5; allow tiny channel imbalances
"dual_mono_corr": 0.90, # was 0.995; still gated by Side/Mid & RMS check
"corr_probe_ms": 30000, # cap correlation probe at 30s
"stereo_probe_win_s": 12, # each sample window length (sec)
}
# --------------------------- Runner --------------------------
class AudioJobRunner:
"""
Usage:
runner = AudioJobRunner(
manifest=None,
source_uri="https://example.com/audio.mp3",
work_root="/tmp/jobwork",
presets={"chunk_target_ms": 45000, "overlap_ms": 700}
)
manifest = runner.run_until_split()
"""
def __init__(
self,
manifest: Optional[Dict[str, Any]],
source_uri: Optional[str],
work_root: str,
presets: Optional[Dict[str, Any]] = None,
):
self.work_root = os.path.abspath(work_root)
os.makedirs(self.work_root, exist_ok=True)
self.presets = dict(DEFAULT_PRESETS)
if presets: self.presets.update(presets)
if manifest is None:
if not source_uri: raise ValueError("source_uri is required for a new job.")
self.manifest = self._init_manifest(source_uri)
else:
self.manifest = manifest
self.manifest.setdefault("version", "2.3")
self.manifest.setdefault("rev", 0)
self._touch()
# -------- Public API --------
def run_until_split(self) -> Dict[str, Any]:
try:
if self._stage("probe") != "done": self._run_probe()
if self._stage("preprocess") != "done": self._build_ingest_plan() # renamed from _run_preprocess
if self._stage("split") != "done": self._run_split_plan()
return self.manifest
except Exception as e:
for s in ("split","preprocess","probe"):
if self._stage(s) == "running":
self._set_stage(s, "failed", 0.0, {"last_error": str(e), "ended_at": utc_now_iso()})
break
raise
# -------- Manifest helpers --------
def _init_manifest(self, src: str) -> Dict[str, Any]:
jid = str(uuid.uuid4())
return {
"version":"2.3","rev":0,"job_id":jid,"created_at":utc_now_iso(),"updated_at":utc_now_iso(),
"source":{"uri":src,"container":None,"codec":None,"duration_ms":None,"sample_rate":None,"channels":None},
"policy": dict(self.presets),
"stages":{
"probe":{"status":"pending","progress":0.0},
"preprocess":{"status":"pending","progress":0.0}, # stage name kept for compatibility
"split":{"status":"pending","progress":0.0},
"transcribe":{"status":"pending","progress":0.0},
},
}
def _touch(self): self.manifest["updated_at"]=utc_now_iso(); self.manifest["rev"]=int(self.manifest.get("rev",0))+1
def _stage(self, name:str)->str: return self.manifest.get("stages",{}).get(name,{}).get("status","pending")
def _set_stage(self, n:str, st:str, prog:float, extra:Dict[str,Any]=None):
s=self.manifest["stages"].setdefault(n,{}); s["status"]=st; s["progress"]=clamp(prog,0.0,1.0)
(extra and s.update(extra)); self._touch()
# -------- Probe --------
def _run_probe(self):
self._set_stage("probe","running",0.05,{"started_at":utc_now_iso()})
src = self.manifest["source"]["uri"]
# ffprobe (stdout-only JSON)
txt = run_json_stdout(
["ffprobe","-v","error","-select_streams","a:0","-show_streams","-show_format","-of","json", src],
retries=self.manifest["policy"]["ff_retries"], timeout=self.manifest["policy"]["ff_timeout_sec"]
)
info=json.loads(txt); fmt=info.get("format",{}); streams=info.get("streams",[])
a=next((s for s in streams if s.get("codec_type")=="audio"),{})
self.manifest["source"].update({
"container": fmt.get("format_name"),
"codec": a.get("codec_name"),
"duration_ms": int(float(fmt.get("duration",0))*1000) if fmt.get("duration") else None,
"sample_rate": int(a.get("sample_rate",0)) if a.get("sample_rate") else None,
"channels": int(a.get("channels",0)) if a.get("channels") else None,
})
if self.manifest["source"]["duration_ms"] is None:
raise RuntimeError("Could not determine duration from ffprobe.")
# Stereo decision (quick, short sample)
policy=self.manifest["policy"]; ch=self.manifest["source"]["channels"] or 1
resolved="downmix"
metrics={}
if policy.get("channel_policy","auto")=="split":
resolved="split" if ch==2 else "downmix"
elif policy.get("channel_policy","auto")=="downmix":
resolved="downmix"
elif ch==2: # auto mode: quick L/R vs MID/SIDE check
resolved, metrics = self._stereo_decision_quick(src)
self.manifest["stages"]["probe"]["actions"] = {
"channel_policy": resolved,
"will_split_stereo": (resolved=="split" and ch==2),
"will_downmix": (resolved!="split")
}
self.manifest["stages"]["probe"]["metrics"]=metrics
self._set_stage("probe","done",1.0,{"ended_at":utc_now_iso()})
def _stereo_decision_quick(self, uri: str) -> Tuple[str, Dict[str, Any]]:
"""
Decide split vs downmix using VAD-mask similarity across channels.
- Probe 3 short windows (20%, 50%, 80%) at 16 kHz stereo
- Build per-channel VAD bitmasks (30 ms frames)
- Allow small ±lag alignment, take best similarity per window
- Dual-mono if median(similarity) >= vad_sim_thr
Falls back to the previous mid/side RMS method if WebRTC VAD is unavailable.
Returns:
("split"|"downmix", metrics_dict)
"""
pol = self.manifest["policy"]
durms = int(self.manifest["source"].get("duration_ms") or 0)
dur_s = max(1, durms // 1000)
# VAD params (defaults if not present in policy)
vad_aggr = int(pol.get("vad_aggressiveness", 2)) # 0..3
vad_frame_ms = 30 # keep 30ms (supported by webrtcvad)
vad_sim_thr = float(pol.get("vad_similarity_thr", 0.95))
vad_max_lag = int(pol.get("vad_max_lag_frames", 1)) # allow ±1 frame (±30 ms) misalignment
probe_win_s = float(pol.get("vad_probe_win_s", 10.0)) # per-window sample length
# offsets: 20%, 50%, 80% — clamp so the window fits
def off(pct: float) -> float:
if dur_s <= probe_win_s: return 0.0
return max(0.0, min(dur_s - probe_win_s, pct * dur_s))
offsets = [off(0.20), off(0.50), off(0.80)]
# Try VAD path; if not available, fall back to prior mid/side approach
try:
import webrtcvad, array, math, subprocess
def best_similarity(a: List[bool], b: List[bool], max_lag: int) -> float:
if not a or not b: return 0.0
n = min(len(a), len(b))
a = a[:n]; b = b[:n]
best = 0.0
for lag in range(-max_lag, max_lag + 1):
if lag > 0:
a2 = a[lag:]; b2 = b[:len(a2)]
elif lag < 0:
b2 = b[-lag:]; a2 = a[:len(b2)]
else:
a2, b2 = a, b
if not a2 or not b2:
continue
matches = sum(1 for x, y in zip(a2, b2) if x == y)
best = max(best, matches / float(len(a2)))
return best
vad = webrtcvad.Vad(vad_aggr)
frame_samples = int(16000 * vad_frame_ms / 1000) # 480
bytes_per_sample = 2
similarities: List[float] = []
win_counts = 0
for ss in offsets:
# decode a small stereo window to s16le@16k on stdout
cmd = [
"ffmpeg", "-nostdin", "-hide_banner", "-v", "error",
"-ss", f"{ss:.3f}", "-t", f"{probe_win_s:.3f}",
"-i", uri, "-map", "0:a:0",
"-ac", "2", "-ar", "16000", "-f", "s16le", "-"
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
# collect all samples in this window
buf = b""
while True:
chunk = proc.stdout.read(65536)
if not chunk: break
buf += chunk
try:
proc.kill()
except Exception:
pass
if not buf:
continue
# to int16
a = array.array("h")
a.frombytes(buf)
if len(a) < 2 * frame_samples:
continue # too short
# deinterleave L/R
L_all = a[0::2]
R_all = a[1::2]
# build VAD masks in 30ms frames
L_mask: List[bool] = []
R_mask: List[bool] = []
# iterate in frame_samples steps
for i in range(0, min(len(L_all), len(R_all)) - frame_samples + 1, frame_samples):
# slice to bytes for vad
L_bytes = array.array("h", L_all[i:i + frame_samples]).tobytes()
R_bytes = array.array("h", R_all[i:i + frame_samples]).tobytes()
L_mask.append(vad.is_speech(L_bytes, 16000))
R_mask.append(vad.is_speech(R_bytes, 16000))
if not L_mask or not R_mask:
continue
sim = best_similarity(L_mask, R_mask, vad_max_lag)
similarities.append(sim)
win_counts += 1
# Decide via median similarity (robust to one odd window)
def median(xs: List[float]) -> Optional[float]:
if not xs: return None
s = sorted(xs); n = len(s)
return s[n // 2] if n % 2 else (s[n // 2 - 1] + s[n // 2]) / 2.0
med_sim = median(similarities)
dual_mono = (med_sim is not None and med_sim >= vad_sim_thr)
rec = "downmix" if dual_mono else "split"
metrics = {
"mid_db": None,
"side_db": None,
"L_db": None,
"R_db": None,
"near_silent": False,
"corr": None,
"dual_mono": dual_mono,
"side_mid_gap_db": None,
"side_mid_thr_db": float(pol.get("dual_mono_side_mid_db", 18.0)),
"rms_delta_thr_db": float(pol.get("dual_mono_rms_delta_db", 1.0)),
"corr_thr": float(pol.get("dual_mono_corr", 0.93)),
"windows_used": win_counts,
"vad_similarities": similarities,
"vad_similarity_median": med_sim,
"vad_params": {
"aggressiveness": vad_aggr,
"frame_ms": vad_frame_ms,
"sim_thr": vad_sim_thr,
"max_lag_frames": vad_max_lag,
"probe_win_s": probe_win_s
}
}
return rec, metrics
except Exception as _vad_err:
# --- Fallback: previous mid/side + correlation method (keeps your pipeline robust) ---
# (This block is identical to the last version you used, just inlined here for brevity.)
# You can keep or remove this fallback if you always have webrtcvad available.
side_mid_thr = float(pol.get("dual_mono_side_mid_db", 18.0))
rms_delta_thr = float(pol.get("dual_mono_rms_delta_db", 1.0))
corr_thr = float(pol.get("dual_mono_corr", 0.93))
corr_probe_ms = int(pol.get("corr_probe_ms", 30000))
win_s2 = float(pol.get("stereo_probe_win_s", 12.0))
def rms_db_window(ss_sec: float, t_sec: float, pan: str) -> Optional[float]:
txt = run_with_retry_collect(
[
"ffmpeg","-nostdin","-hide_banner","-v","error",
"-ss", f"{ss_sec:.3f}", "-t", f"{t_sec:.3f}",
"-i", uri, "-map","0:a:0",
"-af", f"{pan},astats=metadata=1:reset=0,ametadata=print",
"-f","null","-"
],
pol["ff_retries"], pol["ff_timeout_sec"]
)
m = re.findall(r"Overall RMS level:\s*([-\d\.]+)\s*dB", txt) or \
re.findall(r"lavfi\.astats\.Overall\.RMS_level=([-\d\.]+)", txt)
return float(m[-1]) if m else None
pans = {
"mid": "pan=mono|c0=0.5*c0+0.5*c1",
"side": "pan=mono|c0=0.5*c0-0.5*c1",
"L": "pan=mono|c0=c0",
"R": "pan=mono|c0=c1",
}
vals: Dict[str, List[float]] = {k: [] for k in pans.keys()}
for ss in offsets:
for name, pan in pans.items():
v = rms_db_window(ss, win_s2, pan)
if v is not None: vals[name].append(v)
def median(xs: List[float]) -> Optional[float]:
if not xs: return None
s = sorted(xs); n = len(s)
return s[n // 2] if n % 2 else (s[n // 2 - 1] + s[n // 2]) / 2.0
mid = median(vals["mid"])
side = median(vals["side"])
Ldb = median(vals["L"])
Rdb = median(vals["R"])
corr: Optional[float] = None
try:
import array
probe_s = min(win_s2, max(5.0, corr_probe_ms / 1000.0))
cmd = [
"ffmpeg","-nostdin","-hide_banner","-v","error",
"-t", f"{probe_s:.3f}", "-i", uri, "-map","0:a:0",
"-ac","2","-ar","16000","-f","s16le","-"
]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
bytes_per_sample = 2
bytes_per_frame = 2 * bytes_per_sample
chunk_frames = 8192
chunk_bytes = chunk_frames * bytes_per_frame
sumL2 = sumR2 = sumLR = 0.0
while True:
buf = proc.stdout.read(chunk_bytes)
if not buf: break
n_pairs = len(buf) // (2 * bytes_per_sample)
if n_pairs <= 0: break
a = array.array("h"); a.frombytes(buf[: n_pairs * 2 * bytes_per_sample])
for i in range(0, len(a), 2):
Ls = float(a[i]); Rs = float(a[i + 1])
sumL2 += Ls * Ls; sumR2 += Rs * Rs; sumLR += Ls * Rs
try: proc.kill()
except Exception: pass
if sumL2 > 0.0 and sumR2 > 0.0:
import math
corr = float(sumLR / math.sqrt(sumL2 * sumR2))
except Exception:
corr = None
near_silent = False
if Ldb is not None and Rdb is not None:
if (Ldb < -45.0 and (Rdb - Ldb) > 15.0) or (Rdb < -45.0 and (Ldb - Rdb) > 15.0):
near_silent = True
side_mid_gap = None if (mid is None or side is None) else (mid - side)
abs_corr = None if (corr is None) else abs(corr)
cond_side = (side_mid_gap is not None and side_mid_gap >= 15.0)
cond_strong_side = (side_mid_gap is not None and side_mid_gap >= side_mid_thr)
cond_rms = (Ldb is not None and Rdb is not None and abs(Ldb - Rdb) <= rms_delta_thr)
cond_corr = (abs_corr is not None and abs_corr >= corr_thr)
dual_mono = (cond_corr and cond_side) or (cond_strong_side and cond_rms)
rec = "downmix" if (near_silent or dual_mono) else ("downmix" if cond_strong_side else "split")
metrics = {
"mid_db": mid, "side_db": side, "L_db": Ldb, "R_db": Rdb,
"near_silent": near_silent, "corr": corr, "dual_mono": dual_mono,
"side_mid_gap_db": side_mid_gap, "side_mid_thr_db": side_mid_thr,
"rms_delta_thr_db": rms_delta_thr, "corr_thr": corr_thr,
"windows_used": len(vals["mid"])
}
return rec, metrics
# -------- Preprocess (plan-only) --------
def _build_ingest_plan(self):
self._set_stage("preprocess","running",0.1,{"started_at":utc_now_iso()})
sr = int(self.manifest["policy"]["sample_rate_target"])
resolved = self.manifest["stages"]["probe"]["actions"]["channel_policy"]
working = {
"sample_rate": sr,
"channel_map": ["L","R"] if resolved=="split" else ["mono"],
"skipped": True
}
ingest_recipe = {
"decoder": "ffmpeg",
"args": ["-vn","-sn","-map","0:a:0","-ar",str(sr),"-ac","1","-c:a","pcm_s16le"],
"container": "wav",
}
if resolved=="split":
ingest_recipe["channel_extract_filters"]={"L":"pan=mono|c0=c0","R":"pan=mono|c0=c1"}
idem_src = self.manifest["source"]["uri"]
idem_payload = json.dumps({"src":idem_src,"sr":sr,"policy":resolved}, sort_keys=True).encode("utf-8")
idem_key = hashlib.sha256(idem_payload).hexdigest()
self.manifest["stages"]["preprocess"].update({
"idempotency_key": idem_key,
"working": working,
"ingest_recipe": ingest_recipe,
"ended_at": utc_now_iso()
})
self._set_stage("preprocess","done",1.0)
# -------- Split (fixed windows with overlap) --------
def _run_split_plan(self):
self._set_stage("split","running",0.1,{"started_at":utc_now_iso()})
dur_ms = int(self.manifest["source"].get("duration_ms") or 0)
target = int(self.manifest["policy"]["chunk_target_ms"])
overlap = int(self.manifest["policy"]["overlap_ms"])
# Clamp overlap < target to avoid degenerate stepping
if overlap >= target:
overlap = max(0, target - 1)
step = target - overlap
ranges: List[Tuple[int,int]] = []
if dur_ms>0:
s=0
while s < dur_ms:
l = min(target, dur_ms - s)
ranges.append((s,l))
if l < target: break
s += step
channels = self.manifest["stages"]["preprocess"]["working"]["channel_map"]
src = self.manifest["source"]["uri"]
plan_source_uris = {c: src for c in channels}
chunks: List[Dict[str,Any]] = []
idx=0
for c in channels:
for (st,du) in ranges:
chunks.append({"idx": idx if len(channels)==1 else f"{idx}{c}","chan":c,"start_ms":int(st),"dur_ms":int(du),"status":"queued"})
idx+=1
MAX_EMBED = 2000 # keep manifest small
plan = {
"mode": "virtual",
"channels": channels,
"source_uris": plan_source_uris,
"chunk_policy": "fixed_overlap",
"chunk_target_ms": target,
"overlap_ms": overlap,
"total_chunks": len(chunks),
"execution": "transcriber",
"chunks": chunks[:MAX_EMBED],
}
self.manifest["stages"]["split"]["plan"]=plan
self._set_stage("split","done",1.0,{"ended_at":utc_now_iso()})
# Keep transcribe stage for downstream processing
self.manifest["stages"]["transcribe"].update({
"status":"pending","progress":0.0,
"chunks":{"total":len(chunks),"done":0,"running":0,"failed":0,"queued":len(chunks)}
})
# --------------------------- CLI -----------------------------
if __name__ == "__main__":
import argparse
ap = argparse.ArgumentParser(description="Audio preprocess runner (plan-only, fixed split + overlap).")
ap.add_argument("source", help="URL to audio")
ap.add_argument("work_root", help="Working root directory (for manifest)")
ap.add_argument("--manifest", help="Path to existing manifest.json to resume", default=None)
ap.add_argument("--chunk_ms", type=int, default=60000)
ap.add_argument("--overlap_ms", type=int, default=500)
args = ap.parse_args()
presets = {"chunk_target_ms": args.chunk_ms, "overlap_ms": args.overlap_ms}
manifest = None
if args.manifest and os.path.exists(args.manifest):
with open(args.manifest,"r",encoding="utf-8") as f: manifest=json.load(f)
runner = AudioJobRunner(
manifest=manifest,
source_uri=None if manifest else args.source,
work_root=args.work_root,
presets=presets
)
out_manifest = runner.run_until_split()
out_path = os.path.join(args.work_root, "manifest.json")
os.makedirs(args.work_root, exist_ok=True)
with open(out_path, "w", encoding="utf-8") as f:
json.dump(out_manifest, f, ensure_ascii=False, indent=2)
print(f"Saved manifest -> {out_path}")