Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import argparse | |
| import sys | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from typing import Callable, List | |
| import re | |
| import torch | |
| from torch.utils.data import Dataset, DataLoader | |
| from datasets import load_dataset | |
| import wandb | |
| from omegaconf import OmegaConf | |
| from transformers import pipeline | |
| from training.data import T2S_INSTRUCTION | |
| from inference.common import ( | |
| load_train_config, | |
| get_vq_model_audio, | |
| build_uni_prompting, | |
| load_omada_from_checkpoint, | |
| list_checkpoints, | |
| grid_dict, | |
| init_wandb, | |
| safe_log_table, | |
| ) | |
| from models import get_mask_schedule | |
| _ANGLE_TOKEN_RE = re.compile(r"<[^>]+>") | |
| _EXCLAMATIONPOINT_RE = re.compile(r"exclamationpoint", flags=re.IGNORECASE) | |
| _PUNCT_RE = re.compile(r"[^\w\s']") | |
| def _strip_custom_markers(text: str) -> str: | |
| had_exclamationpoint = bool(_EXCLAMATIONPOINT_RE.search(text)) | |
| text = _ANGLE_TOKEN_RE.sub(" ", text) | |
| if had_exclamationpoint: | |
| text = _EXCLAMATIONPOINT_RE.sub(" ", text) | |
| if had_exclamationpoint: | |
| text = text.replace(".", "") | |
| text = _PUNCT_RE.sub(" ", text) | |
| text = re.sub(r"\s+", " ", text).strip() | |
| return text | |
| class T2SEvalDataset(Dataset): | |
| def __init__(self, hf_dataset): | |
| self.hf_dataset = hf_dataset | |
| def __len__(self): | |
| return len(self.hf_dataset) | |
| def __getitem__(self, idx): | |
| ex = self.hf_dataset[idx] | |
| return {"gt_text": ex["text"], "sample_id": ex["id"]} | |
| def ensure_dir(path: str): | |
| os.makedirs(path, exist_ok=True) | |
| def _basic_normalize(text: str) -> str: | |
| text = _strip_custom_markers(text) | |
| text = text.lower() | |
| return text | |
| def build_normalize_fn(mode: str) -> Callable[[str], str]: | |
| mode = (mode or "basic").strip().lower() | |
| if mode in {"off", "none", "no"}: | |
| return lambda s: s | |
| if mode in {"english", "whisper", "whisper_en"}: | |
| try: | |
| from normalizer.normalizer import EnglishTextNormalizer | |
| n = EnglishTextNormalizer() | |
| def _fn(s: str) -> str: | |
| return re.sub(r"\s+", " ", n(s)).strip() | |
| return _fn | |
| except Exception: | |
| return _basic_normalize | |
| return _basic_normalize | |
| def calculate_wer(predictions: List[str], references: List[str], normalize: Callable[[str], str] = _basic_normalize): | |
| import editdistance | |
| # Normalize texts before WER | |
| predictions = [normalize(p) for p in predictions] | |
| references = [normalize(r) for r in references] | |
| total_errors = 0 | |
| total_words = 0 | |
| for pred, ref in zip(predictions, references): | |
| pw = pred.split() | |
| rw = ref.split() | |
| total_errors += editdistance.eval(pw, rw) | |
| total_words += len(rw) | |
| wer = total_errors / total_words if total_words > 0 else 0.0 | |
| return wer, total_errors, total_words | |
| def run_once(ckpt_path: str, hparams: dict, train_cfg, device): | |
| uni_prompting, tokenizer = build_uni_prompting(train_cfg) | |
| vq_audio = get_vq_model_audio(train_cfg, device) | |
| model = load_omada_from_checkpoint(ckpt_path, device) | |
| # Dataset | |
| dcfg = hparams.get("dataset", {}) | |
| subset = dcfg.get("subset", "clean") | |
| split = dcfg.get("split", "test") | |
| limit = int(dcfg.get("limit", 32)) | |
| ds_raw = load_dataset("librispeech_asr", subset, split=split) | |
| if limit > 0: | |
| ds_raw = ds_raw.select(range(min(limit, len(ds_raw)))) | |
| ds = T2SEvalDataset(ds_raw) | |
| batch_size = int(hparams.get("batch_size", train_cfg.training.batch_size_t2s)) | |
| loader = DataLoader(ds, batch_size=batch_size, shuffle=False) | |
| # Generation params | |
| mode = str(hparams.get("mode", "fixed")).lower() # 'fixed', 'free', or 'mmu' | |
| guidance_scale = float(hparams.get("guidance_scale", train_cfg.training.guidance_scale)) | |
| temperature = float(hparams.get("temperature", 1.0)) | |
| timesteps = int(hparams.get("timesteps", 24 if mode != "mmu" else 256)) | |
| default_seq = 254 if mode == "fixed" else (511 if mode == "mmu" else 255) | |
| seq_len = int(hparams.get("seq_len", default_seq)) | |
| block_length = int(hparams.get("block_length", 128)) | |
| max_new_tokens = int(hparams.get("max_new_tokens", seq_len)) if seq_len > 0 else int(hparams.get("max_new_tokens", 512)) | |
| audio_codebook_size = int(hparams.get("audio_codebook_size", 4096)) | |
| noise_schedule = hparams.get("noise_schedule", train_cfg.training.get("mask_schedule", "cosine")) | |
| # Convert string name to callable schedule function expected by model | |
| noise_schedule_fn = get_mask_schedule(noise_schedule) if isinstance(noise_schedule, str) else noise_schedule | |
| noise_type = hparams.get("noise_type", "mask") | |
| out_root = hparams.get("output_dir", os.path.join("outputs", "t2s")) | |
| ensure_dir(out_root) | |
| # W&B | |
| init_wandb(hparams.get("_infer_cfg", {}), "t2s", ckpt_path, { | |
| "mode": mode, | |
| "guidance_scale": guidance_scale, | |
| "temperature": temperature, | |
| "timesteps": timesteps, | |
| "seq_len": seq_len, | |
| "batch_size": batch_size, | |
| }) | |
| mask_token_id = model.config.mask_token_id | |
| rows = [] | |
| for batch in loader: | |
| gt_texts: List[str] = batch["gt_text"] | |
| clean_gt_texts = [_strip_custom_markers(text) for text in gt_texts] | |
| sample_ids: List[str] = batch["sample_id"] | |
| # Build chat prompts | |
| prompts = [ | |
| f"<|start_header_id|>user<|end_header_id|>\n{T2S_INSTRUCTION[0]}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n" | |
| for text in clean_gt_texts | |
| ] | |
| bsz = len(prompts) | |
| audio_tokens = torch.ones((bsz, seq_len), dtype=torch.long, device=device) * mask_token_id | |
| if mode == "fixed": | |
| input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen') | |
| else: | |
| input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen') | |
| if guidance_scale and guidance_scale > 0 and mode != "mmu": | |
| if mode == "fixed": | |
| uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * bsz, audio_tokens), 't2s_fixed_gen') | |
| else: | |
| uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * bsz, audio_tokens), 't2s_gen') | |
| else: | |
| uncond_input_ids, uncond_attention_mask = None, None | |
| with torch.no_grad(): | |
| if mode == "fixed": | |
| outputs = model.t2s_fixed_generate( | |
| input_ids=input_ids.to(device), | |
| uncond_input_ids=None if uncond_input_ids is None else uncond_input_ids.to(device), | |
| attention_mask=attention_mask.to(device), | |
| uncond_attention_mask=None if uncond_attention_mask is None else uncond_attention_mask.to(device), | |
| guidance_scale=guidance_scale, | |
| temperature=temperature, | |
| timesteps=timesteps, | |
| noise_schedule=noise_schedule_fn, | |
| noise_type=noise_type, | |
| seq_len=seq_len, | |
| uni_prompting=uni_prompting, | |
| config=train_cfg, | |
| ) | |
| elif mode == "mmu": | |
| outputs = model.t2s_generate_mmu_like( | |
| input_ids=input_ids.to(device), | |
| max_new_tokens=max_new_tokens, | |
| steps=timesteps, | |
| block_length=block_length, | |
| temperature=temperature, | |
| cfg_scale=guidance_scale, | |
| mask_token_id=mask_token_id, | |
| attention_mask=attention_mask.to(device), | |
| uni_prompting=uni_prompting, | |
| codebook_size=train_cfg.model.omada.codebook_size, | |
| audio_codebook_size=audio_codebook_size, | |
| ) | |
| else: | |
| outputs = model.t2s_generate( | |
| input_ids=input_ids.to(device), | |
| uncond_input_ids=None if uncond_input_ids is None else uncond_input_ids.to(device), | |
| attention_mask=attention_mask.to(device), | |
| uncond_attention_mask=None if uncond_attention_mask is None else uncond_attention_mask.to(device), | |
| guidance_scale=guidance_scale, | |
| temperature=temperature, | |
| timesteps=timesteps, | |
| noise_schedule=noise_schedule_fn, | |
| noise_type=noise_type, | |
| seq_len=seq_len, | |
| uni_prompting=uni_prompting, | |
| config=train_cfg, | |
| ) | |
| # Decode each sample | |
| for i in range(bsz): | |
| if mode == "mmu": | |
| gen_tokens = outputs[i] | |
| if isinstance(gen_tokens, torch.Tensor): | |
| rel_ids = gen_tokens.detach().cpu().tolist() | |
| else: | |
| rel_ids = list(gen_tokens) | |
| else: | |
| rel_ids = outputs[i].tolist() | |
| if not rel_ids: | |
| continue | |
| unit_str = " ".join(map(str, rel_ids)) | |
| speech_unit = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")]) | |
| wav_name = f"{os.path.basename(os.path.dirname(ckpt_path))}_{sample_ids[i]}_{mode}.wav" | |
| wav_path = os.path.join(out_root, wav_name) | |
| _ = vq_audio.decode(speech_unit, condition='gender-female_emotion-neutral_speed-normal_pitch-normal', output_wav_file=wav_path) | |
| rows.append([sample_ids[i], clean_gt_texts[i], wav_path]) | |
| # Log audio samples | |
| aud_rows = [] | |
| for sid, gt, wav in rows[:64]: | |
| aud_rows.append([sid, gt, wandb.Audio(wav, caption=gt)]) | |
| safe_log_table("samples/t2s", ["ID", "GT", "Audio"], aud_rows) | |
| # Optional WER evaluation via Whisper (or any ASR pipeline) | |
| asr_model = hparams.get("wer_asr_model") | |
| if asr_model: | |
| try: | |
| lang_in = hparams.get("wer_language", "english") | |
| # Normalize language to avoid locale strings like C.UTF-8 | |
| def _norm_lang(x: str) -> str: | |
| if not isinstance(x, str) or not x: | |
| return "english" | |
| x = x.strip().lower() | |
| if "utf" in x or x.startswith("c.") or x == "c": | |
| return "english" | |
| aliases = { | |
| "en": "english", "eng": "english", "english": "english", | |
| "ko": "korean", "kor": "korean", "korean": "korean", | |
| "zh": "chinese", "cmn": "chinese", "chinese": "chinese", | |
| "ja": "japanese", "jpn": "japanese", "japanese": "japanese", | |
| } | |
| return aliases.get(x, "english") | |
| lang = _norm_lang(lang_in) | |
| max_samples = int(hparams.get("wer_max_samples", 1024)) | |
| use_cuda = torch.cuda.is_available() | |
| asr_pipe = pipeline("automatic-speech-recognition", model=asr_model, device=0 if use_cuda else -1) | |
| preds, refs = [], [] | |
| norm_mode = str(hparams.get("text_norm", "basic")) | |
| normalize_fn = build_normalize_fn(norm_mode) | |
| trans_rows = [] | |
| for i, (sid, gt, wav) in enumerate(rows): | |
| if i >= max_samples: | |
| break | |
| try: | |
| out = asr_pipe(wav, generate_kwargs={"language": lang, "task": "transcribe"}) | |
| text = out.get("text", "") | |
| except Exception: | |
| text = "" | |
| base_pred = _strip_custom_markers(text) | |
| base_ref = _strip_custom_markers(gt) | |
| preds.append(base_pred) | |
| refs.append(base_ref) | |
| if i < 32: | |
| trans_rows.append([sid, base_ref, base_pred, wandb.Audio(wav, caption=base_pred)]) | |
| # Compute WER using normalized text | |
| wer, errors, words = calculate_wer(preds, refs, normalize=normalize_fn) | |
| wandb.log({ | |
| "metrics/t2s_wer": wer, | |
| "metrics/t2s_word_errors": errors, | |
| "metrics/t2s_total_words": words, | |
| }) | |
| safe_log_table("samples/t2s_transcriptions", ["ID", "GT", "ASR", "Audio"], trans_rows) | |
| except Exception as e: | |
| wandb.log({"warn/t2s_wer_error": str(e)}) | |
| wandb.finish() | |
| def main(): | |
| parser = argparse.ArgumentParser(description="T2S Inference (fixed/free) with CLI overrides or config grids") | |
| # Required basics | |
| parser.add_argument("--train_config", required=True) | |
| parser.add_argument("--ckpt_root", required=True, help="Experiment output dir or specific checkpoint path") | |
| parser.add_argument("--infer_config", required=False, help="Optional YAML with wandb and/or grid configs") | |
| parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir") | |
| # Optional generation overrides (single run when provided) | |
| parser.add_argument("--mode", choices=["fixed", "free", "mmu"], help="T2S mode: fixed, free, or mmu") | |
| parser.add_argument("--guidance_scale", type=float) | |
| parser.add_argument("--temperature", type=float) | |
| parser.add_argument("--timesteps", type=int) | |
| parser.add_argument("--seq_len", type=int) | |
| parser.add_argument("--block_length", type=int) | |
| parser.add_argument("--max_new_tokens", type=int) | |
| parser.add_argument("--noise_schedule") | |
| parser.add_argument("--noise_type") | |
| parser.add_argument("--batch_size", type=int) | |
| parser.add_argument("--output_dir") | |
| parser.add_argument("--text_norm", choices=["off", "basic", "english", "whisper", "whisper_en"], help="Text normalization for WER") | |
| # Optional dataset overrides | |
| parser.add_argument("--subset") | |
| parser.add_argument("--split") | |
| parser.add_argument("--limit", type=int) | |
| # Optional WER logging via ASR | |
| parser.add_argument("--wer_asr_model", help="HF model id for ASR, e.g., openai/whisper-large-v3") | |
| parser.add_argument("--wer_language", help="Language hint for ASR generation") | |
| parser.add_argument("--wer_max_samples", type=int, help="Max number of samples for WER computation") | |
| parser.add_argument("--audio_codebook_size", type=int, help="Override audio codebook size for MMU mode") | |
| args = parser.parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| train_cfg = load_train_config(args.train_config) | |
| infer_cfg = {} | |
| if args.infer_config: | |
| infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True) | |
| # Checkpoints | |
| # Build checkpoint list by priority: explicit --checkpoint > infer_config.checkpoints > --ckpt_root | |
| if args.checkpoint: | |
| ckpt_list = [] | |
| for p in args.checkpoint: | |
| ckpt_list.extend(list_checkpoints(p)) | |
| else: | |
| ckpts = infer_cfg.get("checkpoints") if infer_cfg else None | |
| if ckpts: | |
| ckpt_list = [] | |
| for p in ckpts: | |
| ckpt_list.extend(list_checkpoints(p)) | |
| else: | |
| ckpt_list = list_checkpoints(args.ckpt_root) | |
| if not ckpt_list: | |
| raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.") | |
| # Decide between single-run overrides or grid from config | |
| override_present = any([ | |
| args.mode is not None, args.guidance_scale is not None, args.temperature is not None, | |
| args.timesteps is not None, args.seq_len is not None, args.noise_schedule is not None, | |
| args.noise_type is not None, args.batch_size is not None, args.output_dir is not None, | |
| args.block_length is not None, args.max_new_tokens is not None, | |
| args.text_norm is not None, | |
| args.subset is not None, args.split is not None, args.limit is not None, | |
| ]) | |
| if override_present or not infer_cfg: | |
| # Build single combination from CLI overrides with fallbacks | |
| single = { | |
| "mode": args.mode or "fixed", | |
| "guidance_scale": args.guidance_scale if args.guidance_scale is not None else float(train_cfg.training.guidance_scale), | |
| "temperature": args.temperature if args.temperature is not None else 1.0, | |
| "timesteps": args.timesteps if args.timesteps is not None else 24, | |
| "seq_len": args.seq_len if args.seq_len is not None else 254, | |
| "batch_size": args.batch_size if args.batch_size is not None else int(train_cfg.training.batch_size_t2s), | |
| "output_dir": args.output_dir or os.path.join("outputs", "t2s"), | |
| "noise_schedule": args.noise_schedule if args.noise_schedule is not None else train_cfg.training.get("mask_schedule", "cosine"), | |
| "noise_type": args.noise_type if args.noise_type is not None else "mask", | |
| } | |
| if args.text_norm is not None: | |
| single["text_norm"] = args.text_norm | |
| if args.block_length is not None: | |
| single["block_length"] = args.block_length | |
| if args.max_new_tokens is not None: | |
| single["max_new_tokens"] = args.max_new_tokens | |
| if args.audio_codebook_size is not None: | |
| single["audio_codebook_size"] = args.audio_codebook_size | |
| # WER options | |
| if args.wer_asr_model is not None: | |
| single["wer_asr_model"] = args.wer_asr_model | |
| if args.wer_language is not None: | |
| single["wer_language"] = args.wer_language | |
| if args.wer_max_samples is not None: | |
| single["wer_max_samples"] = args.wer_max_samples | |
| dcfg = { | |
| "subset": args.subset or "clean", | |
| "split": args.split or "test", | |
| "limit": args.limit if args.limit is not None else 32, | |
| } | |
| single["dataset"] = dcfg | |
| single["_infer_cfg"] = infer_cfg | |
| combos = [single] | |
| else: | |
| # Grid from config, allow CLI overrides to force values across the grid | |
| gen_grid = infer_cfg.get("generation", { | |
| "mode": ["fixed"], | |
| "guidance_scale": [float(train_cfg.training.guidance_scale)], | |
| "temperature": [1.0], | |
| "timesteps": [24], | |
| "seq_len": [254], | |
| "batch_size": [int(train_cfg.training.batch_size_t2s)], | |
| "output_dir": [os.path.join("outputs", "t2s")], | |
| }) | |
| combos = grid_dict(gen_grid) | |
| dcfg = infer_cfg.get("dataset", { | |
| "subset": "clean", | |
| "split": "test", | |
| "limit": 32, | |
| }) | |
| # Apply dataset overrides if given | |
| if args.subset is not None: | |
| dcfg["subset"] = args.subset | |
| if args.split is not None: | |
| dcfg["split"] = args.split | |
| if args.limit is not None: | |
| dcfg["limit"] = args.limit | |
| # Apply generation overrides across combos if provided | |
| for c in combos: | |
| if args.mode is not None: | |
| c["mode"] = args.mode | |
| if args.guidance_scale is not None: | |
| c["guidance_scale"] = args.guidance_scale | |
| if args.temperature is not None: | |
| c["temperature"] = args.temperature | |
| if args.timesteps is not None: | |
| c["timesteps"] = args.timesteps | |
| if args.seq_len is not None: | |
| c["seq_len"] = args.seq_len | |
| if args.batch_size is not None: | |
| c["batch_size"] = args.batch_size | |
| if args.output_dir is not None: | |
| c["output_dir"] = args.output_dir | |
| if args.noise_schedule is not None: | |
| c["noise_schedule"] = args.noise_schedule | |
| if args.noise_type is not None: | |
| c["noise_type"] = args.noise_type | |
| if args.text_norm is not None: | |
| c["text_norm"] = args.text_norm | |
| if args.block_length is not None: | |
| c["block_length"] = args.block_length | |
| if args.max_new_tokens is not None: | |
| c["max_new_tokens"] = args.max_new_tokens | |
| if args.audio_codebook_size is not None: | |
| c["audio_codebook_size"] = args.audio_codebook_size | |
| if args.wer_asr_model is not None: | |
| c["wer_asr_model"] = args.wer_asr_model | |
| if args.wer_language is not None: | |
| c["wer_language"] = args.wer_language | |
| if args.wer_max_samples is not None: | |
| c["wer_max_samples"] = args.wer_max_samples | |
| c["dataset"] = dcfg | |
| c["_infer_cfg"] = infer_cfg | |
| for ckpt in ckpt_list: | |
| for hp in combos: | |
| run_once(ckpt, hp, train_cfg, device) | |
| if __name__ == "__main__": | |
| main() | |