|
import functools |
|
import torch |
|
import transformers |
|
import peft |
|
from transformers.trainer_pt_utils import LabelSmoother |
|
from utils.dataset import AudioCollator |
|
from utils.logger import MetricLogger |
|
from utils.output import ansi, get_ansi_len, output_iter |
|
|
|
IGNORE_TOKEN_ID = LabelSmoother.ignore_index |
|
|
|
|
|
def train_gpt_lora( |
|
chat, |
|
dataset, |
|
decoder_encoder, |
|
dvae_encoder, |
|
batch_size=16, |
|
epochs=10, |
|
train_text=True, |
|
speaker_embeds=None, |
|
lora_r=8, |
|
lora_alpha=16, |
|
): |
|
if speaker_embeds is None: |
|
speaker_embeds = {} |
|
|
|
tokenizer = chat.pretrain_models["tokenizer"] |
|
decoder_decoder = chat.pretrain_models["decoder"] |
|
decoder_decoder.eval().requires_grad_(False) |
|
decoder_encoder.to(device=dataset.device).eval().requires_grad_(False) |
|
dvae_decoder = chat.pretrain_models["dvae"] |
|
dvae_decoder.eval().requires_grad_(False) |
|
dvae_encoder.to(device=dataset.device).eval().requires_grad_(False) |
|
|
|
gpt = chat.pretrain_models["gpt"] |
|
gpt.train().requires_grad_() |
|
|
|
|
|
lora_config = peft.LoraConfig(r=lora_r, lora_alpha=lora_alpha) |
|
gpt.gpt = peft.get_peft_model(gpt.gpt, lora_config) |
|
|
|
speaker_embeds = { |
|
speaker: torch.randn(768, device=dataset.device, requires_grad=True) |
|
for speaker in dataset.speakers |
|
} | speaker_embeds |
|
|
|
for speaker_embed in speaker_embeds.values(): |
|
std, mean = chat.pretrain_models["spk_stat"].chunk(2) |
|
speaker_embed.data = speaker_embed.data * std + mean |
|
|
|
SPEAKER_TOKEN_ID = tokenizer.convert_tokens_to_ids("[spk_emb]") |
|
AUDIO_EOS_TOKEN_ID = 0 |
|
AUDIO_PAD_TOKEN_ID = AUDIO_EOS_TOKEN_ID |
|
|
|
train_params = list(gpt.parameters()) + list(speaker_embeds.values()) |
|
optimizer = torch.optim.Adam( |
|
gpt.parameters(), lr=1e-3, weight_decay=0, betas=[0.9, 0.95], eps=1e-5 |
|
) |
|
optimizer.add_param_group({"params": speaker_embeds.values(), "lr": 1e-1}) |
|
|
|
loss_fn = torch.nn.CrossEntropyLoss() |
|
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs, 1e-7) |
|
|
|
loader = torch.utils.data.DataLoader( |
|
dataset, |
|
batch_size=batch_size, |
|
shuffle=True, |
|
collate_fn=AudioCollator(text_pad=tokenizer.pad_token_id), |
|
) |
|
logger = MetricLogger() |
|
logger.create_meters(loss=None, mse_loss=None, audio_loss=None, text_loss=None) |
|
|
|
for _epoch in range(epochs): |
|
_epoch += 1 |
|
logger.reset() |
|
header = "{blue_light}{0}: {1}{reset}".format( |
|
"Epoch", output_iter(_epoch, epochs), **ansi |
|
) |
|
header = header.ljust(max(len("Epoch"), 30) + get_ansi_len(header)) |
|
iterator = logger.log_every(loader, header=header, tqdm_header="Batch") |
|
|
|
for batch in iterator: |
|
speakers = batch["speaker"] |
|
text_input_ids = batch["text_input_ids"] |
|
text_attention_mask = batch["text_attention_mask"] |
|
audio_mel_specs = batch["audio_mel_specs"] |
|
audio_attention_mask = batch["audio_attention_mask"] |
|
|
|
batch_size, text_len = text_attention_mask.size() |
|
|
|
dvae_audio_latents = dvae_encoder(audio_mel_specs, audio_attention_mask) |
|
_, dvae_audio_input_ids = quantize( |
|
dvae_decoder.vq_layer.quantizer, dvae_audio_latents |
|
) |
|
dvae_audio_input_ids[~audio_attention_mask.bool()] = AUDIO_PAD_TOKEN_ID |
|
|
|
extended_audio_attention_mask = torch.cat( |
|
[ |
|
audio_attention_mask, |
|
torch.zeros( |
|
(batch_size, 1), |
|
dtype=audio_attention_mask.dtype, |
|
device=audio_attention_mask.device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
extended_audio_input_ids = torch.cat( |
|
[ |
|
dvae_audio_input_ids, |
|
AUDIO_PAD_TOKEN_ID |
|
* torch.ones( |
|
(batch_size, 1, gpt.num_vq), |
|
dtype=dvae_audio_input_ids.dtype, |
|
device=dvae_audio_input_ids.device, |
|
), |
|
], |
|
dim=1, |
|
) |
|
|
|
indices = audio_attention_mask.int().sum(dim=1) |
|
for i in range(batch_size): |
|
extended_audio_attention_mask[i, indices[i]] = 1 |
|
extended_audio_input_ids[i, indices[i]] = AUDIO_EOS_TOKEN_ID |
|
|
|
input_ids = torch.cat( |
|
[ |
|
text_input_ids.unsqueeze(-1).repeat(1, 1, gpt.num_vq), |
|
extended_audio_input_ids, |
|
], |
|
dim=1, |
|
) |
|
attention_mask = torch.cat( |
|
[text_attention_mask, extended_audio_attention_mask], dim=1 |
|
) |
|
text_mask = torch.cat( |
|
[ |
|
torch.ones_like(text_attention_mask, dtype=bool), |
|
torch.zeros_like(extended_audio_attention_mask, dtype=bool), |
|
], |
|
dim=1, |
|
) |
|
labels = input_ids.clone() |
|
labels[~attention_mask.bool()] = IGNORE_TOKEN_ID |
|
|
|
inputs_embeds = gpt.get_emb(input_ids=input_ids, text_mask=text_mask) |
|
|
|
indices = torch.all(input_ids == SPEAKER_TOKEN_ID, dim=-1) |
|
for i, speaker in enumerate(speakers): |
|
inputs_embeds[i, indices[i]] = torch.nn.functional.normalize( |
|
speaker_embeds[speaker].to(dtype=inputs_embeds.dtype), |
|
p=2.0, |
|
dim=-1, |
|
eps=1e-12, |
|
).unsqueeze(0) |
|
|
|
outputs = gpt.gpt.forward( |
|
inputs_embeds=inputs_embeds, attention_mask=attention_mask |
|
) |
|
hidden_states = outputs.last_hidden_state |
|
text_hidden_states = hidden_states[:, : text_len - 1] |
|
audio_hidden_states = hidden_states[:, text_len - 1 : -1] |
|
|
|
audio_logits = torch.stack( |
|
[gpt.head_code[i](audio_hidden_states) for i in range(gpt.num_vq)], |
|
dim=2, |
|
) |
|
audio_loss = loss_fn( |
|
audio_logits.flatten(0, 2), labels[:, text_len:].flatten(0, 2) |
|
) |
|
loss = audio_loss |
|
|
|
if train_text: |
|
text_logits = gpt.head_text(text_hidden_states) |
|
text_loss = loss_fn( |
|
text_logits.flatten(0, 1), labels[:, 1:text_len, 0].flatten(0, 1) |
|
) |
|
loss += text_loss |
|
logger.meters["text_loss"].update(text_loss.item(), n=batch_size) |
|
|
|
gpt_gen_mel_specs = decoder_decoder( |
|
audio_hidden_states[:, :-1].transpose(1, 2) |
|
).transpose(1, 2) |
|
mse_loss = torch.nn.functional.mse_loss(gpt_gen_mel_specs, audio_mel_specs) |
|
loss += 0.01 * mse_loss |
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
torch.nn.utils.clip_grad_norm_(train_params, 1.0) |
|
optimizer.step() |
|
|
|
logger.meters["loss"].update(loss.item(), n=batch_size) |
|
logger.meters["mse_loss"].update(mse_loss.item(), n=batch_size) |
|
logger.meters["audio_loss"].update(audio_loss.item(), n=batch_size) |
|
|
|
lr_scheduler.step() |
|
optimizer.zero_grad() |
|
return speaker_embeds |
|
|
|
|
|
|
|
def main(): |
|
|
|
chat = ChatTTS.Chat() |
|
chat.load_models() |
|
dataset = XzListTar( |
|
root="data/all.list", |
|
tokenizer=chat.pretrain_models["tokenizer"], |
|
vocos_model=chat.pretrain_models["vocos"], |
|
tar_path="data/Xz.tar", |
|
tar_in_memory=True, |
|
process_ahead=True, |
|
) |
|
|
|
decoder_encoder = DVAEEncoder( |
|
**get_encoder_config(chat.pretrain_models["decoder"].decoder) |
|
) |
|
dvae_encoder = DVAEEncoder( |
|
**get_encoder_config(chat.pretrain_models["dvae"].decoder) |
|
) |
|
|
|
|
|
speaker_embeds = train_gpt_lora( |
|
chat=chat, |
|
dataset=dataset, |
|
decoder_encoder=decoder_encoder, |
|
dvae_encoder=dvae_encoder, |
|
batch_size=32, |
|
epochs=10, |
|
train_text=True, |
|
lora_r=8, |
|
lora_alpha=16, |
|
) |
|
|
|
|
|
lora_save_path = "./saved_models/gpt_lora.pth" |
|
peft.save_pretrained(gpt.gpt, lora_save_path) |
|
np.savez( |
|
"./saved_models/speaker_embeds.npz", |
|
**{k: v.cpu().numpy() for k, v in speaker_embeds.items()} |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|