styletts2_Japanese / styletts2importable.py
Respair's picture
Update styletts2importable.py
f490eec verified
raw
history blame
32.5 kB
from cached_path import cached_path
# print("GRUUT")
# from gruut_phonemize import gphonemize
# from dp.phonemizer import Phonemizer
print("NLTK")
import nltk
nltk.download('punkt')
print("SCIPY")
from scipy.io.wavfile import write
print("TORCH STUFF")
import torch
print("START")
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
import random
random.seed(0)
import numpy as np
np.random.seed(0)
# load packages
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize
from models import *
from utils import *
from text_utils import TextCleaner
textclenaer = TextCleaner()
from cached_path import cached_path
import torch
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
import random
random.seed(0)
import numpy as np
np.random.seed(0)
import nltk
nltk.download('punkt')
# load packages
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize
from models import *
from utils import *
from text_utils import TextCleaner
textclenaer = TextCleaner()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
def preprocess(wave):
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
def compute_style(ref_dicts):
reference_embeddings = {}
for key, path in ref_dicts.items():
wave, sr = librosa.load(path, sr=24000)
audio, index = librosa.effects.trim(wave, top_db=30)
if sr != 24000:
audio = librosa.resample(audio, sr, 24000)
mel_tensor = preprocess(audio).to(device)
with torch.no_grad():
ref = model.style_encoder(mel_tensor.unsqueeze(1))
reference_embeddings[key] = (ref.squeeze(1), audio)
return reference_embeddings
# load phonemizer
# import phonemizer
# global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True, words_mismatch='ignore')
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
import fugashi
import pykakasi
from collections import OrderedDict
# MB-iSTFT-VITS2
import re
from unidecode import unidecode
import pyopenjtalk
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(
r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(
r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# List of (symbol, Japanese) pairs for marks:
_symbols_to_japanese = [(re.compile('%s' % x[0]), x[1]) for x in [
('%', 'パーセント')
]]
# List of (romaji, ipa) pairs for marks:
_romaji_to_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
('ts', 'ʦ'),
('u', 'ɯ'),
('j', 'ʥ'),
('y', 'j'),
('ni', 'n^i'),
('nj', 'n^'),
('hi', 'çi'),
('hj', 'ç'),
('f', 'ɸ'),
('I', 'i*'),
('U', 'ɯ*'),
('r', 'ɾ')
]]
# List of (romaji, ipa2) pairs for marks:
_romaji_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
('u', 'ɯ'),
('ʧ', 'tʃ'),
('j', 'dʑ'),
('y', 'j'),
('ni', 'n^i'),
('nj', 'n^'),
('hi', 'çi'),
('hj', 'ç'),
('f', 'ɸ'),
('I', 'i*'),
('U', 'ɯ*'),
('r', 'ɾ')
]]
# List of (consonant, sokuon) pairs:
_real_sokuon = [(re.compile('%s' % x[0]), x[1]) for x in [
(r'Q([↑↓]*[kg])', r'k#\1'),
(r'Q([↑↓]*[tdjʧ])', r't#\1'),
(r'Q([↑↓]*[sʃ])', r's\1'),
(r'Q([↑↓]*[pb])', r'p#\1')
]]
# List of (consonant, hatsuon) pairs:
_real_hatsuon = [(re.compile('%s' % x[0]), x[1]) for x in [
(r'N([↑↓]*[pbm])', r'm\1'),
(r'N([↑↓]*[ʧʥj])', r'n^\1'),
(r'N([↑↓]*[tdn])', r'n\1'),
(r'N([↑↓]*[kg])', r'ŋ\1')
]]
def symbols_to_japanese(text):
for regex, replacement in _symbols_to_japanese:
text = re.sub(regex, replacement, text)
return text
def japanese_to_romaji_with_accent(text):
'''Reference https://r9y9.github.io/ttslearn/latest/notebooks/ch10_Recipe-Tacotron.html'''
text = symbols_to_japanese(text)
sentences = re.split(_japanese_marks, text)
marks = re.findall(_japanese_marks, text)
text = ''
for i, sentence in enumerate(sentences):
if re.match(_japanese_characters, sentence):
if text != '':
text += ' '
labels = pyopenjtalk.extract_fullcontext(sentence)
for n, label in enumerate(labels):
phoneme = re.search(r'\-([^\+]*)\+', label).group(1)
if phoneme not in ['sil', 'pau']:
text += phoneme.replace('ch', 'ʧ').replace('sh',
'ʃ').replace('cl', 'Q')
else:
continue
# n_moras = int(re.search(r'/F:(\d+)_', label).group(1))
a1 = int(re.search(r"/A:(\-?[0-9]+)\+", label).group(1))
a2 = int(re.search(r"\+(\d+)\+", label).group(1))
a3 = int(re.search(r"\+(\d+)/", label).group(1))
if re.search(r'\-([^\+]*)\+', labels[n + 1]).group(1) in ['sil', 'pau']:
a2_next = -1
else:
a2_next = int(
re.search(r"\+(\d+)\+", labels[n + 1]).group(1))
# Accent phrase boundary
if a3 == 1 and a2_next == 1:
text += ' '
# Falling
elif a1 == 0 and a2_next == a2 + 1:
text += '↓'
# Rising
elif a2 == 1 and a2_next == 2:
text += '↑'
if i < len(marks):
text += unidecode(marks[i]).replace(' ', '')
return text
def get_real_sokuon(text):
for regex, replacement in _real_sokuon:
text = re.sub(regex, replacement, text)
return text
def get_real_hatsuon(text):
for regex, replacement in _real_hatsuon:
text = re.sub(regex, replacement, text)
return text
def japanese_to_ipa(text):
text = japanese_to_romaji_with_accent(text).replace('...', '…')
text = re.sub(
r'([aiueo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
text = get_real_sokuon(text)
text = get_real_hatsuon(text)
for regex, replacement in _romaji_to_ipa:
text = re.sub(regex, replacement, text)
return text
def japanese_to_ipa2(text):
text = japanese_to_romaji_with_accent(text).replace('...', '…')
text = get_real_sokuon(text)
text = get_real_hatsuon(text)
for regex, replacement in _romaji_to_ipa2:
text = re.sub(regex, replacement, text)
return text
def japanese_to_ipa3(text):
text = japanese_to_ipa2(text).replace('n^', 'ȵ').replace(
'ʃ', 'ɕ').replace('*', '\u0325').replace('#', '\u031a')
text = re.sub(
r'([aiɯeo])\1+', lambda x: x.group(0)[0]+'ː'*(len(x.group(0))-1), text)
text = re.sub(r'((?:^|\s)(?:ts|tɕ|[kpt]))', r'\1ʰ', text)
return text
""" from https://github.com/keithito/tacotron """
'''
Cleaners are transformations that run over the input text at both training and eval time.
Cleaners can be selected by passing a comma-delimited list of cleaner names as the "cleaners"
hyperparameter. Some cleaners are English-specific. You'll typically want to use:
1. "english_cleaners" for English text
2. "transliteration_cleaners" for non-English text that can be transliterated to ASCII using
the Unidecode library (https://pypi.python.org/pypi/Unidecode)
3. "basic_cleaners" if you do not want to transliterate (in this case, you should also update
the symbols in symbols.py to match your data).
'''
# Regular expression matching whitespace:
import re
import inflect
from unidecode import unidecode
_inflect = inflect.engine()
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
_number_re = re.compile(r'[0-9]+')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
# List of (ipa, lazy ipa) pairs:
_lazy_ipa = [(re.compile('%s' % x[0]), x[1]) for x in [
('r', 'ɹ'),
('æ', 'e'),
('ɑ', 'a'),
('ɔ', 'o'),
('ð', 'z'),
('θ', 's'),
('ɛ', 'e'),
('ɪ', 'i'),
('ʊ', 'u'),
('ʒ', 'ʥ'),
('ʤ', 'ʥ'),
('', '↓'),
]]
# List of (ipa, lazy ipa2) pairs:
_lazy_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
('r', 'ɹ'),
('ð', 'z'),
('θ', 's'),
('ʒ', 'ʑ'),
('ʤ', 'dʑ'),
('', '↓'),
]]
# List of (ipa, ipa2) pairs
_ipa_to_ipa2 = [(re.compile('%s' % x[0]), x[1]) for x in [
('r', 'ɹ'),
('ʤ', 'dʒ'),
('ʧ', 'tʃ')
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def collapse_whitespace(text):
return re.sub(r'\s+', ' ', text)
def _remove_commas(m):
return m.group(1).replace(',', '')
def _expand_decimal_point(m):
return m.group(1).replace('.', ' point ')
def _expand_dollars(m):
match = m.group(1)
parts = match.split('.')
if len(parts) > 2:
return match + ' dollars' # Unexpected format
dollars = int(parts[0]) if parts[0] else 0
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
if dollars and cents:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
elif dollars:
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
return '%s %s' % (dollars, dollar_unit)
elif cents:
cent_unit = 'cent' if cents == 1 else 'cents'
return '%s %s' % (cents, cent_unit)
else:
return 'zero dollars'
def _expand_ordinal(m):
return _inflect.number_to_words(m.group(0))
def _expand_number(m):
num = int(m.group(0))
if num > 1000 and num < 3000:
if num == 2000:
return 'two thousand'
elif num > 2000 and num < 2010:
return 'two thousand ' + _inflect.number_to_words(num % 100)
elif num % 100 == 0:
return _inflect.number_to_words(num // 100) + ' hundred'
else:
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
else:
return _inflect.number_to_words(num, andword='')
def normalize_numbers(text):
text = re.sub(_comma_number_re, _remove_commas, text)
text = re.sub(_pounds_re, r'\1 pounds', text)
text = re.sub(_dollars_re, _expand_dollars, text)
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
text = re.sub(_ordinal_re, _expand_ordinal, text)
text = re.sub(_number_re, _expand_number, text)
return text
def mark_dark_l(text):
return re.sub(r'l([^aeiouæɑɔəɛɪʊ ]*(?: |$))', lambda x: 'ɫ'+x.group(1), text)
import re
#from text.thai import num_to_thai, latin_to_thai
#from text.shanghainese import shanghainese_to_ipa
#from text.cantonese import cantonese_to_ipa
#from text.ngu_dialect import ngu_dialect_to_ipa
from unidecode import unidecode
_whitespace_re = re.compile(r'\s+')
# Regular expression matching Japanese without punctuation marks:
_japanese_characters = re.compile(r'[A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# Regular expression matching non-Japanese characters or punctuation marks:
_japanese_marks = re.compile(r'[^A-Za-z\d\u3005\u3040-\u30ff\u4e00-\u9fff\uff11-\uff19\uff21-\uff3a\uff41-\uff5a\uff66-\uff9d]')
# List of (regular expression, replacement) pairs for abbreviations:
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
('mrs', 'misess'),
('mr', 'mister'),
('dr', 'doctor'),
('st', 'saint'),
('co', 'company'),
('jr', 'junior'),
('maj', 'major'),
('gen', 'general'),
('drs', 'doctors'),
('rev', 'reverend'),
('lt', 'lieutenant'),
('hon', 'honorable'),
('sgt', 'sergeant'),
('capt', 'captain'),
('esq', 'esquire'),
('ltd', 'limited'),
('col', 'colonel'),
('ft', 'fort'),
]]
def expand_abbreviations(text):
for regex, replacement in _abbreviations:
text = re.sub(regex, replacement, text)
return text
def collapse_whitespace(text):
return re.sub(_whitespace_re, ' ', text)
def convert_to_ascii(text):
return unidecode(text)
def basic_cleaners(text):
# - For replication of https://github.com/FENRlR/MB-iSTFT-VITS2/issues/2
# you may need to replace the symbol to Russian one
'''Basic pipeline that lowercases and collapses whitespace without transliteration.'''
text = text.lower()
text = collapse_whitespace(text)
return text
'''
def fix_g2pk2_error(text):
new_text = ""
i = 0
while i < len(text) - 4:
if (text[i:i+3] == 'ㅇㅡㄹ' or text[i:i+3] == 'ㄹㅡㄹ') and text[i+3] == ' ' and text[i+4] == 'ㄹ':
new_text += text[i:i+3] + ' ' + 'ㄴ'
i += 5
else:
new_text += text[i]
i += 1
new_text += text[i:]
return new_text
'''
def japanese_cleaners(text):
text = japanese_to_romaji_with_accent(text)
text = re.sub(r'([A-Za-z])$', r'\1.', text)
return text
def japanese_cleaners2(text):
return japanese_cleaners(text).replace('ts', 'ʦ').replace('...', '…')
def japanese_cleaners3(text):
text = japanese_to_ipa3(text)
if "<<" in text or ">>" in text or "¡" in text or "¿" in text:
text = text.replace("<<","«")
text = text.replace(">>","»")
text = text.replace("!","¡")
text = text.replace("?","¿")
if'"'in text:
text = text.replace('"','”')
if'--'in text:
text = text.replace('--','—')
if ' ' in text:
text = text.replace(' ','')
return text
# ------------------------------
''' cjke type cleaners below '''
#- text for these cleaners must be labeled first
# ex1 (single) : some.wav|[EN]put some text here[EN]
# ex2 (multi) : some.wav|0|[EN]put some text here[EN]
# ------------------------------
def kej_cleaners(text):
text = re.sub(r'\[KO\](.*?)\[KO\]',
lambda x: korean_to_ipa(x.group(1))+' ', text)
text = re.sub(r'\[EN\](.*?)\[EN\]',
lambda x: english_to_ipa2(x.group(1)) + ' ', text)
text = re.sub(r'\[JA\](.*?)\[JA\]',
lambda x: japanese_to_ipa2(x.group(1)) + ' ', text)
text = re.sub(r'\s+$', '', text)
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
return text
def cjks_cleaners(text):
text = re.sub(r'\[JA\](.*?)\[JA\]',
lambda x: japanese_to_ipa(x.group(1))+' ', text)
#text = re.sub(r'\[SA\](.*?)\[SA\]',
# lambda x: devanagari_to_ipa(x.group(1))+' ', text)
text = re.sub(r'\[EN\](.*?)\[EN\]',
lambda x: english_to_lazy_ipa(x.group(1))+' ', text)
text = re.sub(r'\s+$', '', text)
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
return text
'''
#- reserves
def thai_cleaners(text):
text = num_to_thai(text)
text = latin_to_thai(text)
return text
def shanghainese_cleaners(text):
text = shanghainese_to_ipa(text)
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
return text
def chinese_dialect_cleaners(text):
text = re.sub(r'\[ZH\](.*?)\[ZH\]',
lambda x: chinese_to_ipa2(x.group(1))+' ', text)
text = re.sub(r'\[JA\](.*?)\[JA\]',
lambda x: japanese_to_ipa3(x.group(1)).replace('Q', 'ʔ')+' ', text)
text = re.sub(r'\[SH\](.*?)\[SH\]', lambda x: shanghainese_to_ipa(x.group(1)).replace('1', '˥˧').replace('5',
'˧˧˦').replace('6', '˩˩˧').replace('7', '˥').replace('8', '˩˨').replace('ᴀ', 'ɐ').replace('ᴇ', 'e')+' ', text)
text = re.sub(r'\[GD\](.*?)\[GD\]',
lambda x: cantonese_to_ipa(x.group(1))+' ', text)
text = re.sub(r'\[EN\](.*?)\[EN\]',
lambda x: english_to_lazy_ipa2(x.group(1))+' ', text)
text = re.sub(r'\[([A-Z]{2})\](.*?)\[\1\]', lambda x: ngu_dialect_to_ipa(x.group(2), x.group(
1)).replace('ʣ', 'dz').replace('ʥ', 'dʑ').replace('ʦ', 'ts').replace('ʨ', 'tɕ')+' ', text)
text = re.sub(r'\s+$', '', text)
text = re.sub(r'([^\.,!\?\-…~])$', r'\1.', text)
return text
'''
def japanese_cleaners3(text):
global orig
orig = text # saving the original unmodifed text for future use
text = japanese_to_ipa2(text)
if '' in text:
text = text.replace('','')
if "<<" in text or ">>" in text or "¡" in text or "¿" in text:
text = text.replace("<<","«")
text = text.replace(">>","»")
text = text.replace("!","¡")
text = text.replace("?","¿")
if'"'in text:
text = text.replace('"','”')
if'--'in text:
text = text.replace('--','—')
text = text.replace("#","ʔ")
text = text.replace("^","")
text = text.replace("kj","kʲ")
text = text.replace("kj","kʲ")
text = text.replace("ɾj","ɾʲ")
text = text.replace("mj","mʲ")
text = text.replace("ʃ","ɕ")
text = text.replace("*","")
text = text.replace("bj","bʲ")
text = text.replace("h","ç")
text = text.replace("gj","gʲ")
return text
def japanese_cleaners4(text):
text = japanese_cleaners3(text)
if "にゃ" in orig:
text = text.replace("na","nʲa")
elif "にゅ" in orig:
text = text.replace("n","nʲ")
elif "にょ" in orig:
text = text.replace("n","nʲ")
elif "にぃ" in orig:
text = text.replace("ni i","niː")
elif "いゃ" in orig:
text = text.replace("i↑ja","ja")
elif "いゃ" in orig:
text = text.replace("i↑ja","ja")
elif "ひょ" in orig:
text = text.replace("ço","çʲo")
elif "しょ" in orig:
text = text.replace("ɕo","ɕʲo")
text = text.replace("Q","ʔ")
text = text.replace("N","ɴ")
text = re.sub(r'.ʔ', 'ʔ', text)
text = text.replace('" ', '"')
text = text.replace('” ', '”')
return text
to_mel = torchaudio.transforms.MelSpectrogram(
n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
def preprocess(wave):
wave_tensor = torch.from_numpy(wave).float()
mel_tensor = to_mel(wave_tensor)
mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
return mel_tensor
def compute_style(path):
wave, sr = librosa.load(path, sr=24000)
audio, index = librosa.effects.trim(wave, top_db=30)
if sr != 24000:
audio = librosa.resample(audio, sr, 24000)
mel_tensor = preprocess(audio).to(device)
with torch.no_grad():
ref_s = model.style_encoder(mel_tensor.unsqueeze(1))
ref_p = model.predictor_encoder(mel_tensor.unsqueeze(1))
return torch.cat([ref_s, ref_p], dim=1)
device = 'cpu'
if torch.cuda.is_available():
device = 'cuda'
elif torch.backends.mps.is_available():
print("MPS would be available but cannot be used rn")
# device = 'mps'
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
# phonemizer = Phonemizer.from_checkpoint(str(cached_path('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')))
# config = yaml.safe_load(open("Models/LibriTTS/config.yml"))
config = yaml.safe_load(open(str(cached_path("hf://yl4579/StyleTTS2-LibriTTS/Models/LibriTTS/config.yml"))))
# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)
# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)
# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)
model_params = recursive_munch(config['model_params'])
model = build_model(model_params, text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]
# params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
params_whole = torch.load("Models/Kaede.pth", map_location='cpu')
params = params_whole['net']
for key in model:
if key in params:
print('%s loaded' % key)
try:
model[key].load_state_dict(params[key])
except:
from collections import OrderedDict
state_dict = params[key]
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove `module.`
new_state_dict[name] = v
# load params
model[key].load_state_dict(new_state_dict, strict=False)
# except:
# _load(params[key], model[key])
_ = [model[key].eval() for key in model]
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
sampler = DiffusionSampler(
model.diffusion.diffusion,
sampler=ADPM2Sampler(),
sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
clamp=False
)
def inference(text, ref_s, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
# text = text.strip()
# ps = global_phonemizer.phonemize([text])
# ps = word_tokenize(ps[0])
# ps = ' '.join(ps)
text = japanese_cleaners4(text)
print(text)
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s, # reference from the same speaker as the embedding
num_steps=diffusion_steps).squeeze(1)
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
d = model.predictor.text_encoder(d_en,
s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# encode prosody
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en)
asr_new[:, :, 0] = en[:, :, 0]
asr_new[:, :, 1:] = en[:, :, 0:-1]
en = asr_new
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr)
asr_new[:, :, 0] = asr[:, :, 0]
asr_new[:, :, 1:] = asr[:, :, 0:-1]
asr = asr_new
out = model.decoder(asr,
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later
def LFinference(text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
# text = text.strip()
# ps = global_phonemizer.phonemize([text])
# ps = word_tokenize(ps[0])
# ps = ' '.join(ps)
# ps = ps.replace('``', '"')
# ps = ps.replace("''", '"')
text = japanese_cleaners4(text)
print(text)
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s, # reference from the same speaker as the embedding
num_steps=diffusion_steps).squeeze(1)
if s_prev is not None:
# convex combination of previous and current style
s_pred = t * s_prev + (1 - t) * s_pred
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
s_pred = torch.cat([ref, s], dim=-1)
d = model.predictor.text_encoder(d_en,
s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# encode prosody
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en)
asr_new[:, :, 0] = en[:, :, 0]
asr_new[:, :, 1:] = en[:, :, 0:-1]
en = asr_new
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr)
asr_new[:, :, 0] = asr[:, :, 0]
asr_new[:, :, 1:] = asr[:, :, 0:-1]
asr = asr_new
out = model.decoder(asr,
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-100], s_pred # weird pulse at the end of the model, need to be fixed later
def STinference(text, ref_s, ref_text, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, use_gruut=False):
print("don't use")
# text = text.strip()
# ps = global_phonemizer.phonemize([text])
# ps = word_tokenize(ps[0])
# ps = ' '.join(ps)
text = japanese_cleaners4(text)
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
ref_text = ref_text.strip()
ps = global_phonemizer.phonemize([ref_text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
ref_tokens = textclenaer(ps)
ref_tokens.insert(0, 0)
ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
ref_text_mask = length_to_mask(ref_input_lengths).to(device)
ref_bert_dur = model.bert(ref_tokens, attention_mask=(~ref_text_mask).int())
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s, # reference from the same speaker as the embedding
num_steps=diffusion_steps).squeeze(1)
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
d = model.predictor.text_encoder(d_en,
s, input_lengths, text_mask)
x, _ = model.predictor.lstm(d)
duration = model.predictor.duration_proj(x)
duration = torch.sigmoid(duration).sum(axis=-1)
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# encode prosody
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(en)
asr_new[:, :, 0] = en[:, :, 0]
asr_new[:, :, 1:] = en[:, :, 0:-1]
en = asr_new
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
if model_params.decoder.type == "hifigan":
asr_new = torch.zeros_like(asr)
asr_new[:, :, 0] = asr[:, :, 0]
asr_new[:, :, 1:] = asr[:, :, 0:-1]
asr = asr_new
out = model.decoder(asr,
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50] # weird pulse at the end of the model, need to be fixed later