jaeikkim
Reinit Space without binary assets
7bfbdc3
import os
import argparse
import sys
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from functools import partial
from typing import Callable, List
import re
import torch
from torch.utils.data import Dataset, DataLoader
from datasets import load_dataset
import wandb
from omegaconf import OmegaConf
from training.data import S2T_INSTRUCTION
from inference.common import (
load_train_config,
get_vq_model_audio,
build_uni_prompting,
load_omada_from_checkpoint,
list_checkpoints,
grid_dict,
init_wandb,
safe_log_table,
)
_ANGLE_TOKEN_RE = re.compile(r"<[^>]+>")
_EXCLAMATIONPOINT_RE = re.compile(r"EXCLAMATIONPOINT", flags=re.IGNORECASE)
_PUNCT_RE = re.compile(r"[^\w\s']")
def _strip_custom_markers(text: str) -> str:
had_exclamationpoint = bool(_EXCLAMATIONPOINT_RE.search(text))
text = _ANGLE_TOKEN_RE.sub(" ", text)
if had_exclamationpoint:
text = _EXCLAMATIONPOINT_RE.sub(" ", text)
if had_exclamationpoint:
text = text.replace(".", "")
text = _PUNCT_RE.sub(" ", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def _basic_normalize(text: str) -> str:
text = _strip_custom_markers(text)
text = text.lower()
text = re.sub(r"[^\w\s']", "", text)
text = re.sub(r"\s+", " ", text).strip()
return text
def build_normalize_fn(mode: str) -> Callable[[str], str]:
mode = (mode or "basic").strip().lower()
if mode in {"off", "none", "no"}:
return lambda s: s
if mode in {"english", "whisper", "whisper_en"}:
try:
from normalizer.normalizer import EnglishTextNormalizer
n = EnglishTextNormalizer()
def _fn(s: str) -> str:
return re.sub(r"\s+", " ", n(s)).strip()
return _fn
except Exception:
# Fallback to basic if normalizer package import fails
return _basic_normalize
# default basic
return _basic_normalize
def calculate_wer(predictions: List[str], references: List[str], normalize: Callable[[str], str] = _basic_normalize):
import editdistance
# Normalize texts before WER
predictions = [normalize(p) for p in predictions]
references = [normalize(r) for r in references]
total_errors = 0
total_words = 0
for pred, ref in zip(predictions, references):
pred_words = pred.split()
ref_words = ref.split()
total_errors += editdistance.eval(pred_words, ref_words)
total_words += len(ref_words)
wer = total_errors / total_words if total_words > 0 else 0.0
return wer, total_errors, total_words
class S2TEvalDataset(Dataset):
def __init__(self, hf_dataset, root_path: str):
self.hf_dataset = hf_dataset
self.root_path = root_path
def __len__(self):
return len(self.hf_dataset)
def __getitem__(self, idx):
ex = self.hf_dataset[idx]
sample_id = ex["id"]
speaker_id, chapter_id, _ = sample_id.split("-")
audio_path = os.path.join(self.root_path, speaker_id, chapter_id, f"{sample_id}.flac")
return {"audio_path": audio_path, "gt_text": ex["text"], "sample_id": sample_id}
def s2t_eval_collate_fn(batch, vq_model_audio, tokenizer, uni_prompting, cfg):
import random
audio_tokens_batch = []
offset = len(uni_prompting.text_tokenizer) + cfg.model.omada.codebook_size
for item in batch:
path = item['audio_path']
tokens = vq_model_audio.encode(path)
tokens_with_offset = tokens + offset
audio_tokens_batch.append(tokens_with_offset)
sptids = uni_prompting.sptids_dict
device = audio_tokens_batch[0].device
batched_input_ids = []
for audio_tokens in audio_tokens_batch:
task_tensor = sptids['<|s2t|>'].to(device).unsqueeze(0)
soa_tensor = sptids['<|soa|>'].to(device).unsqueeze(0)
eoa_tensor = sptids['<|eoa|>'].to(device).unsqueeze(0)
audio_block = torch.cat([task_tensor, soa_tensor, audio_tokens, eoa_tensor], dim=1)
prompt_text = random.choice(S2T_INSTRUCTION)
full_prompt_text = f'<|start_header_id|>user<|end_header_id|>\n{prompt_text}<|eot_id|><|start_header_id|>assistant<|end_header_id|>\n'
prompt_tensor = tokenizer(full_prompt_text, return_tensors="pt").input_ids.to(device)
final_seq = torch.cat([audio_block, prompt_tensor], dim=1)
batched_input_ids.append(final_seq.squeeze(0))
max_len = max(seq.size(0) for seq in batched_input_ids)
pad_token_id = 126093
final_batch_input_ids = torch.full(
(len(batched_input_ids), max_len),
pad_token_id,
dtype=torch.long,
device=device,
)
for i, seq in enumerate(batched_input_ids):
final_batch_input_ids[i, -len(seq):] = seq
return {
"input_ids": final_batch_input_ids,
"gt_texts": [item['gt_text'] for item in batch],
"sample_ids": [item['sample_id'] for item in batch],
}
def run_once(ckpt_path: str, hparams: dict, train_cfg, device):
# Models and prompting
uni_prompting, tokenizer = build_uni_prompting(train_cfg)
vq_audio = get_vq_model_audio(train_cfg, device)
model = load_omada_from_checkpoint(ckpt_path, device)
# Dataset
dcfg = hparams.get("dataset", {})
subset = dcfg.get("subset", "clean")
split = dcfg.get("split", "test")
limit = int(dcfg.get("limit", 128))
root_path = dcfg.get("root_path", "/home/work/AIDAS/data/audio/LibriSpeech/test-clean")
ds_raw = load_dataset("librispeech_asr", subset, split=split)
if limit > 0:
ds_raw = ds_raw.select(range(min(limit, len(ds_raw))))
ds = S2TEvalDataset(ds_raw, root_path=root_path)
collate = partial(
s2t_eval_collate_fn,
vq_model_audio=vq_audio,
tokenizer=uni_prompting.text_tokenizer,
uni_prompting=uni_prompting,
cfg=train_cfg,
)
batch_size = int(hparams.get("batch_size", train_cfg.training.batch_size_s2t))
loader = DataLoader(ds, batch_size=batch_size, shuffle=False, collate_fn=collate)
# Generation hparams
steps = int(hparams.get("steps", 128))
block_length = int(hparams.get("block_length", 64))
max_new_tokens = int(hparams.get("max_new_tokens", 256))
remasking = hparams.get("remasking", "low_confidence")
# W&B
init_wandb(hparams.get("_infer_cfg", {}), "s2t", ckpt_path, {
"steps": steps,
"block_length": block_length,
"max_new_tokens": max_new_tokens,
"remasking": remasking,
"batch_size": batch_size,
})
preds, refs, rows = [], [], []
norm_mode = str(hparams.get("text_norm", "basic"))
normalize_fn = build_normalize_fn(norm_mode)
for batch in loader:
input_ids = batch["input_ids"].to(device)
gt_texts = batch["gt_texts"]
sample_ids = batch["sample_ids"]
with torch.no_grad():
output_ids = model.mmu_generate(
input_ids,
max_new_tokens=max_new_tokens,
steps=steps,
block_length=block_length,
remasking=remasking,
)
decoded = uni_prompting.text_tokenizer.batch_decode(
output_ids[:, input_ids.shape[1]:], skip_special_tokens=True
)
# print(decoded)
clean_gts = [_strip_custom_markers(gt) for gt in gt_texts]
clean_preds = [_strip_custom_markers(pred) for pred in decoded]
print(clean_preds)
for sid, clean_gt, clean_pred in zip(sample_ids, clean_gts, clean_preds):
refs.append(clean_gt)
preds.append(clean_pred)
rows.append([sid, clean_gt, clean_pred])
wer, errors, words = calculate_wer(preds, refs, normalize=normalize_fn)
wandb.log({
"metrics/s2t_wer": wer,
"metrics/s2t_word_errors": errors,
"metrics/s2t_total_words": words,
})
safe_log_table("samples/s2t", ["ID", "GT", "PRED"], rows[:64])
wandb.finish()
def main():
parser = argparse.ArgumentParser(description="S2T Inference with CLI overrides or config grids")
parser.add_argument("--train_config", required=True, help="Path to training YAML used to build tokenizers and VQ models")
parser.add_argument("--ckpt_root", required=True, help="Experiment output dir or a specific checkpoint path")
parser.add_argument("--infer_config", required=False, help="Optional YAML for W&B and grids")
parser.add_argument("--checkpoint", action="append", help="Repeatable: explicit checkpoint path(s). Can be '.../unwrapped_model', '.../checkpoint-XXXX', or experiment dir")
# Generation overrides
parser.add_argument("--steps", type=int)
parser.add_argument("--block_length", type=int)
parser.add_argument("--max_new_tokens", type=int)
parser.add_argument("--remasking")
parser.add_argument("--batch_size", type=int)
parser.add_argument("--text_norm", choices=["off", "basic", "english", "whisper", "whisper_en"], help="Text normalization for WER")
# Dataset overrides
parser.add_argument("--subset")
parser.add_argument("--split")
parser.add_argument("--root_path")
parser.add_argument("--limit", type=int)
args = parser.parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
train_cfg = load_train_config(args.train_config)
infer_cfg = {}
if args.infer_config:
infer_cfg = OmegaConf.to_container(OmegaConf.load(args.infer_config), resolve=True)
# Checkpoints
# Build checkpoint list: --checkpoint > infer_config.checkpoints > --ckpt_root
if args.checkpoint:
ckpt_list = []
for p in args.checkpoint:
ckpt_list.extend(list_checkpoints(p))
else:
ckpts = infer_cfg.get("checkpoints") if infer_cfg else None
if ckpts:
ckpt_list = []
for p in ckpts:
ckpt_list.extend(list_checkpoints(p))
else:
ckpt_list = list_checkpoints(args.ckpt_root)
if not ckpt_list:
raise FileNotFoundError(f"No checkpoints found under {args.ckpt_root} or in infer config.")
override_present = any([
args.steps is not None, args.block_length is not None, args.max_new_tokens is not None,
args.remasking is not None, args.batch_size is not None,
args.text_norm is not None,
args.subset is not None, args.split is not None, args.root_path is not None, args.limit is not None,
])
if override_present or not infer_cfg:
single = {
"steps": args.steps if args.steps is not None else 128,
"block_length": args.block_length if args.block_length is not None else 64,
"max_new_tokens": args.max_new_tokens if args.max_new_tokens is not None else 256,
"remasking": args.remasking if args.remasking is not None else "low_confidence",
"batch_size": args.batch_size if args.batch_size is not None else int(train_cfg.training.batch_size_s2t),
}
if args.text_norm is not None:
single["text_norm"] = args.text_norm
dcfg = {
"subset": args.subset or "clean",
"split": args.split or "test",
"root_path": args.root_path or "/home/work/AIDAS/data/audio/LibriSpeech/test-clean",
"limit": args.limit if args.limit is not None else 128,
}
single["dataset"] = dcfg
single["_infer_cfg"] = infer_cfg
combos = [single]
else:
gen_grid = infer_cfg.get("generation", {
"steps": [128],
"block_length": [64],
"max_new_tokens": [256],
"remasking": ["low_confidence"],
"batch_size": [int(train_cfg.training.batch_size_s2t)],
})
combos = grid_dict(gen_grid)
dcfg = infer_cfg.get("dataset", {
"subset": "clean",
"split": "test",
"root_path": "/home/work/AIDAS/data/audio/LibriSpeech/test-clean",
"limit": 128,
})
# Apply overrides if provided
if args.subset is not None:
dcfg["subset"] = args.subset
if args.split is not None:
dcfg["split"] = args.split
if args.root_path is not None:
dcfg["root_path"] = args.root_path
if args.limit is not None:
dcfg["limit"] = args.limit
for c in combos:
if args.steps is not None:
c["steps"] = args.steps
if args.block_length is not None:
c["block_length"] = args.block_length
if args.max_new_tokens is not None:
c["max_new_tokens"] = args.max_new_tokens
if args.remasking is not None:
c["remasking"] = args.remasking
if args.batch_size is not None:
c["batch_size"] = args.batch_size
if args.text_norm is not None:
c["text_norm"] = args.text_norm
c["dataset"] = dcfg
c["_infer_cfg"] = infer_cfg
for ckpt in ckpt_list:
for hp in combos:
run_once(ckpt, hp, train_cfg, device)
if __name__ == "__main__":
main()