wasm-ar-tts / wasq /test.py
wasmdashai's picture
first commit
7694c84
raw
history blame
4.52 kB
import argparse
import os
import torch
import torchaudio
import text
import utils.make_html as html
from utils.plotting import get_spectrogram_figure
from vocoder import load_hifigan
from vocoder.hifigan.denoiser import Denoiser
from utils import get_basic_config
#default:
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test
# Examples:
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test_fp_adv
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --denoise 0.01 --out_dir samples/test_fp_adv_d
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_mse.pth --out_dir samples/test_fp_mse
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --out_dir samples/test_tc2_adv
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --denoise 0.01 --out_dir samples/test_tc2_adv_d
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_mse.pth --out_dir samples/test_tc2_mse
def test(args, text_arabic):
use_cuda_if_available = not args.cpu
device = torch.device(
'cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')
out_dir = args.out_dir
sample_rate = 22_050
# Load model
if args.model == 'fastpitch':
from models.fastpitch import FastPitch
model = FastPitch(args.checkpoint)
elif args.model == 'tacotron2':
from models.tacotron2 import Tacotron2
model = Tacotron2(args.checkpoint)
else:
raise "model type not supported"
print(f'Loaded {args.model} from: {args.checkpoint}')
model.eval()
# Load vocoder model
if args.vocoder_sd is None or args.vocoder_config is None:
config = get_basic_config()
if args.vocoder_sd is None: args.vocoder_sd = config.vocoder_state_path
if args.vocoder_config is None: args.vocoder_config = config.vocoder_config_path
vocoder = load_hifigan(
state_dict_path=args.vocoder_sd,
config_file=args.vocoder_config)
print(f'Loaded vocoder from: {args.vocoder_sd}')
model, vocoder = model.to(device), vocoder.to(device)
denoiser = Denoiser(vocoder)
# Infer spectrogram and wave
with torch.inference_mode():
mel_spec = model.ttmel(text_arabic, vowelizer=args.vowelizer)
wave = vocoder(mel_spec[None])
if args.denoise > 0:
wave = denoiser(wave, args.denoise)
# Save wave and images
if not os.path.exists(out_dir):
os.makedirs(out_dir)
print(f"Created folder: {out_dir}")
torchaudio.save(f'{out_dir}/wave.wav', wave[0].cpu(), sample_rate)
get_spectrogram_figure(mel_spec.cpu()).savefig(
f'{out_dir}/mel_spec.png')
t_phon = text.arabic_to_phonemes(text_arabic)
t_phon = text.simplify_phonemes(t_phon.replace(' ', '').replace('+', ' '))
with open(f'{out_dir}/index.html', 'w', encoding='utf-8') as f:
f.write(html.make_html_start())
f.write(html.make_h_tag("Test sample", n=1))
f.write(html.make_sample_entry2(f"./wave.wav", text_arabic, t_phon))
f.write(html.make_h_tag("Spectrogram"))
f.write(html.make_img_tag('./mel_spec.png'))
f.write(html.make_volume_script(0.42))
f.write(html.make_html_end())
print(f"Saved test sample to: {out_dir}")
if not args.do_not_play:
try:
import sounddevice as sd
sd.play(wave[0, 0].cpu(), sample_rate, blocking=True)
except:
pass
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--text', type=str,
default="ุฃูŽู„ุณู‘ูŽู„ุงู…ู ุนูŽู„ูŽูŠูƒูู… ูŠุง ุตูŽุฏูŠู‚ูŠ")
parser.add_argument('--model', type=str, default='fastpitch')
parser.add_argument(
'--checkpoint', default='pretrained/fastpitch_ar_adv.pth')
parser.add_argument('--vocoder_sd', type=str, default=None)
parser.add_argument('--vocoder_config', type=str, default=None)
parser.add_argument('--denoise', type=float, default=0)
parser.add_argument('--out_dir', default='samples/test')
parser.add_argument('--vowelizer', default=None)
parser.add_argument('--cpu', action='store_true')
parser.add_argument('--do_not_play', action='store_true')
args = parser.parse_args()
text_arabic = args.text
test(args, text_arabic)
if __name__ == '__main__':
main()