File size: 4,524 Bytes
7694c84
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import os
import torch
import torchaudio

import text
import utils.make_html as html
from utils.plotting import get_spectrogram_figure
from vocoder import load_hifigan
from vocoder.hifigan.denoiser import Denoiser
from utils import get_basic_config

#default:
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test

# Examples:
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --out_dir samples/test_fp_adv
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_adv.pth --denoise 0.01 --out_dir samples/test_fp_adv_d
# python test.py --model fastpitch --checkpoint pretrained/fastpitch_ar_mse.pth --out_dir samples/test_fp_mse

# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --out_dir samples/test_tc2_adv
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_adv.pth --denoise 0.01 --out_dir samples/test_tc2_adv_d
# python test.py --model tacotron2 --checkpoint pretrained/tacotron2_ar_mse.pth --out_dir samples/test_tc2_mse


def test(args, text_arabic):    

    use_cuda_if_available = not args.cpu
    device = torch.device(
        'cuda' if torch.cuda.is_available() and use_cuda_if_available else 'cpu')
    out_dir = args.out_dir
    sample_rate = 22_050

    # Load model
    if args.model == 'fastpitch':
        from models.fastpitch import FastPitch
        model = FastPitch(args.checkpoint)
    elif args.model == 'tacotron2':
        from models.tacotron2 import Tacotron2
        model = Tacotron2(args.checkpoint)
    else:
        raise "model type not supported"

    print(f'Loaded {args.model} from: {args.checkpoint}')
    model.eval()

    # Load vocoder model
    if args.vocoder_sd is None or args.vocoder_config is None:
        config = get_basic_config()
        if args.vocoder_sd is None: args.vocoder_sd = config.vocoder_state_path
        if args.vocoder_config is None: args.vocoder_config = config.vocoder_config_path
    vocoder = load_hifigan(
        state_dict_path=args.vocoder_sd,
        config_file=args.vocoder_config)
    print(f'Loaded vocoder from: {args.vocoder_sd}')

    model, vocoder = model.to(device), vocoder.to(device)
    denoiser = Denoiser(vocoder)

    # Infer spectrogram and wave
    with torch.inference_mode():
        mel_spec = model.ttmel(text_arabic, vowelizer=args.vowelizer)
        wave = vocoder(mel_spec[None])
        if args.denoise > 0:
            wave = denoiser(wave, args.denoise)            

    # Save wave and images
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
        print(f"Created folder: {out_dir}")

    torchaudio.save(f'{out_dir}/wave.wav', wave[0].cpu(), sample_rate)

    get_spectrogram_figure(mel_spec.cpu()).savefig(
        f'{out_dir}/mel_spec.png')

    t_phon = text.arabic_to_phonemes(text_arabic)
    t_phon = text.simplify_phonemes(t_phon.replace(' ', '').replace('+', ' '))

    with open(f'{out_dir}/index.html', 'w', encoding='utf-8') as f:
        f.write(html.make_html_start())
        f.write(html.make_h_tag("Test sample", n=1))
        f.write(html.make_sample_entry2(f"./wave.wav", text_arabic, t_phon))
        f.write(html.make_h_tag("Spectrogram"))
        f.write(html.make_img_tag('./mel_spec.png'))
        f.write(html.make_volume_script(0.42))
        f.write(html.make_html_end())

    print(f"Saved test sample to: {out_dir}")

    if not args.do_not_play:
        try:  
            import sounddevice as sd
            sd.play(wave[0, 0].cpu(), sample_rate, blocking=True)
        except:
            pass


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--text', type=str, 
                        default="ุฃูŽู„ุณู‘ูŽู„ุงู…ู ุนูŽู„ูŽูŠูƒูู… ูŠุง ุตูŽุฏูŠู‚ูŠ")
    parser.add_argument('--model', type=str, default='fastpitch')
    parser.add_argument(
        '--checkpoint', default='pretrained/fastpitch_ar_adv.pth') 
    parser.add_argument('--vocoder_sd', type=str, default=None)
    parser.add_argument('--vocoder_config', type=str, default=None)  
    parser.add_argument('--denoise', type=float, default=0)  
    parser.add_argument('--out_dir', default='samples/test')
    parser.add_argument('--vowelizer', default=None)
    parser.add_argument('--cpu', action='store_true')

    parser.add_argument('--do_not_play', action='store_true')   
    args = parser.parse_args()

    text_arabic = args.text

    test(args, text_arabic)


if __name__ == '__main__':
    main()