eriquesouza's picture
app v1
e831f85
import os
import time
import torch
import torch.multiprocessing
from torch.nn.utils.rnn import pad_sequence
from torch.optim import RAdam
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.Aligner import Aligner
from TrainingInterfaces.Text_to_Spectrogram.AutoAligner.TinyTTS import TinyTTS
def collate_and_pad(batch):
# text, text_len, speech, speech_len
return (pad_sequence([datapoint[0] for datapoint in batch], batch_first=True),
torch.stack([datapoint[1] for datapoint in batch]).squeeze(1),
pad_sequence([datapoint[2] for datapoint in batch], batch_first=True),
torch.stack([datapoint[3] for datapoint in batch]).squeeze(1),
torch.stack([datapoint[4] for datapoint in batch]).squeeze())
def train_loop(train_dataset,
device,
save_directory,
batch_size,
steps,
path_to_checkpoint=None,
fine_tune=False,
resume=False,
debug_img_path=None,
use_reconstruction=True):
"""
Args:
resume: whether to resume from the most recent checkpoint
steps: How many steps to train
path_to_checkpoint: reloads a checkpoint to continue training from there
fine_tune: whether to load everything from a checkpoint, or only the model parameters
train_dataset: Pytorch Dataset Object for train data
device: Device to put the loaded tensors on
save_directory: Where to save the checkpoints
batch_size: How many elements should be loaded at once
"""
os.makedirs(save_directory, exist_ok=True)
train_loader = DataLoader(batch_size=batch_size,
dataset=train_dataset,
drop_last=True,
num_workers=8,
pin_memory=False,
shuffle=True,
prefetch_factor=16,
collate_fn=collate_and_pad,
persistent_workers=True)
asr_model = Aligner().to(device)
optim_asr = RAdam(asr_model.parameters(), lr=0.0001)
tiny_tts = TinyTTS().to(device)
optim_tts = RAdam(tiny_tts.parameters(), lr=0.0001)
step_counter = 0
if resume:
previous_checkpoint = os.path.join(save_directory, "aligner.pt")
path_to_checkpoint = previous_checkpoint
fine_tune = False
if path_to_checkpoint is not None:
check_dict = torch.load(os.path.join(path_to_checkpoint), map_location=device)
asr_model.load_state_dict(check_dict["asr_model"])
tiny_tts.load_state_dict(check_dict["tts_model"])
if not fine_tune:
optim_asr.load_state_dict(check_dict["optimizer"])
optim_tts.load_state_dict(check_dict["tts_optimizer"])
step_counter = check_dict["step_counter"]
if step_counter > steps:
print("Desired steps already reached in loaded checkpoint.")
return
start_time = time.time()
while True:
loss_sum = list()
asr_model.train()
tiny_tts.train()
for batch in tqdm(train_loader):
tokens = batch[0].to(device)
tokens_len = batch[1].to(device)
mel = batch[2].to(device)
mel_len = batch[3].to(device)
speaker_embeddings = batch[4].to(device)
pred = asr_model(mel, mel_len)
ctc_loss = asr_model.ctc_loss(pred.transpose(0, 1).log_softmax(2),
tokens,
mel_len,
tokens_len)
if use_reconstruction:
speaker_embeddings_expanded = torch.nn.functional.normalize(speaker_embeddings).unsqueeze(1).expand(-1, pred.size(1), -1)
tts_lambda = min([5, step_counter / 2000]) # super simple schedule
reconstruction_loss = tiny_tts(x=torch.cat([pred, speaker_embeddings_expanded], dim=-1),
# combine ASR prediction with speaker embeddings to allow for reconstruction loss on multiple speakers
lens=mel_len,
ys=mel) * tts_lambda # reconstruction loss to make the states more distinct
loss = ctc_loss + reconstruction_loss
else:
loss = ctc_loss
optim_asr.zero_grad()
if use_reconstruction:
optim_tts.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(asr_model.parameters(), 1.0)
if use_reconstruction:
torch.nn.utils.clip_grad_norm_(tiny_tts.parameters(), 1.0)
optim_asr.step()
if use_reconstruction:
optim_tts.step()
step_counter += 1
loss_sum.append(loss.item())
asr_model.eval()
loss_this_epoch = sum(loss_sum) / len(loss_sum)
torch.save({
"asr_model" : asr_model.state_dict(),
"optimizer" : optim_asr.state_dict(),
"tts_model" : tiny_tts.state_dict(),
"tts_optimizer": optim_tts.state_dict(),
"step_counter" : step_counter,
},
os.path.join(save_directory, "aligner.pt"))
print("Total Loss: {}".format(round(loss_this_epoch, 3)))
print("Time elapsed: {} Minutes".format(round((time.time() - start_time) / 60)))
print("Steps: {}".format(step_counter))
if debug_img_path is not None:
asr_model.inference(mel=mel[0][:mel_len[0]],
tokens=tokens[0][:tokens_len[0]],
save_img_for_debug=debug_img_path + f"/{step_counter}.png",
train=True) # for testing
if step_counter > steps:
return