Spaces:
Sleeping
Sleeping
import argparse | |
import os | |
import torch | |
import torchaudio | |
import text | |
import utils.make_html as html | |
from utils.plotting import get_spectrogram_figure | |
from vocoder import load_hifigan | |
from vocoder.hifigan.denoiser import Denoiser | |
from utils import get_basic_config | |
#default: | |
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test | |
# Examples: | |
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test_fp_adv | |
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --denoise 0.01 --out_dir samples/test_fp_adv_d | |
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_mse.pth --out_dir samples/test_fp_mse | |
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --out_dir samples/test_tc2_adv | |
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --denoise 0.01 --out_dir samples/test_tc2_adv_d | |
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_mse.pth --out_dir samples/test_tc2_mse | |
def test(args, text_arabic): | |
use_cuda_if_available = not args.cpu | |
device = torch.device( | |
'cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu') | |
out_dir = args.out_dir | |
sample_rate = 22_050 | |
# Load model | |
if args.model == 'fastpitch': | |
from models.fastpitch import FastPitch | |
model = FastPitch(args.checkpoint) | |
elif args.model == 'tacotron2': | |
from models.tacotron2 import Tacotron2 | |
model = Tacotron2(args.checkpoint) | |
else: | |
raise "model type not supported" | |
print(f'Loaded {args.model} from: {args.checkpoint}') | |
model.eval() | |
# Load vocoder model | |
if args.vocoder_sd is None or args.vocoder_config is None: | |
config = get_basic_config() | |
if args.vocoder_sd is None: args.vocoder_sd = config.vocoder_state_path | |
if args.vocoder_config is None: args.vocoder_config = config.vocoder_config_path | |
vocoder = load_hifigan( | |
state_dict_path=args.vocoder_sd, | |
config_file=args.vocoder_config) | |
print(f'Loaded vocoder from: {args.vocoder_sd}') | |
model, vocoder = model.to(device), vocoder.to(device) | |
denoiser = Denoiser(vocoder) | |
# Infer spectrogram and wave | |
with torch.inference_mode(): | |
mel_spec = model.ttmel(text_arabic, vowelizer=args.vowelizer) | |
wave = vocoder(mel_spec[None]) | |
if args.denoise > 0: | |
wave = denoiser(wave, args.denoise) | |
# Save wave and images | |
if not os.path.exists(out_dir): | |
os.makedirs(out_dir) | |
print(f"Created folder: {out_dir}") | |
torchaudio.save(f'{out_dir}/wave.wav', wave[0].cpu(), sample_rate) | |
get_spectrogram_figure(mel_spec.cpu()).savefig( | |
f'{out_dir}/mel_spec.png') | |
t_phon = text.arabic_to_phonemes(text_arabic) | |
t_phon = text.simplify_phonemes(t_phon.replace(' ', '').replace('+', ' ')) | |
with open(f'{out_dir}/index.html', 'w', encoding='utf-8') as f: | |
f.write(html.make_html_start()) | |
f.write(html.make_h_tag("Test sample", n=1)) | |
f.write(html.make_sample_entry2(f"./wave.wav", text_arabic, t_phon)) | |
f.write(html.make_h_tag("Spectrogram")) | |
f.write(html.make_img_tag('./mel_spec.png')) | |
f.write(html.make_volume_script(0.42)) | |
f.write(html.make_html_end()) | |
print(f"Saved test sample to: {out_dir}") | |
if not args.do_not_play: | |
try: | |
import sounddevice as sd | |
sd.play(wave[0, 0].cpu(), sample_rate, blocking=True) | |
except: | |
pass | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('--text', type=str, | |
default="ุฃููุณูููุงู ู ุนููููููู ูุง ุตูุฏููู") | |
parser.add_argument('--model', type=str, default='fastpitch') | |
parser.add_argument( | |
'--checkpoint', default='pretrained/fastpitch_ar_adv.pth') | |
parser.add_argument('--vocoder_sd', type=str, default=None) | |
parser.add_argument('--vocoder_config', type=str, default=None) | |
parser.add_argument('--denoise', type=float, default=0) | |
parser.add_argument('--out_dir', default='samples/test') | |
parser.add_argument('--vowelizer', default=None) | |
parser.add_argument('--cpu', action='store_true') | |
parser.add_argument('--do_not_play', action='store_true') | |
args = parser.parse_args() | |
text_arabic = args.text | |
test(args, text_arabic) | |
if __name__ == '__main__': | |
main() | |