#!/usr/bin/env python3 """ inference_first.py — quick Stage-1 sanity-check for StyleTTS-2 Example ------- python inference_first.py \ --ckpt logs/pod_90h_30k/epoch_1st_0004.pth \ --ref data/wavs/123_abcd_part042_00.wav \ --text " ðɪs ɪz ɐ tɛst ˈsɛntəns" It writes preview.wav in the current directory. """ import argparse, yaml, torch, torchaudio from models import build_model, load_ASR_models, load_F0_models from Utils.PLBERT.util import load_plbert from utils import recursive_munch, log_norm, length_to_mask from meldataset import TextCleaner, preprocess # ────────────────────────── helpers ──────────────────────────── def _restore_batch(x): """(T,) ▸ (1,T) or (C,T) ▸ (1,C,T) (handles squeeze in JDCNet).""" return x.unsqueeze(0) if x.dim() == 1 else x def _match_len(x, target_len): """Crop or zero-pad last axis to target_len.""" cur = x.shape[-1] if cur > target_len: return x[..., :target_len] if cur < target_len: pad = target_len - cur return torch.nn.functional.pad(x, (0, pad)) return x # ────────────────────────── CLI ──────────────────────────────── p = argparse.ArgumentParser() p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth") p.add_argument("--ref", required=True, help="reference wav (24 kHz mono)") p.add_argument("--text", required=True, help="IPA / phoneme string") p.add_argument("--cfg", default="Configs/config_ft_single.yml") args = p.parse_args() # ───────────────── net & cfg ─────────────────────────────────── cfg = yaml.safe_load(open(args.cfg)) sr = cfg["preprocess_params"]["sr"] device = "cuda" asr = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"]) f0 = load_F0_models(cfg["F0_path"]) bert = load_plbert(cfg["PLBERT_dir"]) model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert) state = torch.load(args.ckpt, map_location="cpu")["net"] for k in model: model[k].load_state_dict(state[k], strict=False) model[k].eval().to(device) # ───────────────── prepare inputs ────────────────────────────── cleaner = TextCleaner() text_ids = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device) input_lengths = torch.LongTensor([text_ids.shape[1]]).to(device) text_mask = length_to_mask(input_lengths).to(device) wav, _ = torchaudio.load(args.ref) # (1,N) mel_ref = preprocess(wav.squeeze().numpy()).to(device) # (1,80,T) style = model.style_encoder(mel_ref.unsqueeze(1)) # (1,128) F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1)) F0_real = _restore_batch(F0_real) # (1,T') real_norm = log_norm(mel_ref.unsqueeze(1)).squeeze(1) # (1,T") real_norm = _restore_batch(real_norm) # (1,T") # ───────────────── align lengths ─────────────────────────────── enc = model.text_encoder(text_ids, input_lengths, text_mask) # (1,512,L) enc_len = enc.shape[-1] # L target = enc_len * 2 # decoder expects 2×L F0_real = _match_len(F0_real, target) # (1,2L) real_norm = _match_len(real_norm, target) # (1,2L) # ───────────────── decode & save ─────────────────────────────── with torch.no_grad(): y = model.decoder(enc, F0_real, real_norm, style) # ─── make it (channels, samples) = (1, T) ──────────────────────────── y = y.squeeze(0) # (1, T) torchaudio.save("preview.wav", y.cpu(), sr) print("✅ wrote preview.wav")