moe-tts / MoeGoe.py
skytnt's picture
init
c2546a5
raw
history blame
4.04 kB
import sys
from torch import no_grad, LongTensor
import logging
logging.getLogger('numba').setLevel(logging.WARNING)
import commons
import utils
from models import SynthesizerTrn
from text import text_to_sequence
from mel_processing import spectrogram_torch
from scipy.io.wavfile import write
def get_text(text, hps):
text_norm = text_to_sequence(text, hps_ms.symbols, hps.data.text_cleaners)
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
text_norm = LongTensor(text_norm)
return text_norm
def ask_if_continue():
while True:
answer = input('Continue? (y/n): ')
if answer == 'y':
break
elif answer == 'n':
sys.exit(0)
def print_speakers(speakers):
print('ID\tSpeaker')
for id, name in enumerate(speakers):
print(str(id) + '\t' + name)
def get_speaker_id(message):
speaker_id = input(message)
try:
speaker_id = int(speaker_id)
except:
print(str(speaker_id) + ' is not a valid ID!')
sys.exit(1)
return speaker_id
if __name__ == '__main__':
model = input('Path of a VITS model: ')
config = input('Path of a config file: ')
try:
hps_ms = utils.get_hparams_from_file(config)
net_g_ms = SynthesizerTrn(
len(hps_ms.symbols),
hps_ms.data.filter_length // 2 + 1,
hps_ms.train.segment_size // hps_ms.data.hop_length,
n_speakers=hps_ms.data.n_speakers,
**hps_ms.model)
_ = net_g_ms.eval()
_ = utils.load_checkpoint(model, net_g_ms, None)
except:
print('Failed to load!')
sys.exit(1)
while True:
choice = input('TTS or VC? (t/v):')
if choice == 't':
text = input('Text to read: ')
try:
stn_tst = get_text(text, hps_ms)
except:
print('Invalid text!')
sys.exit(1)
print_speakers(hps_ms.speakers)
speaker_id = get_speaker_id('Speaker ID: ')
out_path = input('Path to save: ')
try:
with no_grad():
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = LongTensor([stn_tst.size(0)])
sid = LongTensor([speaker_id])
audio = net_g_ms.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()
write(out_path, hps_ms.data.sampling_rate, audio)
except:
print('Failed to generate!')
sys.exit(1)
print('Successfully saved!')
ask_if_continue()
elif choice == 'v':
wav_path = input('Path of a WAV file (22050 Hz, 16 bits, 1 channel) to convert:\n')
print_speakers(hps_ms.speakers)
audio, sampling_rate = utils.load_wav_to_torch(wav_path)
originnal_id = get_speaker_id('Original speaker ID: ')
target_id = get_speaker_id('Target speaker ID: ')
out_path = input('Path to save: ')
y = audio / hps_ms.data.max_wav_value
y = y.unsqueeze(0)
spec = spectrogram_torch(y, hps_ms.data.filter_length,
hps_ms.data.sampling_rate, hps_ms.data.hop_length, hps_ms.data.win_length,
center=False)
spec_lengths = LongTensor([spec.size(-1)])
sid_src = LongTensor([originnal_id])
try:
with no_grad():
sid_tgt = LongTensor([target_id])
audio = net_g_ms.voice_conversion(spec, spec_lengths, sid_src=sid_src, sid_tgt=sid_tgt)[0][0,0].data.cpu().float().numpy()
write(out_path, hps_ms.data.sampling_rate, audio)
except:
print('Failed to generate!')
sys.exit(1)
print('Successfully saved!')
ask_if_continue()