Spaces:
Runtime error
Runtime error
from cached_path import cached_path | |
import torch | |
torch.manual_seed(0) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
import random | |
random.seed(0) | |
import numpy as np | |
np.random.seed(0) | |
import nltk | |
nltk.download('punkt') | |
# load packages | |
import time | |
import random | |
import yaml | |
from munch import Munch | |
import numpy as np | |
import torch | |
from torch import nn | |
import torch.nn.functional as F | |
import torchaudio | |
import librosa | |
from nltk.tokenize import word_tokenize | |
from models import * | |
from utils import * | |
from text_utils import TextCleaner | |
textclenaer = TextCleaner() | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
to_mel = torchaudio.transforms.MelSpectrogram( | |
n_mels=80, n_fft=2048, win_length=1200, hop_length=300) | |
mean, std = -4, 4 | |
def length_to_mask(lengths): | |
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) | |
mask = torch.gt(mask+1, lengths.unsqueeze(1)) | |
return mask | |
def preprocess(wave): | |
wave_tensor = torch.from_numpy(wave).float() | |
mel_tensor = to_mel(wave_tensor) | |
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std | |
return mel_tensor | |
def compute_style(ref_dicts): | |
reference_embeddings = {} | |
for key, path in ref_dicts.items(): | |
wave, sr = librosa.load(path, sr=24000) | |
audio, index = librosa.effects.trim(wave, top_db=30) | |
if sr != 24000: | |
audio = librosa.resample(audio, sr, 24000) | |
mel_tensor = preprocess(audio).to(device) | |
with torch.no_grad(): | |
ref = model.style_encoder(mel_tensor.unsqueeze(1)) | |
reference_embeddings[key] = (ref.squeeze(1), audio) | |
return reference_embeddings | |
# load phonemizer | |
# import phonemizer | |
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore') | |
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt'))) | |
import fugashi | |
import pykakasi | |
from collections import OrderedDict | |
# MB-iSTFT-VITS2 | |
import re | |
from unidecode import unidecode | |
import pyopenjtalk | |
# Regular expression matching Japanese without punctuation marks: | |
_japanese_characters = re.compile( | |
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') | |
# Regular expression matching non-Japanese characters or punctuation marks: | |
_japanese_marks = re.compile( | |
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') | |
# List of (symbol, Japanese) pairs for marks: | |
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
('%', 'パーセント') | |
]] | |
# List of (romaji, ipa) pairs for marks: | |
_romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
('ts', 'ʦ'), | |
('u', 'ɯ'), | |
('j', 'ʥ'), | |
('y', 'j'), | |
('ni', 'n^i'), | |
('nj', 'n^'), | |
('hi', 'çi'), | |
('hj', 'ç'), | |
('f', 'ɸ'), | |
('I', 'i*'), | |
('U', 'ɯ*'), | |
('r', 'ɾ') | |
]] | |
# List of (romaji, ipa2) pairs for marks: | |
_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
('u', 'ɯ'), | |
('ʧ', 'tʃ'), | |
('j', 'dʑ'), | |
('y', 'j'), | |
('ni', 'n^i'), | |
('nj', 'n^'), | |
('hi', 'çi'), | |
('hj', 'ç'), | |
('f', 'ɸ'), | |
('I', 'i*'), | |
('U', 'ɯ*'), | |
('r', 'ɾ') | |
]] | |
# List of (consonant, sokuon) pairs: | |
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
(r'Q([↑↓]*[kg])', r'k#\1'), | |
(r'Q([↑↓]*[tdjʧ])', r't#\1'), | |
(r'Q([↑↓]*[sʃ])', r's\1'), | |
(r'Q([↑↓]*[pb])', r'p#\1') | |
]] | |
# List of (consonant, hatsuon) pairs: | |
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
(r'N([↑↓]*[pbm])', r'm\1'), | |
(r'N([↑↓]*[ʧʥj])', r'n^\1'), | |
(r'N([↑↓]*[tdn])', r'n\1'), | |
(r'N([↑↓]*[kg])', r'ŋ\1') | |
]] | |
def symbols_to_japanese(text): | |
for regex, replacement in _symbols_to_japanese: | |
text = re.sub(regex, replacement, text) | |
return text | |
def japanese_to_romaji_with_accent(text): | |
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html''' | |
text = symbols_to_japanese(text) | |
sentences = re.split(_japanese_marks, text) | |
marks = re.findall(_japanese_marks, text) | |
text = '' | |
for i, sentence in enumerate(sentences): | |
if re.match(_japanese_characters, sentence): | |
if text != '': | |
text += ' ' | |
labels = pyopenjtalk.extract_fullcontext(sentence) | |
for n, label in enumerate(labels): | |
phoneme = re.search(r'\-([^\+]*)\+', label).group(1) | |
if phoneme not in ['sil', 'pau']: | |
text += phoneme.replace('ch', 'ʧ').replace('sh', | |
'ʃ').replace('cl', 'Q') | |
else: | |
continue | |
# n_moras = int(re.search(r'/F:(\d+)_', label).group(1)) | |
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1)) | |
a2 = int(re.search(r"\+(\d+)\+", label).group(1)) | |
a3 = int(re.search(r"\+(\d+)/", label).group(1)) | |
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']: | |
a2_next = -1 | |
else: | |
a2_next = int( | |
re.search(r"\+(\d+)\+", labels[n + 1]).group(1)) | |
# Accent phrase boundary | |
if a3 == 1 and a2_next == 1: | |
text += ' ' | |
# Falling | |
elif a1 == 0 and a2_next == a2 + 1: | |
text += '↓' | |
# Rising | |
elif a2 == 1 and a2_next == 2: | |
text += '↑' | |
if i < len(marks): | |
text += unidecode(marks[i]).replace(' ', '') | |
return text | |
def get_real_sokuon(text): | |
for regex, replacement in _real_sokuon: | |
text = re.sub(regex, replacement, text) | |
return text | |
def get_real_hatsuon(text): | |
for regex, replacement in _real_hatsuon: | |
text = re.sub(regex, replacement, text) | |
return text | |
def japanese_to_ipa(text): | |
text = japanese_to_romaji_with_accent(text).replace('...', '…') | |
text = re.sub( | |
r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) | |
text = get_real_sokuon(text) | |
text = get_real_hatsuon(text) | |
for regex, replacement in _romaji_to_ipa: | |
text = re.sub(regex, replacement, text) | |
return text | |
def japanese_to_ipa2(text): | |
text = japanese_to_romaji_with_accent(text).replace('...', '…') | |
text = get_real_sokuon(text) | |
text = get_real_hatsuon(text) | |
for regex, replacement in _romaji_to_ipa2: | |
text = re.sub(regex, replacement, text) | |
return text | |
def japanese_to_ipa3(text): | |
text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace( | |
'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a') | |
text = re.sub( | |
r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text) | |
text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text) | |
return text | |
""" from https://github.com/keithito/tacotron """ | |
''' | |
Cleaners are transformations that run over the input text at both training and eval time. | |
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners" | |
hyperparameter. Some cleaners are English-specific. You'll typically want to use: | |
1. "english_cleaners" for English text | |
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using | |
the Unidecode library (https://pypi.python.org/pypi/Unidecode) | |
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update | |
the symbols in symbols.py to match your data). | |
''' | |
# Regular expression matching whitespace: | |
import re | |
import inflect | |
from unidecode import unidecode | |
_inflect = inflect.engine() | |
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])') | |
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)') | |
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)') | |
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)') | |
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)') | |
_number_re = re.compile(r'[0-9]+') | |
# List of (regular expression, replacement) pairs for abbreviations: | |
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ | |
('mrs', 'misess'), | |
('mr', 'mister'), | |
('dr', 'doctor'), | |
('st', 'saint'), | |
('co', 'company'), | |
('jr', 'junior'), | |
('maj', 'major'), | |
('gen', 'general'), | |
('drs', 'doctors'), | |
('rev', 'reverend'), | |
('lt', 'lieutenant'), | |
('hon', 'honorable'), | |
('sgt', 'sergeant'), | |
('capt', 'captain'), | |
('esq', 'esquire'), | |
('ltd', 'limited'), | |
('col', 'colonel'), | |
('ft', 'fort'), | |
]] | |
# List of (ipa, lazy ipa) pairs: | |
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
('r', 'ɹ'), | |
('æ', 'e'), | |
('ɑ', 'a'), | |
('ɔ', 'o'), | |
('ð', 'z'), | |
('θ', 's'), | |
('ɛ', 'e'), | |
('ɪ', 'i'), | |
('ʊ', 'u'), | |
('ʒ', 'ʥ'), | |
('ʤ', 'ʥ'), | |
('', '↓'), | |
]] | |
# List of (ipa, lazy ipa2) pairs: | |
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
('r', 'ɹ'), | |
('ð', 'z'), | |
('θ', 's'), | |
('ʒ', 'ʑ'), | |
('ʤ', 'dʑ'), | |
('', '↓'), | |
]] | |
# List of (ipa, ipa2) pairs | |
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [ | |
('r', 'ɹ'), | |
('ʤ', 'dʒ'), | |
('ʧ', 'tʃ') | |
]] | |
def expand_abbreviations(text): | |
for regex, replacement in _abbreviations: | |
text = re.sub(regex, replacement, text) | |
return text | |
def collapse_whitespace(text): | |
return re.sub(r'\s+', ' ', text) | |
def _remove_commas(m): | |
return m.group(1).replace(',', '') | |
def _expand_decimal_point(m): | |
return m.group(1).replace('.', ' point ') | |
def _expand_dollars(m): | |
match = m.group(1) | |
parts = match.split('.') | |
if len(parts) > 2: | |
return match + ' dollars' # Unexpected format | |
dollars = int(parts[0]) if parts[0] else 0 | |
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0 | |
if dollars and cents: | |
dollar_unit = 'dollar' if dollars == 1 else 'dollars' | |
cent_unit = 'cent' if cents == 1 else 'cents' | |
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit) | |
elif dollars: | |
dollar_unit = 'dollar' if dollars == 1 else 'dollars' | |
return '%s %s' % (dollars, dollar_unit) | |
elif cents: | |
cent_unit = 'cent' if cents == 1 else 'cents' | |
return '%s %s' % (cents, cent_unit) | |
else: | |
return 'zero dollars' | |
def _expand_ordinal(m): | |
return _inflect.number_to_words(m.group(0)) | |
def _expand_number(m): | |
num = int(m.group(0)) | |
if num > 1000 and num < 3000: | |
if num == 2000: | |
return 'two thousand' | |
elif num > 2000 and num < 2010: | |
return 'two thousand ' + _inflect.number_to_words(num % 100) | |
elif num % 100 == 0: | |
return _inflect.number_to_words(num // 100) + ' hundred' | |
else: | |
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ') | |
else: | |
return _inflect.number_to_words(num, andword='') | |
def normalize_numbers(text): | |
text = re.sub(_comma_number_re, _remove_commas, text) | |
text = re.sub(_pounds_re, r'\1 pounds', text) | |
text = re.sub(_dollars_re, _expand_dollars, text) | |
text = re.sub(_decimal_number_re, _expand_decimal_point, text) | |
text = re.sub(_ordinal_re, _expand_ordinal, text) | |
text = re.sub(_number_re, _expand_number, text) | |
return text | |
def mark_dark_l(text): | |
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text) | |
import re | |
#from text.thai import num_to_thai, latin_to_thai | |
#from text.shanghainese import shanghainese_to_ipa | |
#from text.cantonese import cantonese_to_ipa | |
#from text.ngu_dialect import ngu_dialect_to_ipa | |
from unidecode import unidecode | |
_whitespace_re = re.compile(r'\s+') | |
# Regular expression matching Japanese without punctuation marks: | |
_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') | |
# Regular expression matching non-Japanese characters or punctuation marks: | |
_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]') | |
# List of (regular expression, replacement) pairs for abbreviations: | |
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [ | |
('mrs', 'misess'), | |
('mr', 'mister'), | |
('dr', 'doctor'), | |
('st', 'saint'), | |
('co', 'company'), | |
('jr', 'junior'), | |
('maj', 'major'), | |
('gen', 'general'), | |
('drs', 'doctors'), | |
('rev', 'reverend'), | |
('lt', 'lieutenant'), | |
('hon', 'honorable'), | |
('sgt', 'sergeant'), | |
('capt', 'captain'), | |
('esq', 'esquire'), | |
('ltd', 'limited'), | |
('col', 'colonel'), | |
('ft', 'fort'), | |
]] | |
def expand_abbreviations(text): | |
for regex, replacement in _abbreviations: | |
text = re.sub(regex, replacement, text) | |
return text | |
def collapse_whitespace(text): | |
return re.sub(_whitespace_re, ' ', text) | |
def convert_to_ascii(text): | |
return unidecode(text) | |
def basic_cleaners(text): | |
# - For replication of https://github.com/FENRlR/MB-iSTFT-VITS2/issues/2 | |
# you may need to replace the symbol to Russian one | |
'''Basic pipeline that lowercases and collapses whitespace without transliteration.''' | |
text = text.lower() | |
text = collapse_whitespace(text) | |
return text | |
''' | |
def fix_g2pk2_error(text): | |
new_text = "" | |
i = 0 | |
while i < len(text) - 4: | |
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ': | |
new_text += text[i:i+3] + ' ' + 'ㄴ' | |
i += 5 | |
else: | |
new_text += text[i] | |
i += 1 | |
new_text += text[i:] | |
return new_text | |
''' | |
def japanese_cleaners(text): | |
text = japanese_to_romaji_with_accent(text) | |
text = re.sub(r'([A-Za-z])$', r'\1.', text) | |
return text | |
def japanese_cleaners2(text): | |
return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…') | |
def japanese_cleaners3(text): | |
text = japanese_to_ipa3(text) | |
if "<<" in text or ">>" in text or "¡" in text or "¿" in text: | |
text = text.replace("<<","«") | |
text = text.replace(">>","»") | |
text = text.replace("!","¡") | |
text = text.replace("?","¿") | |
if'"'in text: | |
text = text.replace('"','”') | |
if'--'in text: | |
text = text.replace('--','—') | |
if ' ' in text: | |
text = text.replace(' ','') | |
return text | |
# ------------------------------ | |
''' cjke type cleaners below ''' | |
#- text for these cleaners must be labeled first | |
# ex1 (single) : some.wav|[EN]put some text here[EN] | |
# ex2 (multi) : some.wav|0|[EN]put some text here[EN] | |
# ------------------------------ | |
def kej_cleaners(text): | |
text = re.sub(r'\[KO\](.*?)\[KO\]', | |
lambda x: korean_to_ipa(x.group(1))+' ', text) | |
text = re.sub(r'\[EN\](.*?)\[EN\]', | |
lambda x: english_to_ipa2(x.group(1)) + ' ', text) | |
text = re.sub(r'\[JA\](.*?)\[JA\]', | |
lambda x: japanese_to_ipa2(x.group(1)) + ' ', text) | |
text = re.sub(r'\s+$', '', text) | |
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) | |
return text | |
def cjks_cleaners(text): | |
text = re.sub(r'\[JA\](.*?)\[JA\]', | |
lambda x: japanese_to_ipa(x.group(1))+' ', text) | |
#text = re.sub(r'\[SA\](.*?)\[SA\]', | |
# lambda x: devanagari_to_ipa(x.group(1))+' ', text) | |
text = re.sub(r'\[EN\](.*?)\[EN\]', | |
lambda x: english_to_lazy_ipa(x.group(1))+' ', text) | |
text = re.sub(r'\s+$', '', text) | |
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) | |
return text | |
''' | |
#- reserves | |
def thai_cleaners(text): | |
text = num_to_thai(text) | |
text = latin_to_thai(text) | |
return text | |
def shanghainese_cleaners(text): | |
text = shanghainese_to_ipa(text) | |
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) | |
return text | |
def chinese_dialect_cleaners(text): | |
text = re.sub(r'\[ZH\](.*?)\[ZH\]', | |
lambda x: chinese_to_ipa2(x.group(1))+' ', text) | |
text = re.sub(r'\[JA\](.*?)\[JA\]', | |
lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text) | |
text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5', | |
'˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text) | |
text = re.sub(r'\[GD\](.*?)\[GD\]', | |
lambda x: cantonese_to_ipa(x.group(1))+' ', text) | |
text = re.sub(r'\[EN\](.*?)\[EN\]', | |
lambda x: english_to_lazy_ipa2(x.group(1))+' ', text) | |
text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group( | |
1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text) | |
text = re.sub(r'\s+$', '', text) | |
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text) | |
return text | |
''' | |
def japanese_cleaners3(text): | |
global orig | |
orig = text # saving the original unmodifed text for future use | |
text = japanese_to_ipa2(text) | |
if '' in text: | |
text = text.replace('','') | |
if "<<" in text or ">>" in text or "¡" in text or "¿" in text: | |
text = text.replace("<<","«") | |
text = text.replace(">>","»") | |
text = text.replace("!","¡") | |
text = text.replace("?","¿") | |
if'"'in text: | |
text = text.replace('"','”') | |
if'--'in text: | |
text = text.replace('--','—') | |
text = text.replace("#","ʔ") | |
text = text.replace("^","") | |
text = text.replace("kj","kʲ") | |
text = text.replace("kj","kʲ") | |
text = text.replace("ɾj","ɾʲ") | |
text = text.replace("mj","mʲ") | |
text = text.replace("ʃ","ɕ") | |
text = text.replace("*","") | |
text = text.replace("bj","bʲ") | |
text = text.replace("h","ç") | |
text = text.replace("gj","gʲ") | |
return text | |
def japanese_cleaners4(text): | |
text = japanese_cleaners3(text) | |
if "にゃ" in orig: | |
text = text.replace("na","nʲa") | |
elif "にゅ" in orig: | |
text = text.replace("n","nʲ") | |
elif "にょ" in orig: | |
text = text.replace("n","nʲ") | |
elif "にぃ" in orig: | |
text = text.replace("ni i","niː") | |
elif "いゃ" in orig: | |
text = text.replace("i↑ja","ja") | |
elif "いゃ" in orig: | |
text = text.replace("i↑ja","ja") | |
elif "ひょ" in orig: | |
text = text.replace("ço","çʲo") | |
elif "しょ" in orig: | |
text = text.replace("ɕo","ɕʲo") | |
text = text.replace("Q","ʔ") | |
text = text.replace("N","ɴ") | |
text = re.sub(r'.ʔ', 'ʔ', text) | |
text = text.replace('" ', '"') | |
text = text.replace('” ', '”') | |
return text | |
config = yaml.safe_load(open(str(cached_path('hf://yl4579/StyleTTS2-LJSpeech/Models/LJSpeech/config.yml')))) | |
# load pretrained ASR model | |
ASR_config = config.get('ASR_config', False) | |
ASR_path = config.get('ASR_path', False) | |
text_aligner = load_ASR_models(ASR_path, ASR_config) | |
# load pretrained F0 model | |
F0_path = config.get('F0_path', False) | |
pitch_extractor = load_F0_models(F0_path) | |
# load BERT model | |
from Utils.PLBERT.util import load_plbert | |
BERT_path = config.get('PLBERT_dir', False) | |
plbert = load_plbert(BERT_path) | |
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert) | |
_ = [model[key].eval() for key in model] | |
_ = [model[key].to(device) for key in model] | |
# params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu') | |
params_whole = torch.load("Models/Kaede.pth", map_location='cpu') | |
params = params_whole['net'] | |
for key in model: | |
if key in params: | |
print('%s loaded' % key) | |
try: | |
model[key].load_state_dict(params[key]) | |
except: | |
from collections import OrderedDict | |
state_dict = params[key] | |
new_state_dict = OrderedDict() | |
for k, v in state_dict.items(): | |
name = k[7:] # remove `module.` | |
new_state_dict[name] = v | |
# load params | |
model[key].load_state_dict(new_state_dict, strict=False) | |
# except: | |
# _load(params[key], model[key]) | |
_ = [model[key].eval() for key in model] | |
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule | |
sampler = DiffusionSampler( | |
model.diffusion.diffusion, | |
sampler=ADPM2Sampler(), | |
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=0.68, rho=4.6), # empirical parameters | |
clamp=False | |
) | |
def inference(text, noise, diffusion_steps=5, embedding_scale=1): | |
# text = text.strip() | |
# text = text.replace('"', '') | |
# ps = global_phonemizer.phonemize([text]) | |
# ps = word_tokenize(ps[0]) | |
# ps = ' '.join(ps) | |
text = japanese_cleaners4(text) | |
print(text) | |
tokens = textclenaer(text) | |
tokens.insert(0, 0) | |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) | |
with torch.no_grad(): | |
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) | |
text_mask = length_to_mask(input_lengths).to(tokens.device) | |
t_en = model.text_encoder(tokens, input_lengths, text_mask) | |
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) | |
d_en = model.bert_encoder(bert_dur).transpose(-1, -2) | |
s_pred = sampler(noise, | |
embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps, | |
embedding_scale=embedding_scale).squeeze(0) | |
s = s_pred[:, 128:] | |
ref = s_pred[:, :128] | |
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
x, _ = model.predictor.lstm(d) | |
duration = model.predictor.duration_proj(x) | |
duration = torch.sigmoid(duration).sum(axis=-1) | |
pred_dur = torch.round(duration.squeeze()).clamp(min=1) | |
pred_dur[-1] += 5 | |
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) | |
c_frame = 0 | |
for i in range(pred_aln_trg.size(0)): | |
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 | |
c_frame += int(pred_dur[i].data) | |
# encode prosody | |
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) | |
F0_pred, N_pred = model.predictor.F0Ntrain(en, s) | |
out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), | |
F0_pred, N_pred, ref.squeeze().unsqueeze(0)) | |
return out.squeeze().cpu().numpy() | |
def LFinference(text, s_prev, noise, alpha=0.7, diffusion_steps=5, embedding_scale=1): | |
# text = text.strip() | |
# text = text.replace('"', '') | |
# ps = global_phonemizer.phonemize([text]) | |
# ps = word_tokenize(ps[0]) | |
# ps = ' '.join(ps) | |
text = japanese_cleaners4(text) | |
print(text) | |
tokens = textclenaer(text) | |
tokens.insert(0, 0) | |
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) | |
with torch.no_grad(): | |
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) | |
text_mask = length_to_mask(input_lengths).to(tokens.device) | |
t_en = model.text_encoder(tokens, input_lengths, text_mask) | |
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) | |
d_en = model.bert_encoder(bert_dur).transpose(-1, -2) | |
s_pred = sampler(noise, | |
embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps, | |
embedding_scale=embedding_scale).squeeze(0) | |
if s_prev is not None: | |
# convex combination of previous and current style | |
s_pred = alpha * s_prev + (1 - alpha) * s_pred | |
s = s_pred[:, 128:] | |
ref = s_pred[:, :128] | |
d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask) | |
x, _ = model.predictor.lstm(d) | |
duration = model.predictor.duration_proj(x) | |
duration = torch.sigmoid(duration).sum(axis=-1) | |
pred_dur = torch.round(duration.squeeze()).clamp(min=1) | |
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) | |
c_frame = 0 | |
for i in range(pred_aln_trg.size(0)): | |
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 | |
c_frame += int(pred_dur[i].data) | |
# encode prosody | |
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) | |
F0_pred, N_pred = model.predictor.F0Ntrain(en, s) | |
out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), | |
F0_pred, N_pred, ref.squeeze().unsqueeze(0)) | |
return out.squeeze().cpu().numpy(), s_pred |