import os import sys import time sys.path.append('../models') import torch import functools import options as opt from torch import optim from tqdm.auto import tqdm from PauseChecker import PauseChecker from Trainer import Trainer from models.LipNetPlus import LipNetPlus from TranslatorTrainer import TranslatorTrainer from dataset import GridDataset, CharMap, Datasets from helpers import contains_nan_or_inf from models.PhonemeTransformer import * from helpers import * class TransformerTrainer(Trainer, TranslatorTrainer): def __init__( self, batch_size=opt.batch_size, word_tokenize=False, dataset_type: Datasets = opt.dataset, embeds_size=256, vocab_files=None, write_logs=True, input_char_map=CharMap.phonemes, output_char_map=CharMap.letters, name='embeds-transformer-v2', **kwargs ): super().__init__(**kwargs, name=name) self.batch_size = batch_size self.word_tokenize = word_tokenize self.input_char_map = input_char_map self.output_char_map = output_char_map self.dataset_type = dataset_type self.embeds_size = embeds_size self.text_tokenizer = functools.partial( GridDataset.tokenize_text, word_tokenize=word_tokenize ) self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) if vocab_files is None: vocabs = self.load_vocabs(self.base_dir) self.phonemes_vocab, self.text_vocab = vocabs else: phonemes_vocab_path, text_vocab_path = vocab_files self.phonemes_vocab = torch.load(phonemes_vocab_path) self.text_vocab = torch.load(text_vocab_path) self.model = None self.optimizer = None self.best_test_loss = float('inf') self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) """ self.phonemes_encoder = self.sequential_transforms( GridDataset.tokenize_phonemes, self.phonemes_vocab, self.tensor_transform ) """ self.text_encoder = self.sequential_transforms( self.text_tokenizer, self.text_vocab, self.tensor_transform ) if write_logs: self.init_tensorboard() def create_model(self): if self.model is None: output_classes = len(self.train_dataset.get_char_mapping()) self.model = LipNetPlus( output_classes=output_classes, pre_gru_repeats=self.pre_gru_repeats, embeds_size=self.embeds_size, output_vocab_size=len(self.text_vocab) ) self.model = self.model.cuda() if self.net is None: self.net = nn.DataParallel(self.model).cuda() def load_datasets(self): if self.train_dataset is None: self.train_dataset = GridDataset( **self.dataset_kwargs, phase='train', file_list=opt.train_list, sample_all_props=True ) if self.test_dataset is None: self.test_dataset = GridDataset( **self.dataset_kwargs, phase='test', file_list=opt.val_list, sample_all_props=True ) def train(self): self.load_datasets() self.create_model() dataset = self.train_dataset loader = self.dataset2dataloader( dataset, num_workers=self.num_workers ) """ optimizer = optim.Adam( self.model.parameters(), lr=opt.base_lr, weight_decay=0., amsgrad=True ) """ optimizer = optim.RMSprop( self.model.parameters(), lr=opt.base_lr ) print('num_train_data:{}'.format(len(dataset.data))) # don't allow loss function to create infinite loss for # sequences that are too short tic = time.time() self.best_test_loss = float('inf') log_scalar = functools.partial(self.log_scalar, label='train') for epoch in range(opt.max_epoch): print(f'RUNNING EPOCH {epoch}') train_wer = [] pbar = tqdm(loader) for (i_iter, input_sample) in enumerate(pbar): PauseChecker.check() self.model.train() vid = input_sample.get('vid').cuda() # vid_len = input_sample.get('vid_len').cuda() # txt, txt_len = self.extract_char_output(input_sample) batch_arr_sentences = input_sample['txt_anno'] batch_arr_sentences = np.array(batch_arr_sentences) _, batch_size = batch_arr_sentences.shape batch_sentences = [ ''.join(batch_arr_sentences[:, k]).strip() for k in range(batch_size) ] tgt = self.collate_tgt_fn(batch_sentences) tgt = tgt.to(self.device) tgt_input = tgt[:-1, :] with torch.no_grad(): gru_output = self.model.forward_gru(vid) y = self.model.predict_from_gru_out(gru_output) src_embeds = self.model.make_src_embeds(gru_output) transformer_out = self.make_transformer_embeds( dataset, src_embeds, y, batch_size=batch_size ) transformer_src_embeds, src_idx_arr = transformer_out transformer_src_embeds = transformer_src_embeds.to(self.device) src_idx_arr = src_idx_arr.to(self.device) max_seq_len, batch_size = src_idx_arr.shape ( src_mask, tgt_mask, src_padding_mask, tgt_padding_mask ) = create_mask( src_idx_arr, tgt_input, self.device ) logits = self.model.seq_forward( transformer_src_embeds, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask ) optimizer.zero_grad() tgt_out = tgt[1:, :] loss = self.loss_fn( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) ) tot_iter = i_iter + epoch * len(loader) loss.backward() optimizer.step() # Convert logits tensor to string with torch.no_grad(): # Convert logits tensor to string probs = torch.softmax(logits, dim=-1) token_indices = torch.argmax(probs, dim=-1) # Convert token indices to strings for # each sequence in the batch gap = ' ' if self.word_tokenize else '' # print('TT', token_indices.shape) pred_sentences = self.batch_indices_to_text( token_indices, batch_size=batch_size, gap=gap ) wer = np.mean(GridDataset.get_wer( pred_sentences, batch_sentences, char_map=self.output_char_map )) train_wer.append(wer) if tot_iter % opt.display == 0: v = 1.0 * (time.time() - tic) / (tot_iter + 1) eta = (len(loader) - i_iter) * v / 3600.0 wer = np.array(train_wer).mean() log_scalar('loss', loss, tot_iter) log_scalar('wer', wer, tot_iter) self.log_pred_texts( pred_sentences, batch_sentences, sub_samples=3 ) print('epoch={},tot_iter={},eta={},loss={},train_wer={}' .format( epoch, tot_iter, eta, loss, np.array(train_wer).mean() )) print(''.join(161 * '-')) if (tot_iter > -1) and (tot_iter % opt.test_step == 0): # if tot_iter % opt.test_step == 0: self.run_test(tot_iter, optimizer) def make_transformer_embeds( self, dataset, src_embeds, y, batch_size ): batch_indices = dataset.ctc_decode_indices(y) filter_batch_embeds = [] pad_embed = self.model.src_tok_emb( torch.IntTensor([PAD_IDX]).to(self.device) ) begin_embed = self.model.src_tok_emb( torch.IntTensor([BOS_IDX]).to(self.device) ) end_embed = self.model.src_tok_emb( torch.IntTensor([EOS_IDX]).to(self.device) ) max_sentence_len = max([len(x) for x in batch_indices]) # initialize embeds with pad token embeddings # [max_seq_len + 1, batch_size, embeds_size] transformer_src_embeds = pad_embed.expand( max_sentence_len + 2, batch_size, pad_embed.shape[1] ) src_idx_mask = torch.full( transformer_src_embeds.shape[:2], PAD_IDX, dtype=torch.int ) # k is sentence index in batch for k, sentence_indices in enumerate(batch_indices): filter_sentence_embeds = [] for sentence_index in sentence_indices: filter_sentence_embeds.append( src_embeds[sentence_index][k] ) sentence_length = len(filter_sentence_embeds) filter_batch_embeds.append(filter_sentence_embeds) # set beginning to sequence embed transformer_src_embeds[0][k] = begin_embed src_idx_mask[0][k] = UNK_IDX # index i is char index in sentence for i, char_embed in enumerate(filter_sentence_embeds): transformer_src_embeds[i + 1][k] = char_embed src_idx_mask[i + 1][k] = UNK_IDX transformer_src_embeds[sentence_length + 1][k] = end_embed src_idx_mask[sentence_length + 1][k] = UNK_IDX return transformer_src_embeds, src_idx_mask @staticmethod def log_pred_texts( pred_txt, truth_txt, pad=80, sub_samples=None ): line_length = 2 * pad + 1 print(''.join(line_length * '-')) print('{:<{pad}}|{:>{pad}}'.format( 'predict', 'truth', pad=pad )) print(''.join(line_length * '-')) zipped_samples = list(zip(pred_txt, truth_txt)) if sub_samples is not None: zipped_samples = zipped_samples[:sub_samples] for (predict, truth) in zipped_samples: print('{:<{pad}}|{:>{pad}}'.format( predict, truth, pad=pad )) print(''.join(line_length * '-')) def test(self): dataset = self.test_dataset with torch.no_grad(): print('num_test_data:{}'.format(len(dataset.data))) self.model.eval() loader = self.dataset2dataloader( dataset, shuffle=False, num_workers=self.num_workers ) loss_list = [] wer = [] cer = [] tic = time.time() print('RUNNING VALIDATION') pbar = tqdm(loader) for (i_iter, input_sample) in enumerate(pbar): PauseChecker.check() vid = input_sample.get('vid').cuda() batch_arr_sentences = input_sample['txt_anno'] batch_arr_sentences = np.array(batch_arr_sentences) _, batch_size = batch_arr_sentences.shape batch_sentences = [ ''.join(batch_arr_sentences[:, k]).strip() for k in range(batch_size) ] tgt = self.collate_tgt_fn(batch_sentences) tgt = tgt.to(self.device) tgt_input = tgt[:-1, :] with torch.no_grad(): gru_output = self.model.forward_gru(vid) y = self.model.predict_from_gru_out(gru_output) src_embeds = self.model.make_src_embeds(gru_output) transformer_out = self.make_transformer_embeds( dataset, src_embeds, y, batch_size=batch_size ) transformer_src_embeds, src_idx_arr = transformer_out transformer_src_embeds = transformer_src_embeds.to(self.device) src_idx_arr = src_idx_arr.to(self.device) max_seq_len, batch_size = src_idx_arr.shape ( src_mask, tgt_mask, src_padding_mask, tgt_padding_mask ) = create_mask( src_idx_arr, tgt_input, self.device ) logits = self.model.seq_forward( transformer_src_embeds, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask ) # Convert logits tensor to string with torch.no_grad(): # Convert logits tensor to string probs = torch.softmax(logits, dim=-1) token_indices = torch.argmax(probs, dim=-1) # Convert token indices to strings for # each sequence in the batch gap = ' ' if self.word_tokenize else '' # print('TT', token_indices.shape) pred_sentences = self.batch_indices_to_text( token_indices, batch_size=batch_size, gap=gap ) tgt_out = tgt[1:, :] loss = self.loss_fn( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) ) loss_item = loss.detach().cpu().numpy() loss_list.append(loss_item) wer.extend(GridDataset.get_wer( pred_sentences, batch_sentences, char_map=self.output_char_map )) cer.extend(GridDataset.get_cer( pred_sentences, batch_sentences, char_map=self.output_char_map )) if i_iter % opt.display == 0: v = 1.0 * (time.time() - tic) / (i_iter + 1) eta = v * (len(loader) - i_iter) / 3600.0 self.log_pred_texts( pred_sentences, batch_sentences, sub_samples=10 ) print('test_iter={},eta={},wer={},cer={}'.format( i_iter, eta, np.array(wer).mean(), np.array(cer).mean() )) print(''.join(161 * '-')) return ( np.array(loss_list).mean(), np.array(wer).mean(), np.array(cer).mean() ) def run_test(self, tot_iter, optimizer): log_scalar = functools.partial(self.log_scalar, label='test') (loss, wer, cer) = self.test() print('i_iter={},lr={},loss={},wer={},cer={}'.format( tot_iter, show_lr(optimizer), loss, wer, cer )) log_scalar('loss', loss, tot_iter) log_scalar('wer', wer, tot_iter) log_scalar('cer', cer, tot_iter) if loss < self.best_test_loss: print(f'NEW BEST LOSS: {loss}') self.best_test_loss = loss savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format( tot_iter, loss, wer, cer ) savename = savename.replace('.', '') + '.pt' savepath = os.path.join(self.weights_dir, savename) (save_dir, name) = os.path.split(savepath) if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(self.model.state_dict(), savepath) print(f'best model saved at {savepath}') if not opt.is_optimize: exit() if __name__ == '__main__': vocab_filepaths = ( 'data/grid_phoneme_vocab.pth', 'data/grid_text_char_vocab.pth' ) """ vocab_filepaths = ( 'data/lsr2_phoneme_vocab.pth', 'data/lsr2_text_char_vocab.pth' ) """ trainer = TransformerTrainer( word_tokenize=False, vocab_files=vocab_filepaths, input_char_map=opt.char_map, output_char_map=opt.text_char_map ) if hasattr(opt, 'weights'): trainer.load_weights(opt.weights) trainer.train()