maskgct / models /tts /valle_v2 /valle_ar.py
Hecheng0625's picture
Upload 409 files
c968fc3 verified
raw
history blame
10.4 kB
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
from .modeling_llama import LlamaConfig, LlamaForCausalLM, LlamaModel
import torch
import torch.nn.functional as F
import numpy as np
import os
import torch.nn as nn
class ValleAR(nn.Module):
def __init__(
self,
phone_vocab_size=256,
target_vocab_size=1024,
hidden_size=1024,
intermediate_size=4096,
num_hidden_layers=12,
num_attention_heads=16,
pad_token_id=1281,
bos_target_id=1282,
eos_target_id=1283,
bos_phone_id=1284,
eos_phone_id=1285,
use_input_embeds=False,
emb_dim=256,
**kwargs,
):
super(ValleAR, self).__init__()
self.config = LlamaConfig(
vocab_size=phone_vocab_size + target_vocab_size + 10,
hidden_size=hidden_size,
intermediate_size=intermediate_size,
num_hidden_layers=num_hidden_layers,
num_attention_heads=num_attention_heads,
pad_token_id=pad_token_id,
bos_token_id=bos_target_id,
eos_token_id=eos_target_id,
)
self.phone_vocab_size = phone_vocab_size
self.target_vocab_size = target_vocab_size
self.pad_token_id = pad_token_id
self.bos_target_id = bos_target_id
self.eos_target_id = eos_target_id
self.bos_phone_id = bos_phone_id
self.eos_phone_id = eos_phone_id
self.model = LlamaForCausalLM(self.config)
self.use_input_embeds = use_input_embeds
# no input embedding is used to provide speaker information
if self.use_input_embeds:
self.emb_linear = nn.Linear(emb_dim, hidden_size)
self.emb_linear.weight.data.normal_(mean=0.0, std=0.01)
self.emb_linear.bias.data.zero_()
def forward(
self, phone_ids, phone_mask, target_ids, target_mask, input_embeds=None
):
if input_embeds is not None:
input_embeds = self.emb_linear(input_embeds)
phone_ids, phone_mask, phone_label = self.add_phone_eos_bos_label(
phone_ids,
phone_mask,
self.eos_phone_id,
self.bos_phone_id,
self.pad_token_id,
)
target_ids, target_mask, target_label = self.add_target_eos_bos_label(
target_ids,
target_mask,
self.eos_target_id,
self.bos_target_id,
self.pad_token_id,
)
input_token_ids = torch.cat([phone_ids, target_ids], dim=-1)
attention_mask = torch.cat([phone_mask, target_mask], dim=-1)
# breakpoint()
if input_embeds is not None:
raise NotImplementedError
attention_mask = torch.cat(
[
torch.ones(
(input_embeds.shape[0], input_embeds.shape[1]),
dtype=attention_mask.dtype,
device=attention_mask.device,
),
attention_mask,
],
dim=-1,
)
labels = torch.cat([phone_label, target_label], dim=-1)
if input_embeds is not None:
raise NotImplementedError
labels = torch.cat(
[
-100
* torch.ones(
(input_embeds.shape[0], input_embeds.shape[1]),
dtype=labels.dtype,
device=labels.device,
),
labels,
],
dim=-1,
)
if input_embeds is not None:
raise NotImplementedError
inputs_embeds = torch.cat(
[input_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
)
out = self.model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
)
return out
out = self.model(
input_token_ids,
attention_mask=attention_mask,
labels=labels,
return_dict=True,
)
# calcualte top1, top5, top10 accuracy
logits = out.logits
logits = logits[:, -target_ids.shape[1] :]
top1_acc = logits.argmax(-1)[..., :-1] == target_ids[:, 1:]
top1_acc = (top1_acc * target_mask[..., :-1]).sum() / target_mask.sum()
top5_acc = torch.topk(logits[..., :-1, :], 5, dim=-1)[1]
top5_acc = top5_acc == target_ids[:, 1:].unsqueeze(-1)
top5_acc = (
top5_acc * target_mask[..., :-1].unsqueeze(-1)
).sum() / target_mask.sum()
top10_acc = torch.topk(logits[..., :-1, :], 10, dim=-1)[1]
top10_acc = top10_acc == target_ids[:, 1:].unsqueeze(-1)
top10_acc = (
top10_acc * target_mask[..., :-1].unsqueeze(-1)
).sum() / target_mask.sum()
out.top1_acc = top1_acc
out.top5_acc = top5_acc
out.top10_acc = top10_acc
return out
def add_phone_eos_bos_label(
self, phone_ids, phone_mask, phone_eos_id, phone_bos_id, pad_token_id
):
# phone_ids: [B, T]
# phone_mask: [B, T]
phone_ids = phone_ids + self.target_vocab_size * phone_mask
phone_ids = phone_ids * phone_mask
phone_ids = F.pad(phone_ids, (0, 1), value=0) + phone_eos_id * F.pad(
1 - phone_mask, (0, 1), value=1
) # make pad token eos token, add eos token at the end
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add eos mask
phone_ids = phone_ids * phone_mask + pad_token_id * (
1 - phone_mask
) # restore pad token ids
phone_ids = F.pad(phone_ids, (1, 0), value=phone_bos_id) # add bos token
phone_mask = F.pad(phone_mask, (1, 0), value=1) # add bos mask
phone_label = -100 * torch.ones_like(
phone_ids
) # loss for entire phone is not computed (passed to llama)
return phone_ids, phone_mask, phone_label
def add_target_eos_bos_label(
self, target_ids, target_mask, target_eos_id, target_bos_id, pad_token_id
):
# target_ids: [B, T]
# target_mask: [B, T]
target_ids = target_ids * target_mask
target_ids = F.pad(target_ids, (0, 1), value=0) + target_eos_id * F.pad(
1 - target_mask, (0, 1), value=1
)
target_mask = F.pad(target_mask, (1, 0), value=1)
target_ids = target_ids * target_mask + pad_token_id * (1 - target_mask)
target_ids = F.pad(target_ids, (1, 0), value=target_bos_id)
target_mask = F.pad(target_mask, (1, 0), value=1)
target_label = target_ids * target_mask + (-100) * (
1 - target_mask
) # loss for target is computed on unmasked tokens
return target_ids, target_mask, target_label
def sample_hf(
self,
phone_ids, # the phones of prompt and target should be concatenated together
prompt_ids,
inputs_embeds=None,
max_length=2000,
temperature=1.0,
top_k=100,
top_p=0.9,
repeat_penalty=1.0,
num_beams=1,
):
if inputs_embeds is not None:
inputs_embeds = self.emb_linear(inputs_embeds)
phone_mask = torch.ones_like(phone_ids)
prompt_mask = torch.ones_like(prompt_ids)
phone_ids, _, _ = self.add_phone_eos_bos_label(
phone_ids,
phone_mask,
self.eos_phone_id,
self.bos_phone_id,
self.pad_token_id,
)
prompt_ids, _, _ = self.add_target_eos_bos_label(
prompt_ids,
prompt_mask,
self.eos_target_id,
self.bos_target_id,
self.pad_token_id,
)
prompt_ids = prompt_ids[:, :-1] # remove end token. Make it continue mode
input_token_ids = torch.cat([phone_ids, prompt_ids], dim=-1)
if inputs_embeds is not None:
raise NotImplementedError
inputs_embeds = torch.cat(
[inputs_embeds, self.model.model.embed_tokens(input_token_ids)], dim=1
)
generated_ids = self.model.generate(
inputs_embeds=inputs_embeds,
do_sample=True,
max_length=max_length,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_target_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
)
gen_tokens = generated_ids[:, :-1]
return gen_tokens
input_length = input_token_ids.shape[1]
generated_ids = self.model.generate(
input_token_ids,
do_sample=True,
max_length=max_length,
pad_token_id=self.pad_token_id,
eos_token_id=self.eos_target_id,
temperature=temperature,
top_k=top_k,
top_p=top_p,
repetition_penalty=repeat_penalty,
num_beams=num_beams,
)
gen_tokens = generated_ids[:, input_length:-1]
return gen_tokens
def test():
model = ValleAR()
phone_ids = torch.LongTensor([[1, 2, 3, 4, 5, 0], [1, 2, 3, 4, 5, 6]])
phone_mask = torch.LongTensor([[1, 1, 1, 0, 0, 0], [1, 1, 1, 0, 0, 0]])
target_ids = torch.LongTensor([765, 234, 123, 234, 123, 599]).expand(2, -1)
target_mask = torch.LongTensor([1, 1, 1, 1, 0, 0]).expand(2, -1)
optimizer = torch.optim.Adam(model.parameters(), lr=3e-4)
for i in range(15):
optimizer.zero_grad()
out = model(
phone_ids=phone_ids,
phone_mask=phone_mask,
target_ids=target_ids,
target_mask=target_mask,
)
loss = out.loss
loss.backward()
optimizer.step()
print(f"iter={i}, {loss}.")
phone_ids = torch.LongTensor([1, 2, 3]).reshape(1, -1)
target_ids = torch.LongTensor([765, 234]).reshape(1, -1)
sampled = model.sample_hf(phone_ids, target_ids)
breakpoint()
if __name__ == "__main__":
test()