Spaces:
Sleeping
Sleeping
import re | |
import argparse | |
from string import punctuation | |
import torch | |
import yaml | |
import numpy as np | |
from torch.utils.data import DataLoader | |
from g2p_en import G2p | |
from pypinyin import pinyin, Style | |
from utils.model import get_model, get_vocoder | |
from utils.tools import to_device, synth_samples, get_roberta_emotion_embeddings | |
from dataset import TextDataset | |
from text import text_to_sequence | |
from transformers import RobertaTokenizerFast, AutoModel, AutoModelForSequenceClassification | |
ro_model = "/content/FastSpeech2_Text_Aware_Emotion_TTS/roberta_pretrained" | |
roberta_tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base') | |
roberta_model = AutoModelForSequenceClassification.from_pretrained(ro_model) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def read_lexicon(lex_path): | |
lexicon = {} | |
with open(lex_path) as f: | |
for line in f: | |
temp = re.split(r"\s+", line.strip("\n")) | |
word = temp[0] | |
phones = temp[1:] | |
if word.lower() not in lexicon: | |
lexicon[word.lower()] = phones | |
return lexicon | |
def preprocess_english(text, preprocess_config): | |
text = text.rstrip(punctuation) | |
lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) | |
g2p = G2p() | |
phones = [] | |
words = re.split(r"([,;.\-\?\!\s+])", text) | |
for w in words: | |
if w.lower() in lexicon: | |
phones += lexicon[w.lower()] | |
else: | |
phones += list(filter(lambda p: p != " ", g2p(w))) | |
phones = "{" + "}{".join(phones) + "}" | |
phones = re.sub(r"\{[^\w\s]?\}", "{sp}", phones) | |
phones = phones.replace("}{", " ") | |
print("Raw Text Sequence: {}".format(text)) | |
print("Phoneme Sequence: {}".format(phones)) | |
sequence = np.array( | |
text_to_sequence( | |
phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] | |
) | |
) | |
return np.array(sequence) | |
def preprocess_mandarin(text, preprocess_config): | |
lexicon = read_lexicon(preprocess_config["path"]["lexicon_path"]) | |
phones = [] | |
pinyins = [ | |
p[0] | |
for p in pinyin( | |
text, style=Style.TONE3, strict=False, neutral_tone_with_five=True | |
) | |
] | |
for p in pinyins: | |
if p in lexicon: | |
phones += lexicon[p] | |
else: | |
phones.append("sp") | |
phones = "{" + " ".join(phones) + "}" | |
print("Raw Text Sequence: {}".format(text)) | |
print("Phoneme Sequence: {}".format(phones)) | |
sequence = np.array( | |
text_to_sequence( | |
phones, preprocess_config["preprocessing"]["text"]["text_cleaners"] | |
) | |
) | |
return np.array(sequence) | |
def synthesize(model, step, configs, vocoder, batchs, control_values): | |
preprocess_config, model_config, train_config = configs | |
pitch_control, energy_control, duration_control = control_values | |
for batch in batchs: | |
batch = to_device(batch, device) | |
with torch.no_grad(): | |
# Forward | |
output = model( | |
*(batch[2:]), | |
p_control=pitch_control, | |
e_control=energy_control, | |
d_control=duration_control | |
) | |
synth_samples( | |
batch, | |
output, | |
vocoder, | |
model_config, | |
preprocess_config, | |
train_config["path"]["result_path"], | |
) | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--restore_step", type=int, required=True) | |
parser.add_argument( | |
"--mode", | |
type=str, | |
choices=["batch", "single"], | |
required=True, | |
help="Synthesize a whole dataset or a single sentence", | |
) | |
parser.add_argument( | |
"--source", | |
type=str, | |
default=None, | |
help="path to a source file with format like train.txt and val.txt, for batch mode only", | |
) | |
parser.add_argument( | |
"--text", | |
type=str, | |
default=None, | |
help="raw text to synthesize, for single-sentence mode only", | |
) | |
parser.add_argument( | |
"--speaker_id", | |
type=int, | |
default=0, | |
help="speaker ID for multi-speaker synthesis, for single-sentence mode only", | |
) | |
parser.add_argument( | |
"--emotion_id", | |
type=int, | |
default=0, | |
help="emotion ID for multi-emotion synthesis, for single-sentence mode only", | |
) | |
parser.add_argument( | |
"--bert_embed", | |
type=int, | |
default=0, | |
help="Use bert embedings to control sentiment", | |
) | |
parser.add_argument( | |
"-p", | |
"--preprocess_config", | |
type=str, | |
required=True, | |
help="path to preprocess.yaml", | |
) | |
parser.add_argument( | |
"-m", "--model_config", type=str, required=True, help="path to model.yaml" | |
) | |
parser.add_argument( | |
"-t", "--train_config", type=str, required=True, help="path to train.yaml" | |
) | |
parser.add_argument( | |
"--pitch_control", | |
type=float, | |
default=1.0, | |
help="control the pitch of the whole utterance, larger value for higher pitch", | |
) | |
parser.add_argument( | |
"--energy_control", | |
type=float, | |
default=1.0, | |
help="control the energy of the whole utterance, larger value for larger volume", | |
) | |
parser.add_argument( | |
"--duration_control", | |
type=float, | |
default=1.0, | |
help="control the speed of the whole utterance, larger value for slower speaking rate", | |
) | |
args = parser.parse_args() | |
# Check source texts | |
if args.mode == "batch": | |
assert args.source is not None and args.text is None | |
if args.mode == "single": | |
assert args.source is None and args.text is not None | |
# Read Config | |
preprocess_config = yaml.load( | |
open(args.preprocess_config, "r"), Loader=yaml.FullLoader | |
) | |
model_config = yaml.load( | |
open(args.model_config, "r"), Loader=yaml.FullLoader) | |
train_config = yaml.load( | |
open(args.train_config, "r"), Loader=yaml.FullLoader) | |
configs = (preprocess_config, model_config, train_config) | |
# Get model | |
model = get_model(args, configs, device, train=False) | |
# Load vocoder | |
vocoder = get_vocoder(model_config, device) | |
# Preprocess texts | |
if args.mode == "batch": | |
# Get dataset | |
dataset = TextDataset(args.source, preprocess_config) | |
batchs = DataLoader( | |
dataset, | |
batch_size=8, | |
collate_fn=dataset.collate_fn, | |
) | |
if args.mode == "single": | |
if np.array([args.bert_embed]) == 0: | |
emotions = np.array([args.emotion_id]) | |
# print(f'FS2 emotions: {emotions}') | |
else: | |
emotions = get_roberta_emotion_embeddings(roberta_tokenizer, roberta_model, args.text) | |
emotions = torch.argmax(emotions, dim=1).cpu().numpy() | |
# print(f'RoBERTa emotions {emotions}') | |
ids = raw_texts = [args.text[:100]] | |
speakers = np.array([args.speaker_id]) | |
if preprocess_config["preprocessing"]["text"]["language"] == "en": | |
texts = np.array( | |
[preprocess_english(args.text, preprocess_config)]) | |
elif preprocess_config["preprocessing"]["text"]["language"] == "zh": | |
texts = np.array( | |
[preprocess_mandarin(args.text, preprocess_config)]) | |
text_lens = np.array([len(texts[0])]) | |
batchs = [(ids, raw_texts, speakers, texts, | |
text_lens, max(text_lens), emotions)] | |
control_values = args.pitch_control, args.energy_control, args.duration_control | |
synthesize(model, args.restore_step, configs, | |
vocoder, batchs, control_values) | |