|
|
|
|
|
""" |
|
|
VITS2 Remy - Luxembourgish TTS Inference Script |
|
|
|
|
|
Usage: |
|
|
python inference.py "Moien, wéi geet et dir?" |
|
|
python inference.py "Moien, wéi geet et dir?" -o output.wav |
|
|
python inference.py "Moien, wéi geet et dir?" --noise_scale 0.5 |
|
|
""" |
|
|
import argparse |
|
|
import torch |
|
|
import scipy.io.wavfile as wavfile |
|
|
|
|
|
import utils |
|
|
import commons |
|
|
from models import SynthesizerTrn |
|
|
from text.symbols import symbols |
|
|
from text import text_to_sequence |
|
|
|
|
|
|
|
|
def get_text(text, hps): |
|
|
text_norm = text_to_sequence(text, hps.data.text_cleaners) |
|
|
if hps.data.add_blank: |
|
|
text_norm = commons.intersperse(text_norm, 0) |
|
|
text_norm = torch.LongTensor(text_norm) |
|
|
return text_norm |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser(description="VITS2 Remy TTS") |
|
|
parser.add_argument("text", type=str, help="Text to synthesize") |
|
|
parser.add_argument("-o", "--output", type=str, default="output.wav", help="Output WAV file") |
|
|
parser.add_argument("--noise_scale", type=float, default=0.667, help="Noise scale (default: 0.667)") |
|
|
parser.add_argument("--noise_scale_w", type=float, default=0.8, help="Noise scale W (default: 0.8)") |
|
|
parser.add_argument("--length_scale", type=float, default=1.0, help="Length scale (default: 1.0)") |
|
|
parser.add_argument("--cpu", action="store_true", help="Use CPU instead of GPU") |
|
|
args = parser.parse_args() |
|
|
|
|
|
device = "cpu" if args.cpu else "cuda" |
|
|
|
|
|
|
|
|
hps = utils.get_hparams_from_file("config.json") |
|
|
|
|
|
|
|
|
if getattr(hps.model, 'use_mel_posterior_encoder', False): |
|
|
posterior_channels = hps.data.n_mel_channels |
|
|
else: |
|
|
posterior_channels = hps.data.filter_length // 2 + 1 |
|
|
|
|
|
net_g = SynthesizerTrn( |
|
|
len(symbols), |
|
|
posterior_channels, |
|
|
hps.train.segment_size // hps.data.hop_length, |
|
|
n_speakers=hps.data.n_speakers, |
|
|
**hps.model |
|
|
).to(device) |
|
|
|
|
|
_ = utils.load_checkpoint("model.pth", net_g, None) |
|
|
net_g.eval() |
|
|
|
|
|
|
|
|
text = args.text.lower() |
|
|
print(f"Synthesizing: {text}") |
|
|
|
|
|
with torch.no_grad(): |
|
|
stn_tst = get_text(text, hps) |
|
|
x_tst = stn_tst.to(device).unsqueeze(0) |
|
|
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device) |
|
|
|
|
|
audio = net_g.infer( |
|
|
x_tst, x_tst_lengths, |
|
|
noise_scale=args.noise_scale, |
|
|
noise_scale_w=args.noise_scale_w, |
|
|
length_scale=args.length_scale |
|
|
)[0][0, 0].data.cpu().float().numpy() |
|
|
|
|
|
wavfile.write(args.output, hps.data.sampling_rate, audio) |
|
|
print(f"Saved to: {args.output}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|