from cached_path import cached_path import torch torch.manual_seed(0) torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True import random random.seed(0) import numpy as np np.random.seed(0) import nltk nltk.download('punkt') # load packages import time import random import yaml from munch import Munch import numpy as np import torch from torch import nn import torch.nn.functional as F import torchaudio import librosa from nltk.tokenize import word_tokenize from models import * from utils import * from text_utils import TextCleaner textclenaer = TextCleaner() device = 'cuda' if torch.cuda.is_available() else 'cpu' to_mel = torchaudio.transforms.MelSpectrogram( n_mels=80, n_fft=2048, win_length=1200, hop_length=300) mean, std = -4, 4 def length_to_mask(lengths): mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) mask = torch.gt(mask+1, lengths.unsqueeze(1)) return mask def preprocess(wave): wave_tensor = torch.from_numpy(wave).float() mel_tensor = to_mel(wave_tensor) mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std return mel_tensor def compute_style(ref_dicts): reference_embeddings = {} for key, path in ref_dicts.items(): wave, sr = librosa.load(path, sr=24000) audio, index = librosa.effects.trim(wave, top_db=30) if sr != 24000: audio = librosa.resample(audio, sr, 24000) mel_tensor = preprocess(audio).to(device) with torch.no_grad(): ref = model.style_encoder(mel_tensor.unsqueeze(1)) reference_embeddings[key] = (ref.squeeze(1), audio) return reference_embeddings # load phonemizer # import phonemizer # global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore') # phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt'))) import fugashi import pykakasi from collections import OrderedDict # MB-iSTFT-VITS2 import re from unidecode import unidecode import pyopenjtalk # Regular expression matching Japanese without punctuation marks: _japanese_characters = re.compile( r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') # Regular expression matching non-Japanese characters or punctuation marks: _japanese_marks = re.compile( r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') # List of (symbol, Japanese) pairs for marks: _symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ ('%', 'パーセント') ]] # List of (romaji, ipa) pairs for marks: _romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ ('ts', 'ʦ'), ('u', 'ɯ'), ('j', 'ʥ'), ('y', 'j'), ('ni', 'n^i'), ('nj', 'n^'), ('hi', 'çi'), ('hj', 'ç'), ('f', 'ɸ'), ('I', 'i*'), ('U', 'ɯ*'), ('r', 'ɾ') ]] # List of (romaji, ipa2) pairs for marks: _romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ ('u', 'ɯ'), ('ʧ', 'tʃ'), ('j', 'dʑ'), ('y', 'j'), ('ni', 'n^i'), ('nj', 'n^'), ('hi', 'çi'), ('hj', 'ç'), ('f', 'ɸ'), ('I', 'i*'), ('U', 'ɯ*'), ('r', 'ɾ') ]] # List of (consonant, sokuon) pairs: _real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ (r'Q([↑↓]*[kg])', r'k#\1'), (r'Q([↑↓]*[tdjʧ])', r't#\1'), (r'Q([↑↓]*[sʃ])', r's\1'), (r'Q([↑↓]*[pb])', r'p#\1') ]] # List of (consonant, hatsuon) pairs: _real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ (r'N([↑↓]*[pbm])', r'm\1'), (r'N([↑↓]*[ʧʥj])', r'n^\1'), (r'N([↑↓]*[tdn])', r'n\1'), (r'N([↑↓]*[kg])', r'ŋ\1') ]] def symbols_to_japanese(text): for regex, replacement in _symbols_to_japanese: text = re.sub(regex, replacement, text) return text def japanese_to_romaji_with_accent(text): '''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' text = symbols_to_japanese(text) sentences = re.split(_japanese_marks, text) marks = re.findall(_japanese_marks, text) text = '' for i, sentence in enumerate(sentences): if re.match(_japanese_characters, sentence): if text != '': text += ' ' labels = pyopenjtalk.extract_fullcontext(sentence) for n, label in enumerate(labels): phoneme = re.search(r'\-([^\+]*)\+', label).group(1) if phoneme not in ['sil', 'pau']: text += phoneme.replace('ch', 'ʧ').replace('sh', 'ʃ').replace('cl', 'Q') else: continue # n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) a2 = int(re.search(r"\+(\d+)\+", label).group(1)) a3 = int(re.search(r"\+(\d+)/", label).group(1)) if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: a2_next = -1 else: a2_next = int( re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) # Accent phrase boundary if a3 == 1 and a2_next == 1: text += ' ' # Falling elif a1 == 0 and a2_next == a2 + 1: text += '↓' # Rising elif a2 == 1 and a2_next == 2: text += '↑' if i < len(marks): text += unidecode(marks[i]).replace(' ', '') return text def get_real_sokuon(text): for regex, replacement in _real_sokuon: text = re.sub(regex, replacement, text) return text def get_real_hatsuon(text): for regex, replacement in _real_hatsuon: text = re.sub(regex, replacement, text) return text def japanese_to_ipa(text): text = japanese_to_romaji_with_accent(text).replace('...', '…') text = re.sub( r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) text = get_real_sokuon(text) text = get_real_hatsuon(text) for regex, replacement in _romaji_to_ipa: text = re.sub(regex, replacement, text) return text def japanese_to_ipa2(text): text = japanese_to_romaji_with_accent(text).replace('...', '…') text = get_real_sokuon(text) text = get_real_hatsuon(text) for regex, replacement in _romaji_to_ipa2: text = re.sub(regex, replacement, text) return text def japanese_to_ipa3(text): text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( 'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') text = re.sub( r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) return text """ from https://github.com/keithito/tacotron """ ''' Cleaners are transformations that run over the input text at both training and eval time. Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" hyperparameter. Some cleaners are English-specific. You'll typically want to use: 1. "english_cleaners" for English text 2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using the Unidecode library (https://pypi.python.org/pypi/Unidecode) 3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update the symbols in symbols.py to match your data). ''' # Regular expression matching whitespace: import re import inflect from unidecode import unidecode _inflect = inflect.engine() _comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') _decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') _pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') _dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') _ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') _number_re = re.compile(r'[0-9]+') # List of (regular expression, replacement) pairs for abbreviations: _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ ('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]] # List of (ipa, lazy ipa) pairs: _lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ ('r', 'ɹ'), ('æ', 'e'), ('ɑ', 'a'), ('ɔ', 'o'), ('ð', 'z'), ('θ', 's'), ('ɛ', 'e'), ('ɪ', 'i'), ('ʊ', 'u'), ('ʒ', 'ʥ'), ('ʤ', 'ʥ'), ('', '↓'), ]] # List of (ipa, lazy ipa2) pairs: _lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ ('r', 'ɹ'), ('ð', 'z'), ('θ', 's'), ('ʒ', 'ʑ'), ('ʤ', 'dʑ'), ('', '↓'), ]] # List of (ipa, ipa2) pairs _ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ ('r', 'ɹ'), ('ʤ', 'dʒ'), ('ʧ', 'tʃ') ]] def expand_abbreviations(text): for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text def collapse_whitespace(text): return re.sub(r'\s+', ' ', text) def _remove_commas(m): return m.group(1).replace(',', '') def _expand_decimal_point(m): return m.group(1).replace('.', ' point ') def _expand_dollars(m): match = m.group(1) parts = match.split('.') if len(parts) > 2: return match + ' dollars' # Unexpected format dollars = int(parts[0]) if parts[0] else 0 cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 if dollars and cents: dollar_unit = 'dollar' if dollars == 1 else 'dollars' cent_unit = 'cent' if cents == 1 else 'cents' return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) elif dollars: dollar_unit = 'dollar' if dollars == 1 else 'dollars' return '%s %s' % (dollars, dollar_unit) elif cents: cent_unit = 'cent' if cents == 1 else 'cents' return '%s %s' % (cents, cent_unit) else: return 'zero dollars' def _expand_ordinal(m): return _inflect.number_to_words(m.group(0)) def _expand_number(m): num = int(m.group(0)) if num > 1000 and num < 3000: if num == 2000: return 'two thousand' elif num > 2000 and num < 2010: return 'two thousand ' + _inflect.number_to_words(num % 100) elif num % 100 == 0: return _inflect.number_to_words(num // 100) + ' hundred' else: return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') else: return _inflect.number_to_words(num, andword='') def normalize_numbers(text): text = re.sub(_comma_number_re, _remove_commas, text) text = re.sub(_pounds_re, r'\1 pounds', text) text = re.sub(_dollars_re, _expand_dollars, text) text = re.sub(_decimal_number_re, _expand_decimal_point, text) text = re.sub(_ordinal_re, _expand_ordinal, text) text = re.sub(_number_re, _expand_number, text) return text def mark_dark_l(text): return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) import re #from text.thai import num_to_thai, latin_to_thai #from text.shanghainese import shanghainese_to_ipa #from text.cantonese import cantonese_to_ipa #from text.ngu_dialect import ngu_dialect_to_ipa from unidecode import unidecode _whitespace_re = re.compile(r'\s+') # Regular expression matching Japanese without punctuation marks: _japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') # Regular expression matching non-Japanese characters or punctuation marks: _japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') # List of (regular expression, replacement) pairs for abbreviations: _abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ ('mrs', 'misess'), ('mr', 'mister'), ('dr', 'doctor'), ('st', 'saint'), ('co', 'company'), ('jr', 'junior'), ('maj', 'major'), ('gen', 'general'), ('drs', 'doctors'), ('rev', 'reverend'), ('lt', 'lieutenant'), ('hon', 'honorable'), ('sgt', 'sergeant'), ('capt', 'captain'), ('esq', 'esquire'), ('ltd', 'limited'), ('col', 'colonel'), ('ft', 'fort'), ]] def expand_abbreviations(text): for regex, replacement in _abbreviations: text = re.sub(regex, replacement, text) return text def collapse_whitespace(text): return re.sub(_whitespace_re, ' ', text) def convert_to_ascii(text): return unidecode(text) def basic_cleaners(text): # - For replication of https://github.com/FENRlR/MB-iSTFT-VITS2/issues/2 # you may need to replace the symbol to Russian one '''Basic pipeline that lowercases and collapses whitespace without transliteration.''' text = text.lower() text = collapse_whitespace(text) return text ''' def fix_g2pk2_error(text): new_text = "" i = 0 while i < len(text) - 4: if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ': new_text += text[i:i+3] + ' ' + 'ㄴ' i += 5 else: new_text += text[i] i += 1 new_text += text[i:] return new_text ''' def japanese_cleaners(text): text = japanese_to_romaji_with_accent(text) text = re.sub(r'([A-Za-z])$', r'\1.', text) return text def japanese_cleaners2(text): return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') def japanese_cleaners3(text): text = japanese_to_ipa3(text) if "<<" in text or ">>" in text or "¡" in text or "¿" in text: text = text.replace("<<","«") text = text.replace(">>","»") text = text.replace("!","¡") text = text.replace("?","¿") if'"'in text: text = text.replace('"','”') if'--'in text: text = text.replace('--','—') if ' ' in text: text = text.replace(' ','') return text # ------------------------------ ''' cjke type cleaners below ''' #- text for these cleaners must be labeled first # ex1 (single) : some.wav|[EN]put some text here[EN] # ex2 (multi) : some.wav|0|[EN]put some text here[EN] # ------------------------------ def kej_cleaners(text): text = re.sub(r'\[KO\](.*?)\[KO\]', lambda x: korean_to_ipa(x.group(1))+' ', text) text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_ipa2(x.group(1)) + ' ', text) text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa2(x.group(1)) + ' ', text) text = re.sub(r'\s+$', '', text) text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) return text def cjks_cleaners(text): text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa(x.group(1))+' ', text) #text = re.sub(r'\[SA\](.*?)\[SA\]', # lambda x: devanagari_to_ipa(x.group(1))+' ', text) text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_lazy_ipa(x.group(1))+' ', text) text = re.sub(r'\s+$', '', text) text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) return text ''' #- reserves def thai_cleaners(text): text = num_to_thai(text) text = latin_to_thai(text) return text def shanghainese_cleaners(text): text = shanghainese_to_ipa(text) text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) return text def chinese_dialect_cleaners(text): text = re.sub(r'\[ZH\](.*?)\[ZH\]', lambda x: chinese_to_ipa2(x.group(1))+' ', text) text = re.sub(r'\[JA\](.*?)\[JA\]', lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text) text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5', '˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text) text = re.sub(r'\[GD\](.*?)\[GD\]', lambda x: cantonese_to_ipa(x.group(1))+' ', text) text = re.sub(r'\[EN\](.*?)\[EN\]', lambda x: english_to_lazy_ipa2(x.group(1))+' ', text) text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group( 1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text) text = re.sub(r'\s+$', '', text) text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) return text ''' def japanese_cleaners3(text): global orig orig = text # saving the original unmodifed text for future use text = japanese_to_ipa2(text) if '' in text: text = text.replace('','') if "<<" in text or ">>" in text or "¡" in text or "¿" in text: text = text.replace("<<","«") text = text.replace(">>","»") text = text.replace("!","¡") text = text.replace("?","¿") if'"'in text: text = text.replace('"','”') if'--'in text: text = text.replace('--','—') text = text.replace("#","ʔ") text = text.replace("^","") text = text.replace("kj","kʲ") text = text.replace("kj","kʲ") text = text.replace("ɾj","ɾʲ") text = text.replace("mj","mʲ") text = text.replace("ʃ","ɕ") text = text.replace("*","") text = text.replace("bj","bʲ") text = text.replace("h","ç") text = text.replace("gj","gʲ") return text def japanese_cleaners4(text): text = japanese_cleaners3(text) if "にゃ" in orig: text = text.replace("na","nʲa") elif "にゅ" in orig: text = text.replace("n","nʲ") elif "にょ" in orig: text = text.replace("n","nʲ") elif "にぃ" in orig: text = text.replace("ni i","niː") elif "いゃ" in orig: text = text.replace("i↑ja","ja") elif "いゃ" in orig: text = text.replace("i↑ja","ja") elif "ひょ" in orig: text = text.replace("ço","çʲo") elif "しょ" in orig: text = text.replace("ɕo","ɕʲo") text = text.replace("Q","ʔ") text = text.replace("N","ɴ") text = re.sub(r'.ʔ', 'ʔ', text) text = text.replace('" ', '"') text = text.replace('” ', '”') return text config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml')))) # load pretrained ASR model ASR_config = config.get('ASR_config', False) ASR_path = config.get('ASR_path', False) text_aligner = load_ASR_models(ASR_path, ASR_config) # load pretrained F0 model F0_path = config.get('F0_path', False) pitch_extractor = load_F0_models(F0_path) # load BERT model from Utils.PLBERT.util import load_plbert BERT_path = config.get('PLBERT_dir', False) plbert = load_plbert(BERT_path) model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) _ = [model[key].eval() for key in model] _ = [model[key].to(device) for key in model] # params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu') params_whole = torch.load("Models/Kaede.pth", map_location='cpu') params = params_whole['net'] for key in model: if key in params: print('%s loaded' % key) try: model[key].load_state_dict(params[key]) except: from collections import OrderedDict state_dict = params[key] new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model[key].load_state_dict(new_state_dict, strict=False) # except: # _load(params[key], model[key]) _ = [model[key].eval() for key in model] from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule sampler = DiffusionSampler( model.diffusion.diffusion, sampler=ADPM2Sampler(), sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=0.61, rho=4.6), # empirical parameters clamp=False ) def inference(text, noise, diffusion_steps=5, embedding_scale=1): # text = text.strip() # text = text.replace('"', '') # ps = global_phonemizer.phonemize([text]) # ps = word_tokenize(ps[0]) # ps = ' '.join(ps) text = japanese_cleaners4(text) print(text) tokens = textclenaer(text) tokens.insert(0, 0) tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) text_mask = length_to_mask(input_lengths).to(tokens.device) t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) s_pred = sampler(noise, embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps, embedding_scale=embedding_scale).squeeze(0) s = s_pred[:, 128:] ref = s_pred[:, :128] d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = model.predictor.lstm(d) duration = model.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) pred_dur[-1] += 5 pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) # encode prosody en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) F0_pred, N_pred = model.predictor.F0Ntrain(en, s) out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy() def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1): # text = text.strip() # text = text.replace('"', '') # ps = global_phonemizer.phonemize([text]) # ps = word_tokenize(ps[0]) # ps = ' '.join(ps) text = japanese_cleaners4(text) print(text) tokens = textclenaer(text) tokens.insert(0, 0) tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) with torch.no_grad(): input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) text_mask = length_to_mask(input_lengths).to(tokens.device) t_en = model.text_encoder(tokens, input_lengths, text_mask) bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) d_en = model.bert_encoder(bert_dur).transpose(-1, -2) s_pred = sampler(noise, embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps, embedding_scale=embedding_scale).squeeze(0) if s_prev is not None: # convex combination of previous and current style s_pred = alpha * s_prev + (1 - alpha) * s_pred s = s_pred[:, 128:] ref = s_pred[:, :128] d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) x, _ = model.predictor.lstm(d) duration = model.predictor.duration_proj(x) duration = torch.sigmoid(duration).sum(axis=-1) pred_dur = torch.round(duration.squeeze()).clamp(min=1) pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) c_frame = 0 for i in range(pred_aln_trg.size(0)): pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 c_frame += int(pred_dur[i].data) # encode prosody en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) F0_pred, N_pred = model.predictor.F0Ntrain(en, s) out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), F0_pred, N_pred, ref.squeeze().unsqueeze(0)) return out.squeeze().cpu().numpy(), s_pred