torchnet / TransformerTrainer.py
milselarch's picture
push to main
df07554
raw
history blame contribute delete
No virus
16.6 kB
import os
import sys
import time
sys.path.append('../models')
import torch
import functools
import options as opt
from torch import optim
from tqdm.auto import tqdm
from PauseChecker import PauseChecker
from Trainer import Trainer
from models.LipNetPlus import LipNetPlus
from TranslatorTrainer import TranslatorTrainer
from dataset import GridDataset, CharMap, Datasets
from helpers import contains_nan_or_inf
from models.PhonemeTransformer import *
from helpers import *
class TransformerTrainer(Trainer, TranslatorTrainer):
def __init__(
self, batch_size=opt.batch_size, word_tokenize=False,
dataset_type: Datasets = opt.dataset, embeds_size=256,
vocab_files=None, write_logs=True,
input_char_map=CharMap.phonemes,
output_char_map=CharMap.letters,
name='embeds-transformer-v2',
**kwargs
):
super().__init__(**kwargs, name=name)
self.batch_size = batch_size
self.word_tokenize = word_tokenize
self.input_char_map = input_char_map
self.output_char_map = output_char_map
self.dataset_type = dataset_type
self.embeds_size = embeds_size
self.text_tokenizer = functools.partial(
GridDataset.tokenize_text, word_tokenize=word_tokenize
)
self.device = torch.device(
'cuda' if torch.cuda.is_available() else 'cpu'
)
if vocab_files is None:
vocabs = self.load_vocabs(self.base_dir)
self.phonemes_vocab, self.text_vocab = vocabs
else:
phonemes_vocab_path, text_vocab_path = vocab_files
self.phonemes_vocab = torch.load(phonemes_vocab_path)
self.text_vocab = torch.load(text_vocab_path)
self.model = None
self.optimizer = None
self.best_test_loss = float('inf')
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)
"""
self.phonemes_encoder = self.sequential_transforms(
GridDataset.tokenize_phonemes, self.phonemes_vocab,
self.tensor_transform
)
"""
self.text_encoder = self.sequential_transforms(
self.text_tokenizer, self.text_vocab,
self.tensor_transform
)
if write_logs:
self.init_tensorboard()
def create_model(self):
if self.model is None:
output_classes = len(self.train_dataset.get_char_mapping())
self.model = LipNetPlus(
output_classes=output_classes,
pre_gru_repeats=self.pre_gru_repeats,
embeds_size=self.embeds_size,
output_vocab_size=len(self.text_vocab)
)
self.model = self.model.cuda()
if self.net is None:
self.net = nn.DataParallel(self.model).cuda()
def load_datasets(self):
if self.train_dataset is None:
self.train_dataset = GridDataset(
**self.dataset_kwargs, phase='train',
file_list=opt.train_list,
sample_all_props=True
)
if self.test_dataset is None:
self.test_dataset = GridDataset(
**self.dataset_kwargs, phase='test',
file_list=opt.val_list,
sample_all_props=True
)
def train(self):
self.load_datasets()
self.create_model()
dataset = self.train_dataset
loader = self.dataset2dataloader(
dataset, num_workers=self.num_workers
)
"""
optimizer = optim.Adam(
self.model.parameters(), lr=opt.base_lr,
weight_decay=0., amsgrad=True
)
"""
optimizer = optim.RMSprop(
self.model.parameters(), lr=opt.base_lr
)
print('num_train_data:{}'.format(len(dataset.data)))
# don't allow loss function to create infinite loss for
# sequences that are too short
tic = time.time()
self.best_test_loss = float('inf')
log_scalar = functools.partial(self.log_scalar, label='train')
for epoch in range(opt.max_epoch):
print(f'RUNNING EPOCH {epoch}')
train_wer = []
pbar = tqdm(loader)
for (i_iter, input_sample) in enumerate(pbar):
PauseChecker.check()
self.model.train()
vid = input_sample.get('vid').cuda()
# vid_len = input_sample.get('vid_len').cuda()
# txt, txt_len = self.extract_char_output(input_sample)
batch_arr_sentences = input_sample['txt_anno']
batch_arr_sentences = np.array(batch_arr_sentences)
_, batch_size = batch_arr_sentences.shape
batch_sentences = [
''.join(batch_arr_sentences[:, k]).strip()
for k in range(batch_size)
]
tgt = self.collate_tgt_fn(batch_sentences)
tgt = tgt.to(self.device)
tgt_input = tgt[:-1, :]
with torch.no_grad():
gru_output = self.model.forward_gru(vid)
y = self.model.predict_from_gru_out(gru_output)
src_embeds = self.model.make_src_embeds(gru_output)
transformer_out = self.make_transformer_embeds(
dataset, src_embeds, y, batch_size=batch_size
)
transformer_src_embeds, src_idx_arr = transformer_out
transformer_src_embeds = transformer_src_embeds.to(self.device)
src_idx_arr = src_idx_arr.to(self.device)
max_seq_len, batch_size = src_idx_arr.shape
(
src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask
) = create_mask(
src_idx_arr, tgt_input, self.device
)
logits = self.model.seq_forward(
transformer_src_embeds, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask
)
optimizer.zero_grad()
tgt_out = tgt[1:, :]
loss = self.loss_fn(
logits.reshape(-1, logits.shape[-1]),
tgt_out.reshape(-1)
)
tot_iter = i_iter + epoch * len(loader)
loss.backward()
optimizer.step()
# Convert logits tensor to string
with torch.no_grad():
# Convert logits tensor to string
probs = torch.softmax(logits, dim=-1)
token_indices = torch.argmax(probs, dim=-1)
# Convert token indices to strings for
# each sequence in the batch
gap = ' ' if self.word_tokenize else ''
# print('TT', token_indices.shape)
pred_sentences = self.batch_indices_to_text(
token_indices, batch_size=batch_size, gap=gap
)
wer = np.mean(GridDataset.get_wer(
pred_sentences, batch_sentences,
char_map=self.output_char_map
))
train_wer.append(wer)
if tot_iter % opt.display == 0:
v = 1.0 * (time.time() - tic) / (tot_iter + 1)
eta = (len(loader) - i_iter) * v / 3600.0
wer = np.array(train_wer).mean()
log_scalar('loss', loss, tot_iter)
log_scalar('wer', wer, tot_iter)
self.log_pred_texts(
pred_sentences, batch_sentences, sub_samples=3
)
print('epoch={},tot_iter={},eta={},loss={},train_wer={}'
.format(
epoch, tot_iter, eta, loss,
np.array(train_wer).mean()
))
print(''.join(161 * '-'))
if (tot_iter > -1) and (tot_iter % opt.test_step == 0):
# if tot_iter % opt.test_step == 0:
self.run_test(tot_iter, optimizer)
def make_transformer_embeds(
self, dataset, src_embeds, y, batch_size
):
batch_indices = dataset.ctc_decode_indices(y)
filter_batch_embeds = []
pad_embed = self.model.src_tok_emb(
torch.IntTensor([PAD_IDX]).to(self.device)
)
begin_embed = self.model.src_tok_emb(
torch.IntTensor([BOS_IDX]).to(self.device)
)
end_embed = self.model.src_tok_emb(
torch.IntTensor([EOS_IDX]).to(self.device)
)
max_sentence_len = max([len(x) for x in batch_indices])
# initialize embeds with pad token embeddings
# [max_seq_len + 1, batch_size, embeds_size]
transformer_src_embeds = pad_embed.expand(
max_sentence_len + 2, batch_size, pad_embed.shape[1]
)
src_idx_mask = torch.full(
transformer_src_embeds.shape[:2], PAD_IDX,
dtype=torch.int
)
# k is sentence index in batch
for k, sentence_indices in enumerate(batch_indices):
filter_sentence_embeds = []
for sentence_index in sentence_indices:
filter_sentence_embeds.append(
src_embeds[sentence_index][k]
)
sentence_length = len(filter_sentence_embeds)
filter_batch_embeds.append(filter_sentence_embeds)
# set beginning to sequence embed
transformer_src_embeds[0][k] = begin_embed
src_idx_mask[0][k] = UNK_IDX
# index i is char index in sentence
for i, char_embed in enumerate(filter_sentence_embeds):
transformer_src_embeds[i + 1][k] = char_embed
src_idx_mask[i + 1][k] = UNK_IDX
transformer_src_embeds[sentence_length + 1][k] = end_embed
src_idx_mask[sentence_length + 1][k] = UNK_IDX
return transformer_src_embeds, src_idx_mask
@staticmethod
def log_pred_texts(
pred_txt, truth_txt, pad=80, sub_samples=None
):
line_length = 2 * pad + 1
print(''.join(line_length * '-'))
print('{:<{pad}}|{:>{pad}}'.format(
'predict', 'truth', pad=pad
))
print(''.join(line_length * '-'))
zipped_samples = list(zip(pred_txt, truth_txt))
if sub_samples is not None:
zipped_samples = zipped_samples[:sub_samples]
for (predict, truth) in zipped_samples:
print('{:<{pad}}|{:>{pad}}'.format(
predict, truth, pad=pad
))
print(''.join(line_length * '-'))
def test(self):
dataset = self.test_dataset
with torch.no_grad():
print('num_test_data:{}'.format(len(dataset.data)))
self.model.eval()
loader = self.dataset2dataloader(
dataset, shuffle=False, num_workers=self.num_workers
)
loss_list = []
wer = []
cer = []
tic = time.time()
print('RUNNING VALIDATION')
pbar = tqdm(loader)
for (i_iter, input_sample) in enumerate(pbar):
PauseChecker.check()
vid = input_sample.get('vid').cuda()
batch_arr_sentences = input_sample['txt_anno']
batch_arr_sentences = np.array(batch_arr_sentences)
_, batch_size = batch_arr_sentences.shape
batch_sentences = [
''.join(batch_arr_sentences[:, k]).strip()
for k in range(batch_size)
]
tgt = self.collate_tgt_fn(batch_sentences)
tgt = tgt.to(self.device)
tgt_input = tgt[:-1, :]
with torch.no_grad():
gru_output = self.model.forward_gru(vid)
y = self.model.predict_from_gru_out(gru_output)
src_embeds = self.model.make_src_embeds(gru_output)
transformer_out = self.make_transformer_embeds(
dataset, src_embeds, y, batch_size=batch_size
)
transformer_src_embeds, src_idx_arr = transformer_out
transformer_src_embeds = transformer_src_embeds.to(self.device)
src_idx_arr = src_idx_arr.to(self.device)
max_seq_len, batch_size = src_idx_arr.shape
(
src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask
) = create_mask(
src_idx_arr, tgt_input, self.device
)
logits = self.model.seq_forward(
transformer_src_embeds, tgt_input, src_mask, tgt_mask,
src_padding_mask, tgt_padding_mask, src_padding_mask
)
# Convert logits tensor to string
with torch.no_grad():
# Convert logits tensor to string
probs = torch.softmax(logits, dim=-1)
token_indices = torch.argmax(probs, dim=-1)
# Convert token indices to strings for
# each sequence in the batch
gap = ' ' if self.word_tokenize else ''
# print('TT', token_indices.shape)
pred_sentences = self.batch_indices_to_text(
token_indices, batch_size=batch_size, gap=gap
)
tgt_out = tgt[1:, :]
loss = self.loss_fn(
logits.reshape(-1, logits.shape[-1]),
tgt_out.reshape(-1)
)
loss_item = loss.detach().cpu().numpy()
loss_list.append(loss_item)
wer.extend(GridDataset.get_wer(
pred_sentences, batch_sentences,
char_map=self.output_char_map
))
cer.extend(GridDataset.get_cer(
pred_sentences, batch_sentences,
char_map=self.output_char_map
))
if i_iter % opt.display == 0:
v = 1.0 * (time.time() - tic) / (i_iter + 1)
eta = v * (len(loader) - i_iter) / 3600.0
self.log_pred_texts(
pred_sentences, batch_sentences, sub_samples=10
)
print('test_iter={},eta={},wer={},cer={}'.format(
i_iter, eta, np.array(wer).mean(),
np.array(cer).mean()
))
print(''.join(161 * '-'))
return (
np.array(loss_list).mean(), np.array(wer).mean(),
np.array(cer).mean()
)
def run_test(self, tot_iter, optimizer):
log_scalar = functools.partial(self.log_scalar, label='test')
(loss, wer, cer) = self.test()
print('i_iter={},lr={},loss={},wer={},cer={}'.format(
tot_iter, show_lr(optimizer), loss, wer, cer
))
log_scalar('loss', loss, tot_iter)
log_scalar('wer', wer, tot_iter)
log_scalar('cer', cer, tot_iter)
if loss < self.best_test_loss:
print(f'NEW BEST LOSS: {loss}')
self.best_test_loss = loss
savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format(
tot_iter, loss, wer, cer
)
savename = savename.replace('.', '') + '.pt'
savepath = os.path.join(self.weights_dir, savename)
(save_dir, name) = os.path.split(savepath)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(self.model.state_dict(), savepath)
print(f'best model saved at {savepath}')
if not opt.is_optimize:
exit()
if __name__ == '__main__':
vocab_filepaths = (
'data/grid_phoneme_vocab.pth',
'data/grid_text_char_vocab.pth'
)
"""
vocab_filepaths = (
'data/lsr2_phoneme_vocab.pth',
'data/lsr2_text_char_vocab.pth'
)
"""
trainer = TransformerTrainer(
word_tokenize=False, vocab_files=vocab_filepaths,
input_char_map=opt.char_map,
output_char_map=opt.text_char_map
)
if hasattr(opt, 'weights'):
trainer.load_weights(opt.weights)
trainer.train()