import os import numpy as np import functools import shutil from typing import List import torch from tqdm.auto import tqdm from torch.utils.data import DataLoader from torchtext.datasets import Multi30k import options from Loader import GridLoader from PauseChecker import PauseChecker from dataset import GridDataset, CharMap, Datasets from datetime import datetime as Datetime from models.PhonemeTransformer import * from torchtext.vocab import build_vocab_from_iterator from torch.nn.utils.rnn import pad_sequence from BaseTrainer import BaseTrainer class TranslationDataset(GridDataset): def __init__( self, input_char_map: CharMap, output_char_map: CharMap, **kwargs ): super().__init__(**kwargs) self.input_char_map = input_char_map self.output_char_map = output_char_map def __getitem__(self, idx): (vid, spk, name) = self.data[idx] basename, _ = os.path.splitext(name) input_filepath = self.fetch_anno_path( spk, basename, char_map=self.input_char_map ) output_filepath = self.fetch_anno_path( spk, basename, char_map=self.output_char_map ) input_str = self.load_str_sentence( input_filepath, char_map=self.input_char_map ) output_str = self.load_str_sentence( output_filepath, char_map=self.output_char_map ) return input_str, output_str class TranslatorTrainer(BaseTrainer): def __init__( self, dataset_type: Datasets = options.dataset, batch_size=128, validate_every=20, display_every=10, name='translate', write_logs=True, base_dir='', word_tokenize=False, vocab_files=None, input_char_map=CharMap.phonemes, output_char_map=CharMap.letters ): super().__init__(name=name, base_dir=base_dir) self.batch_size = batch_size self.validate_every = validate_every self.display_every = display_every self.word_tokenize = word_tokenize self.input_char_map = input_char_map self.output_char_map = output_char_map self.dataset_type = dataset_type self.text_tokenizer = functools.partial( GridDataset.tokenize_text, word_tokenize=word_tokenize ) self.device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) if vocab_files is None: vocabs = self.load_vocabs(self.base_dir) self.phonemes_vocab, self.text_vocab = vocabs else: phonemes_vocab_path, text_vocab_path = vocab_files self.phonemes_vocab = torch.load(phonemes_vocab_path) self.text_vocab = torch.load(text_vocab_path) self.model = None self.optimizer = None self.best_test_loss = float('inf') self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) self.phonemes_encoder = self.sequential_transforms( GridDataset.tokenize_phonemes, self.phonemes_vocab, self.tensor_transform ) self.text_encoder = self.sequential_transforms( self.text_tokenizer, self.text_vocab, self.tensor_transform ) if write_logs: self.init_tensorboard() def load_vocabs(self, base_dir): loader = GridLoader(base_dir=base_dir) if self.dataset_type == Datasets.GRID: phonemes_text_map = loader.load_grid_phonemes_text_map( phonemes_char_map=self.input_char_map, text_char_map=self.output_char_map ) elif self.dataset_type == Datasets.LRS2: phonemes_text_map = loader.load_lsr2_phonemes_text_map( phonemes_char_map=self.input_char_map, text_char_map=self.output_char_map ) else: raise NotImplementedError phonemes_map = phonemes_text_map[self.input_char_map] text_map = phonemes_text_map[self.output_char_map] phonemes_vocab = self.build_vocab( phonemes_map, tokenizer=GridDataset.tokenize_phonemes ) text_vocab = self.build_vocab( text_map, tokenizer=self.text_tokenizer ) return phonemes_vocab, text_vocab def save_vocabs( self, phoneme_vocab_path, text_vocab_path ): torch.save(self.phonemes_vocab, phoneme_vocab_path) torch.save(self.text_vocab, text_vocab_path) def load_weights(self, weights): self.create_model() pretrained_dict = torch.load(weights) model_dict = self.model.state_dict() pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict.keys() and v.size() == model_dict[k].size() } missed_params = [ k for k, v in model_dict.items() if k not in pretrained_dict.keys() ] print('loaded params/tot params: {}/{}'.format( len(pretrained_dict), len(model_dict) )) print('miss matched params:{}'.format(missed_params)) model_dict.update(pretrained_dict) self.model.load_state_dict(model_dict) def create_model(self): self.model = Seq2SeqTransformer( src_vocab_size=len(self.phonemes_vocab), tgt_vocab_size=len(self.text_vocab) ) self.model = self.model.to(self.device) self.optimizer = torch.optim.Adam( self.model.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9 ) def collate_tgt_fn(self, batch): tgt_batch = [] for tgt_sample in batch: tgt_batch.append(self.text_encoder(tgt_sample.rstrip("\n"))) tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) return tgt_batch # function to collate data samples into batch tensors def collate_fn(self, batch): src_batch, tgt_batch = [], [] for src_sample, tgt_sample in batch: src_batch.append(self.phonemes_encoder(src_sample.rstrip("\n"))) tgt_batch.append(self.text_encoder(tgt_sample.rstrip("\n"))) src_batch = pad_sequence(src_batch, padding_value=PAD_IDX) tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) return src_batch, tgt_batch def train(self, max_iters=10*1000): assert self.writer is not None assert self.display_every < self.validate_every self.create_model() self.best_test_loss = float('inf') log_scalar = functools.partial(self.log_scalar, label='train') self.model.train() losses = 0 dataset_kwargs = self.get_dataset_kwargs( input_char_map=self.input_char_map, char_map=self.output_char_map, output_char_map=self.output_char_map, file_list=options.train_list ) train_iter = TranslationDataset(**dataset_kwargs, phase='train') test_iter = TranslationDataset(**dataset_kwargs, phase='test') train_dataloader = DataLoader( train_iter, batch_size=self.batch_size, # collate_fn=self.collate_fn, shuffle=True ) test_dataloader = DataLoader( test_iter, batch_size=self.batch_size, # collate_fn=self.collate_fn, shuffle=True ) tot_iters = 0 pbar = tqdm(total=max_iters) while tot_iters < max_iters: for train_pair in train_dataloader: PauseChecker.check() raw_src, raw_tgt = train_pair src, tgt = self.collate_fn(zip(raw_src, raw_tgt)) batch_size, max_seq_len = src.shape src = src.to(self.device) tgt = tgt.to(self.device) tgt_input = tgt[:-1, :] ( src_mask, tgt_mask, src_padding_mask, tgt_padding_mask ) = create_mask(src, tgt_input, self.device) logits = self.model( src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask ) self.optimizer.zero_grad() tgt_out = tgt[1:, :] loss = self.loss_fn( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) ) loss.backward() self.optimizer.step() loss_item = loss.item() # Convert logits tensor to string with torch.no_grad(): # Convert logits tensor to string probs = torch.softmax(logits, dim=-1) token_indices = torch.argmax(probs, dim=-1) # Convert token indices to strings for # each sequence in the batch gap = ' ' if self.word_tokenize else '' pred_sentences = self.batch_indices_to_text( token_indices, batch_size=max_seq_len, gap=gap ) wer = np.mean(GridDataset.get_wer( pred_sentences, raw_tgt, char_map=self.output_char_map )) desc = f'loss: {loss_item:.4f}, wer: {wer:.4f}' pbar.desc = desc losses += loss_item tot_iters += 1 pbar.update(1) run_validation = ( (tot_iters > 0) and (tot_iters % self.validate_every == 0) ) run_display = ( (tot_iters > 0) and (tot_iters % self.display_every == 0) ) if run_validation: self.run_test(test_dataloader, tot_iters=tot_iters) elif run_display: print('TRAIN PREDICTIONS') self.show_sentences(pred_sentences, raw_tgt, batch_size) if self.writer is not None: log_scalar('loss', loss, tot_iters) log_scalar('wer', wer, tot_iters) return losses / len(list(train_dataloader)) @staticmethod def show_sentences( pred_sentences, target_sentences, batch_size, pad=40 ): print('{:<{pad}}|{:>{pad}}'.format( 'predict', 'target', pad=pad )) line_length = 2 * pad + 1 print(''.join(line_length * '-')) for k in range(batch_size): pred_sentence = pred_sentences[k] target_sentence = target_sentences[k] print('{:<{pad}}|{:>{pad}}'.format( pred_sentence, target_sentence, pad=pad )) print(''.join(line_length * '-')) def run_test(self, test_dataloader, tot_iters): log_scalar = functools.partial(self.log_scalar, label='test') with torch.no_grad(): self.model.eval() for batch in test_dataloader: break raw_src, raw_tgt = batch src, tgt = self.collate_fn(zip(raw_src, raw_tgt)) batch_size, max_seq_len = src.shape src = src.to(self.device) tgt = tgt.to(self.device) tgt_input = tgt[:-1, :] ( src_mask, tgt_mask, src_padding_mask, tgt_padding_mask ) = create_mask(src, tgt_input, self.device) logits = self.model( src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask ) self.optimizer.zero_grad() tgt_out = tgt[1:, :] loss = self.loss_fn( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) ) loss_item = loss.item() # Convert logits tensor to string probs = torch.softmax(logits, dim=-1) token_indices = torch.argmax(torch.softmax(logits, dim=-1), dim=-1) # Convert token indices to strings for each sequence in the batch gap = ' ' if self.word_tokenize else '' pred_sentences = self.batch_indices_to_text( token_indices, batch_size=max_seq_len, gap=gap ) wer = np.mean(GridDataset.get_wer( pred_sentences, raw_tgt, char_map=self.output_char_map )) log_scalar('loss', loss, tot_iters) log_scalar('wer', wer, tot_iters) print(f'TEST PREDS [loss={loss_item:.4f}, wer={wer:.4f}]') self.show_sentences(pred_sentences, raw_tgt, batch_size) if loss < self.best_test_loss: print(f'NEW BEST LOSS: {loss}') self.best_test_loss = loss savename = 'I{}-L{:.4f}-W{:.4f}'.format( tot_iters, loss, wer ) savename = savename.replace('.', '') + '.pt' savepath = os.path.join(self.weights_dir, savename) (save_dir, name) = os.path.split(savepath) if not os.path.exists(save_dir): os.makedirs(save_dir) torch.save(self.model.state_dict(), savepath) print(f'best model saved at {savepath}') def batch_indices_to_text( self, indices_tensor, batch_size, gap='' ): sentences = [] for k in range(batch_size): tokens = [] for indices_row in indices_tensor: idx = indices_row[k] if idx == EOS_IDX: break if idx in [PAD_IDX, BOS_IDX, EOS_IDX]: continue token = self.text_vocab.lookup_token(idx) tokens.append(token) sentence = gap.join(tokens) sentences.append(sentence) return sentences @staticmethod def batch_tokenize_text(batch_sentences, word_tokenize=False): return [ GridDataset.tokenize_text( sentence, word_tokenize=word_tokenize ) for sentence in batch_sentences ] def evaluate(self, model): model.eval() losses = 0 language_pair = (str(CharMap.phonemes), str(CharMap.letters)) val_iter = Multi30k( split='valid', language_pair=language_pair ) val_dataloader = DataLoader( val_iter, batch_size=self.batch_size, collate_fn=self.collate_fn ) for src, tgt in val_dataloader: src = src.to(self.device) tgt = tgt.to(self.device) tgt_input = tgt[:-1, :] ( src_mask, tgt_mask, src_padding_mask, tgt_padding_mask ) = create_mask(src, tgt_input, self.device) logits = model( src, tgt_input, src_mask, tgt_mask, src_padding_mask, tgt_padding_mask, src_padding_mask ) tgt_out = tgt[1:, :] loss = self.loss_fn( logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1) ) losses += loss.item() return losses / len(list(val_dataloader)) # actual function to translate input sentence into target language def translate( self, phoneme_sentence: str, beam_size=0 ): self.model.eval() dummy_sentence = self.text_vocab.lookup_token( len(self.text_vocab) - 1 ) src, _ = self.collate_fn(zip( [phoneme_sentence], [dummy_sentence] )) batch_size, max_seq_len = src.shape src = src.to(self.device) num_tokens = src.shape[0] src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) max_len = num_tokens + 5 if beam_size > 0: tgt_tokens = self.beam_search_decode( src, src_mask, max_len=max_len, start_symbol=BOS_IDX, beam_size=beam_size ) else: tgt_tokens = self.greedy_decode( src, src_mask, max_len=max_len, start_symbol=BOS_IDX ) gap = ' ' if self.word_tokenize else '' pred_sentence = self.batch_indices_to_text( tgt_tokens, batch_size=max_seq_len, gap=gap )[0] return pred_sentence # function to generate output sequence using greedy algorithm def greedy_decode(self, src, src_mask, max_len, start_symbol): src = src.to(self.device) src_mask = src_mask.to(self.device) memory = self.model.encode(src, src_mask) ys = ( torch.ones(1, 1).fill_(start_symbol). type(torch.long).to(self.device) ) for i in range(max_len - 1): memory = memory.to(self.device) tgt_mask = ( generate_square_subsequent_mask( ys.size(0), device=self.device ).type(torch.bool) ).to(self.device) out = self.model.decode(ys, memory, tgt_mask) out = out.transpose(0, 1) prob = self.model.generator(out[:, -1]) _, next_word = torch.max(prob, dim=1) next_word = next_word.item() ys = torch.cat([ ys, torch.ones(1, 1).type_as(src.data).fill_(next_word) ], dim=0) if next_word == EOS_IDX: break return ys def beam_search_decode( self, src, src_mask, max_len, start_symbol, beam_size=5 ): src = src.to(self.device) src_mask = src_mask.to(self.device) memory = self.model.encode(src, src_mask) ys = ( torch.ones(1, 1).fill_(start_symbol). type(torch.long).to(self.device) ) # Each hypothesis is a tuple (sequence, score) hypotheses = [(ys, 0.0)] for _ in range(max_len - 1): new_hypotheses = [] for seq, score in hypotheses: if seq[-1] == EOS_IDX: new_hypotheses.append((seq, score)) continue tgt_mask = generate_square_subsequent_mask( seq.size(0), device=self.device ).type(torch.bool) out = self.model.decode(seq, memory, tgt_mask) out = out.transpose(0, 1) prob = self.model.generator(out[:, -1]) # pick {beam_size} largest probabilities from prob topk_prob, topk_indices = torch.topk(prob, beam_size) for i in range(beam_size): next_word = topk_indices[0][i] # Assuming negative log probabilities next_score = score - topk_prob[0][i].item() new_seq = torch.cat([ seq, torch.ones(1, 1).type_as(src.data).fill_(next_word) ], dim=0) # new_seq = torch.cat([seq, next_word.unsqueeze(0)], dim=0) new_hypotheses.append((new_seq, next_score)) if len(new_hypotheses) == 0: break # Keep top beam_size hypotheses hypotheses = sorted( new_hypotheses, key=lambda x: x[1] )[:beam_size] return hypotheses[0][0] # Return the best hypothesis @staticmethod def yield_tokens(sequence_map, tokenizer): for key in sequence_map: yield tokenizer(sequence_map[key]) def build_vocab(self, sequence_map, tokenizer): return build_vocab_from_iterator( self.yield_tokens(sequence_map, tokenizer), min_freq=1, specials=SPECIAL_SYMBOLS, special_first=True ) # helper function to club together sequential operations @staticmethod def sequential_transforms(*transforms): def func(txt_input): for transform in transforms: txt_input = transform(txt_input) return txt_input return func # function to add BOS/EOS and create tensor for input sequence indices @staticmethod def tensor_transform(token_ids: List[int]): return torch.cat(( torch.tensor([BOS_IDX]), torch.tensor(token_ids), torch.tensor([EOS_IDX]) )) if __name__ == '__main__': vocab_filepaths = ( 'data/grid_phoneme_vocab.pth', 'data/grid_text_char_vocab.pth' ) """ vocab_filepaths = ( 'data/lsr2_phoneme_vocab.pth', 'data/lsr2_text_char_vocab.pth' ) """ trainer = TranslatorTrainer( word_tokenize=False, vocab_files=vocab_filepaths, input_char_map=options.char_map, output_char_map=options.text_char_map ) trainer.train() # trainer.save_vocabs(*vocab_filepaths) # loader = GridLoader() # phonemes_text_map = loader.load_phonemes_text_map() # print(">>>")