import torch.nn as nn import functools import torch.optim as optim import options as opt import time from helpers import * from dataset import GridDataset, CharMap from datetime import datetime as Datetime from models.LipNet import LipNet from tqdm.auto import tqdm from PauseChecker import PauseChecker from torch.utils.data import DataLoader from torch.multiprocessing import Manager from BaseTrainer import BaseTrainer class Trainer(BaseTrainer): def __init__( self, name=opt.run_name, write_logs=True, num_workers=None, base_dir='', char_map=opt.char_map, pre_gru_repeats=None ): super().__init__(name=name, base_dir=base_dir) images_dir = opt.images_dir if opt.use_lip_crops: images_dir = opt.crop_images_dir if num_workers is None: num_workers = opt.num_workers if pre_gru_repeats is None: pre_gru_repeats = opt.pre_gru_repeats assert pre_gru_repeats >= 1 assert isinstance(pre_gru_repeats, int) self.images_dir = images_dir self.num_workers = num_workers self.pre_gru_repeats = pre_gru_repeats self.char_map = char_map manager = Manager() if opt.cache_videos: shared_dict = manager.dict() else: shared_dict = None self.shared_dict = shared_dict self.dataset_kwargs = self.get_dataset_kwargs( shared_dict=shared_dict, base_dir=self.base_dir, char_map=self.char_map ) self.best_test_loss = float('inf') self.train_dataset = None self.test_dataset = None self.model = None self.net = None if write_logs: self.init_tensorboard() def load_datasets(self): if self.train_dataset is None: self.train_dataset = GridDataset( **self.dataset_kwargs, phase='train', file_list=opt.train_list ) if self.test_dataset is None: self.test_dataset = GridDataset( **self.dataset_kwargs, phase='test', file_list=opt.val_list ) def create_model(self): output_classes = len(self.train_dataset.get_char_mapping()) if self.model is None: self.model = LipNet( output_classes=output_classes, pre_gru_repeats=self.pre_gru_repeats ) self.model = self.model.cuda() if self.net is None: self.net = nn.DataParallel(self.model).cuda() def load_weights(self, weights_path): self.load_datasets() self.create_model() weights_path = os.path.join(self.base_dir, weights_path) pretrained_dict = torch.load(weights_path) model_dict = self.model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() and v.size() == model_dict[k].size() } missed_params = [ k for k, v in model_dict.items() if k not in pretrained_dict.keys() ] print('loaded params/tot params: {}/{}'.format( len(pretrained_dict), len(model_dict) )) print('miss matched params:{}'.format(missed_params)) model_dict.update(pretrained_dict) self.model.load_state_dict(model_dict) @staticmethod def make_date_stamp(): return Datetime.now().strftime("%y%m%d-%H%M") @staticmethod def dataset2dataloader( dataset, num_workers, shuffle=True ): return DataLoader( dataset, batch_size=opt.batch_size, shuffle=shuffle, num_workers=num_workers, drop_last=False ) def test(self): dataset = self.test_dataset with torch.no_grad(): print('num_test_data:{}'.format(len(dataset.data))) self.model.eval() loader = self.dataset2dataloader( dataset, shuffle=False, num_workers=self.num_workers ) loss_list = [] wer = [] cer = [] crit = nn.CTCLoss(zero_infinity=True) tic = time.time() print('RUNNING VALIDATION') pbar = tqdm(loader) for (i_iter, input_sample) in enumerate(pbar): PauseChecker.check() vid = input_sample.get('vid').cuda() vid_len = input_sample.get('vid_len').cuda() txt, txt_len = self.extract_char_output(input_sample) y = self.net(vid) # assert not contains_nan_or_inf(y) assert ( self.pre_gru_repeats * vid_len.view(-1) > 2 * txt_len.view(-1) ).all() loss = crit( y.transpose(0, 1).log_softmax(-1), txt, self.pre_gru_repeats * vid_len.view(-1), txt_len.view(-1) ).detach().cpu().numpy() loss_list.append(loss) pred_txt = dataset.ctc_decode(y) truth_txt = [ dataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0)) ] wer.extend(dataset.wer(pred_txt, truth_txt)) cer.extend(dataset.cer(pred_txt, truth_txt)) if i_iter % opt.display == 0: v = 1.0 * (time.time() - tic) / (i_iter + 1) eta = v * (len(loader) - i_iter) / 3600.0 self.log_pred_texts(pred_txt, truth_txt, sub_samples=10) print('test_iter={},eta={},wer={},cer={}'.format( i_iter, eta, np.array(wer).mean(), np.array(cer).mean() )) print(''.join(161 * '-')) return ( np.array(loss_list).mean(), np.array(wer).mean(), np.array(cer).mean() ) def extract_char_output(self, input_sample): """ extract output character sequence from input_sample output character sequence is text if char_map is CharMap.letters output character sequence is phonemes if char_map is CharMap.phonemes """ if self.char_map == CharMap.letters: txt = input_sample.get('txt').cuda() txt_len = input_sample.get('txt_len').cuda() elif self.char_map == CharMap.phonemes: txt = input_sample.get('phonemes').cuda() txt_len = input_sample.get('phonemes_len').cuda() elif self.char_map == CharMap.cmu_phonemes: txt = input_sample.get('cmu_phonemes').cuda() txt_len = input_sample.get('cmu_phonemes_len').cuda() else: raise ValueError(f'UNSUPPORTED CHAR_MAP: {self.char_map}') return txt, txt_len def train(self): self.load_datasets() self.create_model() dataset = self.train_dataset loader = self.dataset2dataloader( dataset, num_workers=self.num_workers ) """ optimizer = optim.Adam( self.model.parameters(), lr=opt.base_lr, weight_decay=0., amsgrad=True ) """ optimizer = optim.RMSprop( self.model.parameters(), lr=opt.base_lr ) print('num_train_data:{}'.format(len(dataset.data))) # don't allow loss function to create infinite loss for # sequences that are too short crit = nn.CTCLoss(zero_infinity=True) tic = time.time() train_wer = [] self.best_test_loss = float('inf') log_scalar = functools.partial(self.log_scalar, label='train') for epoch in range(opt.max_epoch): print(f'RUNNING EPOCH {epoch}') pbar = tqdm(loader) for (i_iter, input_sample) in enumerate(pbar): PauseChecker.check() self.model.train() vid = input_sample.get('vid').cuda() vid_len = input_sample.get('vid_len').cuda() txt, txt_len = self.extract_char_output(input_sample) optimizer.zero_grad() y = self.net(vid) assert not contains_nan_or_inf(y) assert ( self.pre_gru_repeats * vid_len.view(-1) > 2 * txt_len.view(-1) ).all() loss = crit( y.transpose(0, 1).log_softmax(-1), txt, self.pre_gru_repeats * vid_len.view(-1), txt_len.view(-1) ) if contains_nan_or_inf(loss): print(f'LOSS IS INVALID. SKIPPING {i_iter}') # print('Y', y) # print('txt', txt) continue loss.backward() params = self.model.parameters() # Check for NaNs in gradients if any(torch.isnan(p.grad).any() for p in params): optimizer.zero_grad() # Clear gradients to prevent update print('SKIPPING NAN GRADS') continue if opt.is_optimize: optimizer.step() assert not contains_nan_or_inf(self.model.conv1.weight) tot_iter = i_iter + epoch * len(loader) pred_txt = dataset.ctc_decode(y) truth_txt = [ dataset.arr2txt(txt[_], start=1) for _ in range(txt.size(0)) ] train_wer.extend(dataset.wer(pred_txt, truth_txt)) if tot_iter % opt.display == 0: v = 1.0 * (time.time() - tic) / (tot_iter + 1) eta = (len(loader) - i_iter) * v / 3600.0 wer = np.array(train_wer).mean() log_scalar('loss', loss, tot_iter) log_scalar('wer', wer, tot_iter) self.log_pred_texts(pred_txt, truth_txt, sub_samples=3) print('epoch={},tot_iter={},eta={},loss={},train_wer={}' .format( epoch, tot_iter, eta, loss, np.array(train_wer).mean() ) ) print(''.join(161 * '-')) if (tot_iter > 0) and (tot_iter % opt.test_step == 0): # if tot_iter % opt.test_step == 0: self.run_test(tot_iter, optimizer) @staticmethod def log_pred_texts(pred_txt, truth_txt, pad=80, sub_samples=None): line_length = 2 * pad + 1 print(''.join(line_length * '-')) print('{:<{pad}}|{:>{pad}}'.format( 'predict', 'truth', pad=pad )) print(''.join(line_length * '-')) zipped_samples = list(zip(pred_txt, truth_txt)) if sub_samples is not None: zipped_samples = zipped_samples[:sub_samples] for (predict, truth) in zipped_samples: print('{:<{pad}}|{:>{pad}}'.format( predict, truth, pad=pad )) print(''.join(line_length * '-')) def run_test(self, tot_iter, optimizer): log_scalar = functools.partial(self.log_scalar, label='test') (loss, wer, cer) = self.test() print('i_iter={},lr={},loss={},wer={},cer={}'.format( tot_iter, show_lr(optimizer), loss, wer, cer )) log_scalar('loss', loss, tot_iter) log_scalar('wer', wer, tot_iter) log_scalar('cer', cer, tot_iter) if loss < self.best_test_loss: print(f'NEW BEST LOSS: {loss}') self.best_test_loss = loss savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format( tot_iter, loss, wer, cer ) savename = savename.replace('.', '') + '.pt' savepath = os.path.join(self.weights_dir, savename) (save_dir, name) = os.path.split(savepath) if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(self.model.state_dict(), savepath) print(f'best model saved at {savepath}') if not opt.is_optimize: exit() def predict_sample(self, input_sample): self.model.eval() vid = input_sample.get('vid').cuda() return self.predict_video(vid) def predict_video(self, video): video = video.cuda() vid = video.unsqueeze(0) y = self.net(vid) pred_txt = self.train_dataset.ctc_decode(y) return pred_txt