styletts2 / inference_first.py
ak36's picture
Add files using upload-large-folder tool
07b5cfc verified
#!/usr/bin/env python3
"""
inference_first.py β€” quick Stage-1 sanity-check for StyleTTS-2
Example
-------
python inference_first.py \
--ckpt logs/pod_90h_30k/epoch_1st_0004.pth \
--ref data/wavs/123_abcd_part042_00.wav \
--text "<evt_gasp> Γ°Ιͺs Ιͺz ɐ tΙ›st ˈsΙ›ntΙ™ns"
It writes preview.wav in the current directory.
"""
import argparse, yaml, torch, torchaudio
from models import build_model, load_ASR_models, load_F0_models
from Utils.PLBERT.util import load_plbert
from utils import recursive_munch, log_norm, length_to_mask
from meldataset import TextCleaner, preprocess
# ────────────────────────── helpers ────────────────────────────
def _restore_batch(x):
"""(T,) β–Έ (1,T) or (C,T) β–Έ (1,C,T) (handles squeeze in JDCNet)."""
return x.unsqueeze(0) if x.dim() == 1 else x
def _match_len(x, target_len):
"""Crop or zero-pad last axis to target_len."""
cur = x.shape[-1]
if cur > target_len:
return x[..., :target_len]
if cur < target_len:
pad = target_len - cur
return torch.nn.functional.pad(x, (0, pad))
return x
# ────────────────────────── CLI ────────────────────────────────
p = argparse.ArgumentParser()
p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth")
p.add_argument("--ref", required=True, help="reference wav (24 kHz mono)")
p.add_argument("--text", required=True, help="IPA / phoneme string")
p.add_argument("--cfg", default="Configs/config_ft_single.yml")
args = p.parse_args()
# ───────────────── net & cfg ───────────────────────────────────
cfg = yaml.safe_load(open(args.cfg))
sr = cfg["preprocess_params"]["sr"]
device = "cuda"
asr = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"])
f0 = load_F0_models(cfg["F0_path"])
bert = load_plbert(cfg["PLBERT_dir"])
model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert)
state = torch.load(args.ckpt, map_location="cpu")["net"]
for k in model:
model[k].load_state_dict(state[k], strict=False)
model[k].eval().to(device)
# ───────────────── prepare inputs ──────────────────────────────
cleaner = TextCleaner()
text_ids = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device)
input_lengths = torch.LongTensor([text_ids.shape[1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
wav, _ = torchaudio.load(args.ref) # (1,N)
mel_ref = preprocess(wav.squeeze().numpy()).to(device) # (1,80,T)
style = model.style_encoder(mel_ref.unsqueeze(1)) # (1,128)
F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1))
F0_real = _restore_batch(F0_real) # (1,T')
real_norm = log_norm(mel_ref.unsqueeze(1)).squeeze(1) # (1,T")
real_norm = _restore_batch(real_norm) # (1,T")
# ───────────────── align lengths ───────────────────────────────
enc = model.text_encoder(text_ids, input_lengths, text_mask) # (1,512,L)
enc_len = enc.shape[-1] # L
target = enc_len * 2 # decoder expects 2Γ—L
F0_real = _match_len(F0_real, target) # (1,2L)
real_norm = _match_len(real_norm, target) # (1,2L)
# ───────────────── decode & save ───────────────────────────────
with torch.no_grad():
y = model.decoder(enc, F0_real, real_norm, style)
# ─── make it (channels, samples) = (1, T) ────────────────────────────
y = y.squeeze(0) # (1, T)
torchaudio.save("preview.wav", y.cpu(), sr)
print("βœ… wrote preview.wav")