jaeikkim
Reinit Space without binary assets
7bfbdc3
import os
import argparse
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from typing import Callable, List
import re
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import wandb
from omegaconf import OmegaConf
from transformers import pipeline
from training.data import T2S_INSTRUCTION
from inference.common import (
load_train_config,
get_vq_model_audio,
build_uni_prompting,
load_omada_from_checkpoint,
list_checkpoints,
grid_dict,
init_wandb,
safe_log_table,
)
from models import get_mask_schedule
_ANGLE_TOKEN_RE = re.compile(r"<[^>]+>")
_EXCLAMATIONPOINT_RE = re.compile(r"exclamationpoint", flags=re.IGNORECASE)
_PUNCT_RE = re.compile(r"[^\w\s']")
def _strip_custom_markers(text: str) -> str:
had_exclamationpoint = bool(_EXCLAMATIONPOINT_RE.search(text))
text = _ANGLE_TOKEN_RE.sub(" ", text)
if had_exclamationpoint:
text = _EXCLAMATIONPOINT_RE.sub(" ", text)
if had_exclamationpoint:
text = text.replace(".", "")
text = _PUNCT_RE.sub(" ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
class T2SEvalDataset(Dataset):
def __init__(self, hf_dataset):
self.hf_dataset = hf_dataset
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
ex = self.hf_dataset[idx]
return {"gt_text": ex["text"], "sample_id": ex["id"]}
def ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
def _basic_normalize(text: str) -> str:
text = _strip_custom_markers(text)
text = text.lower()
return text
def build_normalize_fn(mode: str) -> Callable[[str], str]:
mode = (mode or "basic").strip().lower()
if mode in {"off", "none", "no"}:
return lambda s: s
if mode in {"english", "whisper", "whisper_en"}:
try:
from normalizer.normalizer import EnglishTextNormalizer
n = EnglishTextNormalizer()
def _fn(s: str) -> str:
return re.sub(r"\s+", " ", n(s)).strip()
return _fn
except Exception:
return _basic_normalize
return _basic_normalize
def calculate_wer(predictions: List[str], references: List[str], normalize: Callable[[str], str] = _basic_normalize):
import editdistance
# Normalize texts before WER
predictions = [normalize(p) for p in predictions]
references = [normalize(r) for r in references]
total_errors = 0
total_words = 0
for pred, ref in zip(predictions, references):
pw = pred.split()
rw = ref.split()
total_errors += editdistance.eval(pw, rw)
total_words += len(rw)
wer = total_errors / total_words if total_words > 0 else 0.0
return wer, total_errors, total_words
def run_once(ckpt_path: str, hparams: dict, train_cfg, device):
uni_prompting, tokenizer = build_uni_prompting(train_cfg)
vq_audio = get_vq_model_audio(train_cfg, device)
model = load_omada_from_checkpoint(ckpt_path, device)
# Dataset
dcfg = hparams.get("dataset", {})
subset = dcfg.get("subset", "clean")
split = dcfg.get("split", "test")
limit = int(dcfg.get("limit", 32))
ds_raw = load_dataset("librispeech_asr", subset, split=split)
if limit > 0:
ds_raw = ds_raw.select(range(min(limit, len(ds_raw))))
ds = T2SEvalDataset(ds_raw)
batch_size = int(hparams.get("batch_size", train_cfg.training.batch_size_t2s))
loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
# Generation params
mode = str(hparams.get("mode", "fixed")).lower() # 'fixed', 'free', or 'mmu'
guidance_scale = float(hparams.get("guidance_scale", train_cfg.training.guidance_scale))
temperature = float(hparams.get("temperature", 1.0))
timesteps = int(hparams.get("timesteps", 24 if mode != "mmu" else 256))
default_seq = 254 if mode == "fixed" else (511 if mode == "mmu" else 255)
seq_len = int(hparams.get("seq_len", default_seq))
block_length = int(hparams.get("block_length", 128))
max_new_tokens = int(hparams.get("max_new_tokens", seq_len)) if seq_len > 0 else int(hparams.get("max_new_tokens", 512))
audio_codebook_size = int(hparams.get("audio_codebook_size", 4096))
noise_schedule = hparams.get("noise_schedule", train_cfg.training.get("mask_schedule", "cosine"))
# Convert string name to callable schedule function expected by model
noise_schedule_fn = get_mask_schedule(noise_schedule) if isinstance(noise_schedule, str) else noise_schedule
noise_type = hparams.get("noise_type", "mask")
out_root = hparams.get("output_dir", os.path.join("outputs", "t2s"))
ensure_dir(out_root)
# W&B
init_wandb(hparams.get("_infer_cfg", {}), "t2s", ckpt_path, {
"mode": mode,
"guidance_scale": guidance_scale,
"temperature": temperature,
"timesteps": timesteps,
"seq_len": seq_len,
"batch_size": batch_size,
})
mask_token_id = model.config.mask_token_id
rows = []
for batch in loader:
gt_texts: List[str] = batch["gt_text"]
clean_gt_texts = [_strip_custom_markers(text) for text in gt_texts]
sample_ids: List[str] = batch["sample_id"]
# Build chat prompts
prompts = [
f"<|start_header_id|>user<|end_header_id|>\n{T2S_INSTRUCTION[0]}\n{text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n"
for text in clean_gt_texts
]
bsz = len(prompts)
audio_tokens = torch.ones((bsz, seq_len), dtype=torch.long, device=device) * mask_token_id
if mode == "fixed":
input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_fixed_gen')
else:
input_ids, attention_mask = uni_prompting((prompts, audio_tokens), 't2s_gen')
if guidance_scale and guidance_scale > 0 and mode != "mmu":
if mode == "fixed":
uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * bsz, audio_tokens), 't2s_fixed_gen')
else:
uncond_input_ids, uncond_attention_mask = uni_prompting(([''] * bsz, audio_tokens), 't2s_gen')
else:
uncond_input_ids, uncond_attention_mask = None, None
with torch.no_grad():
if mode == "fixed":
outputs = model.t2s_fixed_generate(
input_ids=input_ids.to(device),
uncond_input_ids=None if uncond_input_ids is None else uncond_input_ids.to(device),
attention_mask=attention_mask.to(device),
uncond_attention_mask=None if uncond_attention_mask is None else uncond_attention_mask.to(device),
guidance_scale=guidance_scale,
temperature=temperature,
timesteps=timesteps,
noise_schedule=noise_schedule_fn,
noise_type=noise_type,
seq_len=seq_len,
uni_prompting=uni_prompting,
config=train_cfg,
)
elif mode == "mmu":
outputs = model.t2s_generate_mmu_like(
input_ids=input_ids.to(device),
max_new_tokens=max_new_tokens,
steps=timesteps,
block_length=block_length,
temperature=temperature,
cfg_scale=guidance_scale,
mask_token_id=mask_token_id,
attention_mask=attention_mask.to(device),
uni_prompting=uni_prompting,
codebook_size=train_cfg.model.omada.codebook_size,
audio_codebook_size=audio_codebook_size,
)
else:
outputs = model.t2s_generate(
input_ids=input_ids.to(device),
uncond_input_ids=None if uncond_input_ids is None else uncond_input_ids.to(device),
attention_mask=attention_mask.to(device),
uncond_attention_mask=None if uncond_attention_mask is None else uncond_attention_mask.to(device),
guidance_scale=guidance_scale,
temperature=temperature,
timesteps=timesteps,
noise_schedule=noise_schedule_fn,
noise_type=noise_type,
seq_len=seq_len,
uni_prompting=uni_prompting,
config=train_cfg,
)
# Decode each sample
for i in range(bsz):
if mode == "mmu":
gen_tokens = outputs[i]
if isinstance(gen_tokens, torch.Tensor):
rel_ids = gen_tokens.detach().cpu().tolist()
else:
rel_ids = list(gen_tokens)
else:
rel_ids = outputs[i].tolist()
if not rel_ids:
continue
unit_str = " ".join(map(str, rel_ids))
speech_unit = "".join([f"<|speech_{u}|>" for u in unit_str.split(" ")])
wav_name = f"{os.path.basename(os.path.dirname(ckpt_path))}_{sample_ids[i]}_{mode}.wav"
wav_path = os.path.join(out_root, wav_name)
_ = vq_audio.decode(speech_unit, condition='gender-female_emotion-neutral_speed-normal_pitch-normal', output_wav_file=wav_path)
rows.append([sample_ids[i], clean_gt_texts[i], wav_path])
# Log audio samples
aud_rows = []
for sid, gt, wav in rows[:64]:
aud_rows.append([sid, gt, wandb.Audio(wav, caption=gt)])
safe_log_table("samples/t2s", ["ID", "GT", "Audio"], aud_rows)
# Optional WER evaluation via Whisper (or any ASR pipeline)
asr_model = hparams.get("wer_asr_model")
if asr_model:
try:
lang_in = hparams.get("wer_language", "english")
# Normalize language to avoid locale strings like C.UTF-8
def _norm_lang(x: str) -> str:
if not isinstance(x, str) or not x:
return "english"
x = x.strip().lower()
if "utf" in x or x.startswith("c.") or x == "c":
return "english"
aliases = {
"en": "english", "eng": "english", "english": "english",
"ko": "korean", "kor": "korean", "korean": "korean",
"zh": "chinese", "cmn": "chinese", "chinese": "chinese",
"ja": "japanese", "jpn": "japanese", "japanese": "japanese",
}
return aliases.get(x, "english")
lang = _norm_lang(lang_in)
max_samples = int(hparams.get("wer_max_samples", 1024))
use_cuda = torch.cuda.is_available()
asr_pipe = pipeline("automatic-speech-recognition", model=asr_model, device=0 if use_cuda else -1)
preds, refs = [], []
norm_mode = str(hparams.get("text_norm", "basic"))
normalize_fn = build_normalize_fn(norm_mode)
trans_rows = []
for i, (sid, gt, wav) in enumerate(rows):
if i >= max_samples:
break
try:
out = asr_pipe(wav, generate_kwargs={"language": lang, "task": "transcribe"})
text = out.get("text", "")
except Exception:
text = ""
base_pred = _strip_custom_markers(text)
base_ref = _strip_custom_markers(gt)
preds.append(base_pred)
refs.append(base_ref)
if i < 32:
trans_rows.append([sid, base_ref, base_pred, wandb.Audio(wav, caption=base_pred)])
# Compute WER using normalized text
wer, errors, words = calculate_wer(preds, refs, normalize=normalize_fn)
wandb.log({
"metrics/t2s_wer": wer,
"metrics/t2s_word_errors": errors,
"metrics/t2s_total_words": words,
})
safe_log_table("samples/t2s_transcriptions", ["ID", "GT", "ASR", "Audio"], trans_rows)
except Exception as e:
wandb.log({"warn/t2s_wer_error": str(e)})
wandb.finish()
def main():
parser = argparse.ArgumentParser(description="T2S Inference (fixed/free) with CLI overrides or config grids")
# Required basics
parser.add_argument("--train_config", required=True)
parser.add_argument("--ckpt_root", required=True, help="Experiment output dir or specific checkpoint path")
parser.add_argument("--infer_config", required=False, help="Optional YAML with wandb and/or grid configs")
parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir")
# Optional generation overrides (single run when provided)
parser.add_argument("--mode", choices=["fixed", "free", "mmu"], help="T2S mode: fixed, free, or mmu")
parser.add_argument("--guidance_scale", type=float)
parser.add_argument("--temperature", type=float)
parser.add_argument("--timesteps", type=int)
parser.add_argument("--seq_len", type=int)
parser.add_argument("--block_length", type=int)
parser.add_argument("--max_new_tokens", type=int)
parser.add_argument("--noise_schedule")
parser.add_argument("--noise_type")
parser.add_argument("--batch_size", type=int)
parser.add_argument("--output_dir")
parser.add_argument("--text_norm", choices=["off", "basic", "english", "whisper", "whisper_en"], help="Text normalization for WER")
# Optional dataset overrides
parser.add_argument("--subset")
parser.add_argument("--split")
parser.add_argument("--limit", type=int)
# Optional WER logging via ASR
parser.add_argument("--wer_asr_model", help="HF model id for ASR, e.g., openai/whisper-large-v3")
parser.add_argument("--wer_language", help="Language hint for ASR generation")
parser.add_argument("--wer_max_samples", type=int, help="Max number of samples for WER computation")
parser.add_argument("--audio_codebook_size", type=int, help="Override audio codebook size for MMU mode")
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_cfg = load_train_config(args.train_config)
infer_cfg = {}
if args.infer_config:
infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True)
# Checkpoints
# Build checkpoint list by priority: explicit --checkpoint > infer_config.checkpoints > --ckpt_root
if args.checkpoint:
ckpt_list = []
for p in args.checkpoint:
ckpt_list.extend(list_checkpoints(p))
else:
ckpts = infer_cfg.get("checkpoints") if infer_cfg else None
if ckpts:
ckpt_list = []
for p in ckpts:
ckpt_list.extend(list_checkpoints(p))
else:
ckpt_list = list_checkpoints(args.ckpt_root)
if not ckpt_list:
raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.")
# Decide between single-run overrides or grid from config
override_present = any([
args.mode is not None, args.guidance_scale is not None, args.temperature is not None,
args.timesteps is not None, args.seq_len is not None, args.noise_schedule is not None,
args.noise_type is not None, args.batch_size is not None, args.output_dir is not None,
args.block_length is not None, args.max_new_tokens is not None,
args.text_norm is not None,
args.subset is not None, args.split is not None, args.limit is not None,
])
if override_present or not infer_cfg:
# Build single combination from CLI overrides with fallbacks
single = {
"mode": args.mode or "fixed",
"guidance_scale": args.guidance_scale if args.guidance_scale is not None else float(train_cfg.training.guidance_scale),
"temperature": args.temperature if args.temperature is not None else 1.0,
"timesteps": args.timesteps if args.timesteps is not None else 24,
"seq_len": args.seq_len if args.seq_len is not None else 254,
"batch_size": args.batch_size if args.batch_size is not None else int(train_cfg.training.batch_size_t2s),
"output_dir": args.output_dir or os.path.join("outputs", "t2s"),
"noise_schedule": args.noise_schedule if args.noise_schedule is not None else train_cfg.training.get("mask_schedule", "cosine"),
"noise_type": args.noise_type if args.noise_type is not None else "mask",
}
if args.text_norm is not None:
single["text_norm"] = args.text_norm
if args.block_length is not None:
single["block_length"] = args.block_length
if args.max_new_tokens is not None:
single["max_new_tokens"] = args.max_new_tokens
if args.audio_codebook_size is not None:
single["audio_codebook_size"] = args.audio_codebook_size
# WER options
if args.wer_asr_model is not None:
single["wer_asr_model"] = args.wer_asr_model
if args.wer_language is not None:
single["wer_language"] = args.wer_language
if args.wer_max_samples is not None:
single["wer_max_samples"] = args.wer_max_samples
dcfg = {
"subset": args.subset or "clean",
"split": args.split or "test",
"limit": args.limit if args.limit is not None else 32,
}
single["dataset"] = dcfg
single["_infer_cfg"] = infer_cfg
combos = [single]
else:
# Grid from config, allow CLI overrides to force values across the grid
gen_grid = infer_cfg.get("generation", {
"mode": ["fixed"],
"guidance_scale": [float(train_cfg.training.guidance_scale)],
"temperature": [1.0],
"timesteps": [24],
"seq_len": [254],
"batch_size": [int(train_cfg.training.batch_size_t2s)],
"output_dir": [os.path.join("outputs", "t2s")],
})
combos = grid_dict(gen_grid)
dcfg = infer_cfg.get("dataset", {
"subset": "clean",
"split": "test",
"limit": 32,
})
# Apply dataset overrides if given
if args.subset is not None:
dcfg["subset"] = args.subset
if args.split is not None:
dcfg["split"] = args.split
if args.limit is not None:
dcfg["limit"] = args.limit
# Apply generation overrides across combos if provided
for c in combos:
if args.mode is not None:
c["mode"] = args.mode
if args.guidance_scale is not None:
c["guidance_scale"] = args.guidance_scale
if args.temperature is not None:
c["temperature"] = args.temperature
if args.timesteps is not None:
c["timesteps"] = args.timesteps
if args.seq_len is not None:
c["seq_len"] = args.seq_len
if args.batch_size is not None:
c["batch_size"] = args.batch_size
if args.output_dir is not None:
c["output_dir"] = args.output_dir
if args.noise_schedule is not None:
c["noise_schedule"] = args.noise_schedule
if args.noise_type is not None:
c["noise_type"] = args.noise_type
if args.text_norm is not None:
c["text_norm"] = args.text_norm
if args.block_length is not None:
c["block_length"] = args.block_length
if args.max_new_tokens is not None:
c["max_new_tokens"] = args.max_new_tokens
if args.audio_codebook_size is not None:
c["audio_codebook_size"] = args.audio_codebook_size
if args.wer_asr_model is not None:
c["wer_asr_model"] = args.wer_asr_model
if args.wer_language is not None:
c["wer_language"] = args.wer_language
if args.wer_max_samples is not None:
c["wer_max_samples"] = args.wer_max_samples
c["dataset"] = dcfg
c["_infer_cfg"] = infer_cfg
for ckpt in ckpt_list:
for hp in combos:
run_once(ckpt, hp, train_cfg, device)
if __name__ == "__main__":
main()