Tsukasa_Speech / Modules /KotoDama_sampler.py
Respair's picture
Upload folder using huggingface_hub
bcdb559 verified
raw
history blame
8.54 kB
from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer
import torch
import torch.nn as nn
from text_utils import TextCleaner
textclenaer = TextCleaner()
def length_to_mask(lengths):
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
mask = torch.gt(mask+1, lengths.unsqueeze(1))
return mask
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# tokenizer_koto_prompt = AutoTokenizer.from_pretrained("google/mt5-small", trust_remote_code=True)
tokenizer_koto_prompt = AutoTokenizer.from_pretrained("ku-nlp/deberta-v3-base-japanese", trust_remote_code=True)
tokenizer_koto_text = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True)
class KotoDama_Prompt(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.backbone = AutoModel.from_config(config)
self.output = nn.Sequential(nn.Linear(config.hidden_size, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, config.num_labels))
def forward(
self,
input_ids,
attention_mask=None,
token_type_ids=None,
position_ids=None,
labels=None,
):
outputs = self.backbone(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
)
sequence_output = outputs.last_hidden_state[:, 0, :]
outputs = self.output(sequence_output)
# if labels, then we are training
loss = None
if labels is not None:
loss_fn = nn.MSELoss()
# labels = labels.unsqueeze(1)
loss = loss_fn(outputs, labels)
return {
"loss": loss,
"logits": outputs
}
class KotoDama_Text(PreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.backbone = AutoModel.from_config(config)
self.output = nn.Sequential(nn.Linear(config.hidden_size, 512),
nn.LeakyReLU(0.2),
nn.Linear(512, config.num_labels))
def forward(
self,
input_ids,
attention_mask=None,
# token_type_ids=None,
# position_ids=None,
labels=None,
):
outputs = self.backbone(
input_ids,
attention_mask=attention_mask,
# token_type_ids=token_type_ids,
# position_ids=position_ids,
)
sequence_output = outputs.last_hidden_state[:, 0, :]
outputs = self.output(sequence_output)
# if labels, then we are training
loss = None
if labels is not None:
loss_fn = nn.MSELoss()
# labels = labels.unsqueeze(1)
loss = loss_fn(outputs, labels)
return {
"loss": loss,
"logits": outputs
}
def inference(model, diffusion_sampler, text=None, ref_s=None, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.):
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s, # reference from the same speaker as the embedding
num_steps=diffusion_steps).squeeze(1)
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
d = model.predictor.text_encoder(d_en,
s, input_lengths, text_mask)
x = model.predictor.lstm(d)
x_mod = model.predictor.prepare_projection(x)
duration = model.predictor.duration_proj(x_mod)
duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# encode prosody
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
out = model.decoder(asr,
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-50]
def Longform(model, diffusion_sampler, text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.0):
tokens = textclenaer(text)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
with torch.no_grad():
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
text_mask = length_to_mask(input_lengths).to(device)
t_en = model.text_encoder(tokens, input_lengths, text_mask)
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
d_en = model.bert_encoder(bert_dur).transpose(-1, -2)
s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
embedding=bert_dur,
embedding_scale=embedding_scale,
features=ref_s,
num_steps=diffusion_steps).squeeze(1)
if s_prev is not None:
# convex combination of previous and current style
s_pred = t * s_prev + (1 - t) * s_pred
s = s_pred[:, 128:]
ref = s_pred[:, :128]
ref = alpha * ref + (1 - alpha) * ref_s[:, :128]
s = beta * s + (1 - beta) * ref_s[:, 128:]
s_pred = torch.cat([ref, s], dim=-1)
d = model.predictor.text_encoder(d_en,
s, input_lengths, text_mask)
x = model.predictor.lstm(d)
x_mod = model.predictor.prepare_projection(x) # 640 -> 512
duration = model.predictor.duration_proj(x_mod)
duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech
pred_dur = torch.round(duration.squeeze()).clamp(min=1)
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
c_frame = 0
for i in range(pred_aln_trg.size(0)):
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
c_frame += int(pred_dur[i].data)
# encode prosody
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
out = model.decoder(asr,
F0_pred, N_pred, ref.squeeze().unsqueeze(0))
return out.squeeze().cpu().numpy()[..., :-100], s_pred
def merge_short_elements(lst):
i = 0
while i < len(lst):
if i > 0 and len(lst[i]) < 10:
lst[i-1] += ' ' + lst[i]
lst.pop(i)
else:
i += 1
return lst
def merge_three(text_list, maxim=2):
merged_list = []
for i in range(0, len(text_list), maxim):
merged_text = ' '.join(text_list[i:i+maxim])
merged_list.append(merged_text)
return merged_list
def merging_sentences(lst):
return merge_three(merge_short_elements(lst))