import re import argparse from string import punctuation import torch import yaml import numpy as np from torch.utils.data import DataLoader from g2p_en import G2p from pypinyin import pinyin, Style from utils.model import get_model, get_vocoder from utils.tools import to_device, synth_samples, get_roberta_emotion_embeddings from dataset import TextDataset from text import text_to_sequence from transformers import RobertaTokenizerFast, AutoModel, AutoModelForSequenceClassification ro_model = "roberta_pretrained" roberta_tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') roberta_model = AutoModelForSequenceClassification.from_pretrained(ro_model) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def read_lexicon(lex_path): lexicon = {} with open(lex_path) as f: for line in f: temp = re.split(r"\s+", line.strip("\n")) word = temp[0] phones = temp[1:] if word.lower() not in lexicon: lexicon[word.lower()] = phones return lexicon def preprocess_english(text, preprocess_config): text = text.rstrip(punctuation) lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) g2p = G2p() phones = [] words = re.split(r"([,;.\-\?\!\s+])", text) for w in words: if w.lower() in lexicon: phones += lexicon[w.lower()] else: phones += list(filter(lambda p: p != " ", g2p(w))) phones = "{" + "}{".join(phones) + "}" phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) phones = phones.replace("}{", " ") print("Raw Text Sequence: {}".format(text)) print("Phoneme Sequence: {}".format(phones)) sequence = np.array( text_to_sequence( phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] ) ) return np.array(sequence) def preprocess_mandarin(text, preprocess_config): lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) phones = [] pinyins = [ p[0] for p in pinyin( text, style=Style.TONE3, strict=False, neutral_tone_with_five=True ) ] for p in pinyins: if p in lexicon: phones += lexicon[p] else: phones.append("sp") phones = "{" + " ".join(phones) + "}" print("Raw Text Sequence: {}".format(text)) print("Phoneme Sequence: {}".format(phones)) sequence = np.array( text_to_sequence( phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] ) ) return np.array(sequence) def synthesize(model, step, configs, vocoder, batchs, control_values): preprocess_config, model_config, train_config = configs pitch_control, energy_control, duration_control = control_values for batch in batchs: batch = to_device(batch, device) with torch.no_grad(): # Forward output = model( *(batch[2:]), p_control=pitch_control, e_control=energy_control, d_control=duration_control ) synth_samples( batch, output, vocoder, model_config, preprocess_config, train_config["path"]["result_path"], ) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--restore_step", type=int, required=True) parser.add_argument( "--mode", type=str, choices=["batch", "single"], required=True, help="Synthesize a whole dataset or a single sentence", ) parser.add_argument( "--source", type=str, default=None, help="path to a source file with format like train.txt and val.txt, for batch mode only", ) parser.add_argument( "--text", type=str, default=None, help="raw text to synthesize, for single-sentence mode only", ) parser.add_argument( "--speaker_id", type=int, default=0, help="speaker ID for multi-speaker synthesis, for single-sentence mode only", ) parser.add_argument( "--emotion_id", type=int, default=0, help="emotion ID for multi-emotion synthesis, for single-sentence mode only", ) parser.add_argument( "--bert_embed", type=int, default=0, help="Use bert embedings to control sentiment", ) parser.add_argument( "-p", "--preprocess_config", type=str, required=True, help="path to preprocess.yaml", ) parser.add_argument( "-m", "--model_config", type=str, required=True, help="path to model.yaml" ) parser.add_argument( "-t", "--train_config", type=str, required=True, help="path to train.yaml" ) parser.add_argument( "--pitch_control", type=float, default=1.0, help="control the pitch of the whole utterance, larger value for higher pitch", ) parser.add_argument( "--energy_control", type=float, default=1.0, help="control the energy of the whole utterance, larger value for larger volume", ) parser.add_argument( "--duration_control", type=float, default=1.0, help="control the speed of the whole utterance, larger value for slower speaking rate", ) args = parser.parse_args() # Check source texts if args.mode == "batch": assert args.source is not None and args.text is None if args.mode == "single": assert args.source is None and args.text is not None # Read Config preprocess_config = yaml.load( open(args.preprocess_config, "r"), Loader=yaml.FullLoader ) model_config = yaml.load( open(args.model_config, "r"), Loader=yaml.FullLoader) train_config = yaml.load( open(args.train_config, "r"), Loader=yaml.FullLoader) configs = (preprocess_config, model_config, train_config) # Get model model = get_model(args, configs, device, train=False) # Load vocoder vocoder = get_vocoder(model_config, device) # Preprocess texts if args.mode == "batch": # Get dataset dataset = TextDataset(args.source, preprocess_config) batchs = DataLoader( dataset, batch_size=8, collate_fn=dataset.collate_fn, ) if args.mode == "single": if np.array([args.bert_embed]) == 0: emotions = np.array([args.emotion_id]) # print(f'FS2 emotions: {emotions}') else: emotions = get_roberta_emotion_embeddings( roberta_tokenizer, roberta_model, args.text) emotions = torch.argmax(emotions, dim=1).cpu().numpy() # print(f'RoBERTa emotions {emotions}') ids = raw_texts = [args.text[:100]] speakers = np.array([args.speaker_id]) if preprocess_config["preprocessing"]["text"]["language"] == "en": texts = np.array( [preprocess_english(args.text, preprocess_config)]) elif preprocess_config["preprocessing"]["text"]["language"] == "zh": texts = np.array( [preprocess_mandarin(args.text, preprocess_config)]) text_lens = np.array([len(texts[0])]) batchs = [(ids, raw_texts, speakers, texts, text_lens, max(text_lens), emotions)] control_values = args.pitch_control, args.energy_control, args.duration_control synthesize(model, args.restore_step, configs, vocoder, batchs, control_values)