|
|
|
""" |
|
inference_first.py β quick Stage-1 sanity-check for StyleTTS-2 |
|
|
|
Example |
|
------- |
|
python inference_first.py \ |
|
--ckpt logs/pod_90h_30k/epoch_1st_0004.pth \ |
|
--ref data/wavs/123_abcd_part042_00.wav \ |
|
--text "<evt_gasp> Γ°Ιͺs Ιͺz Ι tΙst ΛsΙntΙns" |
|
It writes preview.wav in the current directory. |
|
""" |
|
import argparse, yaml, torch, torchaudio |
|
from models import build_model, load_ASR_models, load_F0_models |
|
from Utils.PLBERT.util import load_plbert |
|
from utils import recursive_munch, log_norm, length_to_mask |
|
from meldataset import TextCleaner, preprocess |
|
|
|
|
|
def _restore_batch(x): |
|
"""(T,) βΈ (1,T) or (C,T) βΈ (1,C,T) (handles squeeze in JDCNet).""" |
|
return x.unsqueeze(0) if x.dim() == 1 else x |
|
|
|
def _match_len(x, target_len): |
|
"""Crop or zero-pad last axis to target_len.""" |
|
cur = x.shape[-1] |
|
if cur > target_len: |
|
return x[..., :target_len] |
|
if cur < target_len: |
|
pad = target_len - cur |
|
return torch.nn.functional.pad(x, (0, pad)) |
|
return x |
|
|
|
|
|
p = argparse.ArgumentParser() |
|
p.add_argument("--ckpt", required=True, help="epoch_1st_*.pth") |
|
p.add_argument("--ref", required=True, help="reference wav (24 kHz mono)") |
|
p.add_argument("--text", required=True, help="IPA / phoneme string") |
|
p.add_argument("--cfg", default="Configs/config_ft_single.yml") |
|
args = p.parse_args() |
|
|
|
|
|
cfg = yaml.safe_load(open(args.cfg)) |
|
sr = cfg["preprocess_params"]["sr"] |
|
device = "cuda" |
|
|
|
asr = load_ASR_models(cfg["ASR_path"], cfg["ASR_config"]) |
|
f0 = load_F0_models(cfg["F0_path"]) |
|
bert = load_plbert(cfg["PLBERT_dir"]) |
|
model = build_model(recursive_munch(cfg["model_params"]), asr, f0, bert) |
|
|
|
state = torch.load(args.ckpt, map_location="cpu")["net"] |
|
for k in model: |
|
model[k].load_state_dict(state[k], strict=False) |
|
model[k].eval().to(device) |
|
|
|
|
|
cleaner = TextCleaner() |
|
text_ids = torch.LongTensor(cleaner(args.text)).unsqueeze(0).to(device) |
|
input_lengths = torch.LongTensor([text_ids.shape[1]]).to(device) |
|
text_mask = length_to_mask(input_lengths).to(device) |
|
|
|
wav, _ = torchaudio.load(args.ref) |
|
mel_ref = preprocess(wav.squeeze().numpy()).to(device) |
|
|
|
style = model.style_encoder(mel_ref.unsqueeze(1)) |
|
|
|
F0_real, _, _ = model.pitch_extractor(mel_ref.unsqueeze(1)) |
|
F0_real = _restore_batch(F0_real) |
|
|
|
real_norm = log_norm(mel_ref.unsqueeze(1)).squeeze(1) |
|
real_norm = _restore_batch(real_norm) |
|
|
|
|
|
enc = model.text_encoder(text_ids, input_lengths, text_mask) |
|
enc_len = enc.shape[-1] |
|
target = enc_len * 2 |
|
|
|
F0_real = _match_len(F0_real, target) |
|
real_norm = _match_len(real_norm, target) |
|
|
|
|
|
with torch.no_grad(): |
|
y = model.decoder(enc, F0_real, real_norm, style) |
|
|
|
|
|
y = y.squeeze(0) |
|
|
|
torchaudio.save("preview.wav", y.cpu(), sr) |
|
print("β
wrote preview.wav") |