|
from transformers import AutoModelForSequenceClassification, PreTrainedModel, AutoConfig, AutoModel, AutoTokenizer |
|
import torch |
|
import torch.nn as nn |
|
from text_utils import TextCleaner |
|
textclenaer = TextCleaner() |
|
|
|
|
|
def length_to_mask(lengths): |
|
mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths) |
|
mask = torch.gt(mask+1, lengths.unsqueeze(1)) |
|
return mask |
|
|
|
|
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
|
|
tokenizer_koto_prompt = AutoTokenizer.from_pretrained("ku-nlp/deberta-v3-base-japanese", trust_remote_code=True) |
|
tokenizer_koto_text = AutoTokenizer.from_pretrained("line-corporation/line-distilbert-base-japanese", trust_remote_code=True) |
|
|
|
class KotoDama_Prompt(PreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.backbone = AutoModel.from_config(config) |
|
|
|
self.output = nn.Sequential(nn.Linear(config.hidden_size, 512), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(512, config.num_labels)) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
token_type_ids=None, |
|
position_ids=None, |
|
labels=None, |
|
): |
|
outputs = self.backbone( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
token_type_ids=token_type_ids, |
|
position_ids=position_ids, |
|
) |
|
|
|
|
|
sequence_output = outputs.last_hidden_state[:, 0, :] |
|
outputs = self.output(sequence_output) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fn = nn.MSELoss() |
|
|
|
loss = loss_fn(outputs, labels) |
|
|
|
return { |
|
"loss": loss, |
|
"logits": outputs |
|
} |
|
|
|
|
|
class KotoDama_Text(PreTrainedModel): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.backbone = AutoModel.from_config(config) |
|
|
|
self.output = nn.Sequential(nn.Linear(config.hidden_size, 512), |
|
nn.LeakyReLU(0.2), |
|
nn.Linear(512, config.num_labels)) |
|
|
|
|
|
|
|
def forward( |
|
self, |
|
input_ids, |
|
attention_mask=None, |
|
|
|
|
|
labels=None, |
|
): |
|
outputs = self.backbone( |
|
input_ids, |
|
attention_mask=attention_mask, |
|
|
|
|
|
) |
|
|
|
|
|
sequence_output = outputs.last_hidden_state[:, 0, :] |
|
outputs = self.output(sequence_output) |
|
|
|
|
|
loss = None |
|
if labels is not None: |
|
|
|
loss_fn = nn.MSELoss() |
|
|
|
loss = loss_fn(outputs, labels) |
|
|
|
return { |
|
"loss": loss, |
|
"logits": outputs |
|
} |
|
|
|
|
|
def inference(model, diffusion_sampler, text=None, ref_s=None, alpha = 0.3, beta = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.): |
|
|
|
tokens = textclenaer(text) |
|
tokens.insert(0, 0) |
|
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) |
|
|
|
text_mask = length_to_mask(input_lengths).to(device) |
|
|
|
t_en = model.text_encoder(tokens, input_lengths, text_mask) |
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) |
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2) |
|
|
|
|
|
|
|
s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), |
|
embedding=bert_dur, |
|
embedding_scale=embedding_scale, |
|
features=ref_s, |
|
num_steps=diffusion_steps).squeeze(1) |
|
|
|
|
|
s = s_pred[:, 128:] |
|
ref = s_pred[:, :128] |
|
|
|
ref = alpha * ref + (1 - alpha) * ref_s[:, :128] |
|
s = beta * s + (1 - beta) * ref_s[:, 128:] |
|
|
|
d = model.predictor.text_encoder(d_en, |
|
s, input_lengths, text_mask) |
|
|
|
|
|
|
|
x = model.predictor.lstm(d) |
|
x_mod = model.predictor.prepare_projection(x) |
|
duration = model.predictor.duration_proj(x_mod) |
|
|
|
|
|
duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech |
|
|
|
pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
|
|
|
|
|
|
|
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
|
|
|
c_frame = 0 |
|
for i in range(pred_aln_trg.size(0)): |
|
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
|
c_frame += int(pred_dur[i].data) |
|
|
|
|
|
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) |
|
|
|
|
|
|
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s) |
|
|
|
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) |
|
|
|
|
|
out = model.decoder(asr, |
|
F0_pred, N_pred, ref.squeeze().unsqueeze(0)) |
|
|
|
|
|
return out.squeeze().cpu().numpy()[..., :-50] |
|
|
|
|
|
def Longform(model, diffusion_sampler, text, s_prev, ref_s, alpha = 0.3, beta = 0.7, t = 0.7, diffusion_steps=5, embedding_scale=1, rate_of_speech=1.0): |
|
|
|
tokens = textclenaer(text) |
|
tokens.insert(0, 0) |
|
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device) |
|
text_mask = length_to_mask(input_lengths).to(device) |
|
|
|
t_en = model.text_encoder(tokens, input_lengths, text_mask) |
|
bert_dur = model.bert(tokens, attention_mask=(~text_mask).int()) |
|
d_en = model.bert_encoder(bert_dur).transpose(-1, -2) |
|
|
|
s_pred = diffusion_sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device), |
|
embedding=bert_dur, |
|
embedding_scale=embedding_scale, |
|
features=ref_s, |
|
num_steps=diffusion_steps).squeeze(1) |
|
|
|
if s_prev is not None: |
|
|
|
s_pred = t * s_prev + (1 - t) * s_pred |
|
|
|
s = s_pred[:, 128:] |
|
ref = s_pred[:, :128] |
|
|
|
ref = alpha * ref + (1 - alpha) * ref_s[:, :128] |
|
s = beta * s + (1 - beta) * ref_s[:, 128:] |
|
|
|
s_pred = torch.cat([ref, s], dim=-1) |
|
|
|
d = model.predictor.text_encoder(d_en, |
|
s, input_lengths, text_mask) |
|
|
|
x = model.predictor.lstm(d) |
|
x_mod = model.predictor.prepare_projection(x) |
|
duration = model.predictor.duration_proj(x_mod) |
|
|
|
duration = torch.sigmoid(duration).sum(axis=-1) / rate_of_speech |
|
pred_dur = torch.round(duration.squeeze()).clamp(min=1) |
|
|
|
|
|
pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data)) |
|
c_frame = 0 |
|
for i in range(pred_aln_trg.size(0)): |
|
pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1 |
|
c_frame += int(pred_dur[i].data) |
|
|
|
|
|
en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)) |
|
|
|
F0_pred, N_pred = model.predictor.F0Ntrain(en, s) |
|
|
|
asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device)) |
|
|
|
out = model.decoder(asr, |
|
F0_pred, N_pred, ref.squeeze().unsqueeze(0)) |
|
|
|
|
|
return out.squeeze().cpu().numpy()[..., :-100], s_pred |
|
|
|
|
|
def merge_short_elements(lst): |
|
i = 0 |
|
while i < len(lst): |
|
if i > 0 and len(lst[i]) < 10: |
|
lst[i-1] += ' ' + lst[i] |
|
lst.pop(i) |
|
else: |
|
i += 1 |
|
return lst |
|
|
|
|
|
def merge_three(text_list, maxim=2): |
|
|
|
merged_list = [] |
|
for i in range(0, len(text_list), maxim): |
|
merged_text = ' '.join(text_list[i:i+maxim]) |
|
merged_list.append(merged_text) |
|
return merged_list |
|
|
|
|
|
def merging_sentences(lst): |
|
return merge_three(merge_short_elements(lst)) |
|
|