|
import os |
|
import numpy as np |
|
import functools |
|
import shutil |
|
|
|
from typing import List |
|
|
|
import torch |
|
from tqdm.auto import tqdm |
|
|
|
from torch.utils.data import DataLoader |
|
from torchtext.datasets import Multi30k |
|
|
|
import options |
|
from Loader import GridLoader |
|
from PauseChecker import PauseChecker |
|
from dataset import GridDataset, CharMap, Datasets |
|
from datetime import datetime as Datetime |
|
|
|
from models.PhonemeTransformer import * |
|
from torchtext.vocab import build_vocab_from_iterator |
|
from torch.nn.utils.rnn import pad_sequence |
|
from BaseTrainer import BaseTrainer |
|
|
|
|
|
class TranslationDataset(GridDataset): |
|
def __init__( |
|
self, input_char_map: CharMap, |
|
output_char_map: CharMap, **kwargs |
|
): |
|
super().__init__(**kwargs) |
|
self.input_char_map = input_char_map |
|
self.output_char_map = output_char_map |
|
|
|
def __getitem__(self, idx): |
|
(vid, spk, name) = self.data[idx] |
|
basename, _ = os.path.splitext(name) |
|
|
|
input_filepath = self.fetch_anno_path( |
|
spk, basename, char_map=self.input_char_map |
|
) |
|
output_filepath = self.fetch_anno_path( |
|
spk, basename, char_map=self.output_char_map |
|
) |
|
|
|
input_str = self.load_str_sentence( |
|
input_filepath, char_map=self.input_char_map |
|
) |
|
output_str = self.load_str_sentence( |
|
output_filepath, char_map=self.output_char_map |
|
) |
|
return input_str, output_str |
|
|
|
|
|
class TranslatorTrainer(BaseTrainer): |
|
def __init__( |
|
self, dataset_type: Datasets = options.dataset, |
|
batch_size=128, validate_every=20, display_every=10, |
|
name='translate', write_logs=True, base_dir='', |
|
word_tokenize=False, vocab_files=None, |
|
input_char_map=CharMap.phonemes, |
|
output_char_map=CharMap.letters |
|
): |
|
super().__init__(name=name, base_dir=base_dir) |
|
|
|
self.batch_size = batch_size |
|
self.validate_every = validate_every |
|
self.display_every = display_every |
|
self.word_tokenize = word_tokenize |
|
self.input_char_map = input_char_map |
|
self.output_char_map = output_char_map |
|
self.dataset_type = dataset_type |
|
|
|
self.text_tokenizer = functools.partial( |
|
GridDataset.tokenize_text, word_tokenize=word_tokenize |
|
) |
|
self.device = torch.device( |
|
'cuda' if torch.cuda.is_available() else 'cpu' |
|
) |
|
|
|
if vocab_files is None: |
|
vocabs = self.load_vocabs(self.base_dir) |
|
self.phonemes_vocab, self.text_vocab = vocabs |
|
else: |
|
phonemes_vocab_path, text_vocab_path = vocab_files |
|
self.phonemes_vocab = torch.load(phonemes_vocab_path) |
|
self.text_vocab = torch.load(text_vocab_path) |
|
|
|
self.model = None |
|
self.optimizer = None |
|
self.best_test_loss = float('inf') |
|
self.loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX) |
|
|
|
self.phonemes_encoder = self.sequential_transforms( |
|
GridDataset.tokenize_phonemes, self.phonemes_vocab, |
|
self.tensor_transform |
|
) |
|
self.text_encoder = self.sequential_transforms( |
|
self.text_tokenizer, self.text_vocab, |
|
self.tensor_transform |
|
) |
|
|
|
if write_logs: |
|
self.init_tensorboard() |
|
|
|
def load_vocabs(self, base_dir): |
|
loader = GridLoader(base_dir=base_dir) |
|
|
|
if self.dataset_type == Datasets.GRID: |
|
phonemes_text_map = loader.load_grid_phonemes_text_map( |
|
phonemes_char_map=self.input_char_map, |
|
text_char_map=self.output_char_map |
|
) |
|
elif self.dataset_type == Datasets.LRS2: |
|
phonemes_text_map = loader.load_lsr2_phonemes_text_map( |
|
phonemes_char_map=self.input_char_map, |
|
text_char_map=self.output_char_map |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
phonemes_map = phonemes_text_map[self.input_char_map] |
|
text_map = phonemes_text_map[self.output_char_map] |
|
|
|
phonemes_vocab = self.build_vocab( |
|
phonemes_map, tokenizer=GridDataset.tokenize_phonemes |
|
) |
|
text_vocab = self.build_vocab( |
|
text_map, tokenizer=self.text_tokenizer |
|
) |
|
|
|
return phonemes_vocab, text_vocab |
|
|
|
def save_vocabs( |
|
self, phoneme_vocab_path, text_vocab_path |
|
): |
|
torch.save(self.phonemes_vocab, phoneme_vocab_path) |
|
torch.save(self.text_vocab, text_vocab_path) |
|
|
|
def load_weights(self, weights): |
|
self.create_model() |
|
|
|
pretrained_dict = torch.load(weights) |
|
model_dict = self.model.state_dict() |
|
pretrained_dict = { |
|
k: v for k, v in pretrained_dict.items() if |
|
k in model_dict.keys() and v.size() == model_dict[k].size() |
|
} |
|
|
|
missed_params = [ |
|
k for k, v in model_dict.items() |
|
if k not in pretrained_dict.keys() |
|
] |
|
|
|
print('loaded params/tot params: {}/{}'.format( |
|
len(pretrained_dict), len(model_dict) |
|
)) |
|
print('miss matched params:{}'.format(missed_params)) |
|
model_dict.update(pretrained_dict) |
|
self.model.load_state_dict(model_dict) |
|
|
|
def create_model(self): |
|
self.model = Seq2SeqTransformer( |
|
src_vocab_size=len(self.phonemes_vocab), |
|
tgt_vocab_size=len(self.text_vocab) |
|
) |
|
|
|
self.model = self.model.to(self.device) |
|
self.optimizer = torch.optim.Adam( |
|
self.model.parameters(), |
|
lr=0.0001, betas=(0.9, 0.98), eps=1e-9 |
|
) |
|
|
|
def collate_tgt_fn(self, batch): |
|
tgt_batch = [] |
|
for tgt_sample in batch: |
|
tgt_batch.append(self.text_encoder(tgt_sample.rstrip("\n"))) |
|
|
|
tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) |
|
return tgt_batch |
|
|
|
|
|
def collate_fn(self, batch): |
|
src_batch, tgt_batch = [], [] |
|
for src_sample, tgt_sample in batch: |
|
src_batch.append(self.phonemes_encoder(src_sample.rstrip("\n"))) |
|
tgt_batch.append(self.text_encoder(tgt_sample.rstrip("\n"))) |
|
|
|
src_batch = pad_sequence(src_batch, padding_value=PAD_IDX) |
|
tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX) |
|
return src_batch, tgt_batch |
|
|
|
def train(self, max_iters=10*1000): |
|
assert self.writer is not None |
|
assert self.display_every < self.validate_every |
|
|
|
self.create_model() |
|
self.best_test_loss = float('inf') |
|
log_scalar = functools.partial(self.log_scalar, label='train') |
|
self.model.train() |
|
losses = 0 |
|
|
|
dataset_kwargs = self.get_dataset_kwargs( |
|
input_char_map=self.input_char_map, |
|
char_map=self.output_char_map, |
|
output_char_map=self.output_char_map, |
|
file_list=options.train_list |
|
) |
|
|
|
train_iter = TranslationDataset(**dataset_kwargs, phase='train') |
|
test_iter = TranslationDataset(**dataset_kwargs, phase='test') |
|
|
|
train_dataloader = DataLoader( |
|
train_iter, batch_size=self.batch_size, |
|
|
|
) |
|
test_dataloader = DataLoader( |
|
test_iter, batch_size=self.batch_size, |
|
|
|
) |
|
|
|
tot_iters = 0 |
|
pbar = tqdm(total=max_iters) |
|
|
|
while tot_iters < max_iters: |
|
for train_pair in train_dataloader: |
|
PauseChecker.check() |
|
|
|
raw_src, raw_tgt = train_pair |
|
src, tgt = self.collate_fn(zip(raw_src, raw_tgt)) |
|
batch_size, max_seq_len = src.shape |
|
|
|
src = src.to(self.device) |
|
tgt = tgt.to(self.device) |
|
tgt_input = tgt[:-1, :] |
|
( |
|
src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask |
|
) = create_mask(src, tgt_input, self.device) |
|
|
|
logits = self.model( |
|
src, tgt_input, src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask, src_padding_mask |
|
) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
tgt_out = tgt[1:, :] |
|
loss = self.loss_fn( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1) |
|
) |
|
|
|
loss.backward() |
|
self.optimizer.step() |
|
loss_item = loss.item() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
token_indices = torch.argmax(probs, dim=-1) |
|
|
|
|
|
|
|
gap = ' ' if self.word_tokenize else '' |
|
pred_sentences = self.batch_indices_to_text( |
|
token_indices, batch_size=max_seq_len, gap=gap |
|
) |
|
wer = np.mean(GridDataset.get_wer( |
|
pred_sentences, raw_tgt, char_map=self.output_char_map |
|
)) |
|
|
|
desc = f'loss: {loss_item:.4f}, wer: {wer:.4f}' |
|
pbar.desc = desc |
|
|
|
losses += loss_item |
|
tot_iters += 1 |
|
pbar.update(1) |
|
|
|
run_validation = ( |
|
(tot_iters > 0) and |
|
(tot_iters % self.validate_every == 0) |
|
) |
|
run_display = ( |
|
(tot_iters > 0) and |
|
(tot_iters % self.display_every == 0) |
|
) |
|
|
|
if run_validation: |
|
self.run_test(test_dataloader, tot_iters=tot_iters) |
|
elif run_display: |
|
print('TRAIN PREDICTIONS') |
|
self.show_sentences(pred_sentences, raw_tgt, batch_size) |
|
|
|
if self.writer is not None: |
|
log_scalar('loss', loss, tot_iters) |
|
log_scalar('wer', wer, tot_iters) |
|
|
|
return losses / len(list(train_dataloader)) |
|
|
|
@staticmethod |
|
def show_sentences( |
|
pred_sentences, target_sentences, batch_size, pad=40 |
|
): |
|
print('{:<{pad}}|{:>{pad}}'.format( |
|
'predict', 'target', pad=pad |
|
)) |
|
|
|
line_length = 2 * pad + 1 |
|
print(''.join(line_length * '-')) |
|
|
|
for k in range(batch_size): |
|
pred_sentence = pred_sentences[k] |
|
target_sentence = target_sentences[k] |
|
print('{:<{pad}}|{:>{pad}}'.format( |
|
pred_sentence, target_sentence, pad=pad |
|
)) |
|
|
|
print(''.join(line_length * '-')) |
|
|
|
def run_test(self, test_dataloader, tot_iters): |
|
log_scalar = functools.partial(self.log_scalar, label='test') |
|
|
|
with torch.no_grad(): |
|
self.model.eval() |
|
|
|
for batch in test_dataloader: |
|
break |
|
|
|
raw_src, raw_tgt = batch |
|
src, tgt = self.collate_fn(zip(raw_src, raw_tgt)) |
|
batch_size, max_seq_len = src.shape |
|
src = src.to(self.device) |
|
tgt = tgt.to(self.device) |
|
|
|
tgt_input = tgt[:-1, :] |
|
( |
|
src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask |
|
) = create_mask(src, tgt_input, self.device) |
|
|
|
logits = self.model( |
|
src, tgt_input, src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask, src_padding_mask |
|
) |
|
|
|
self.optimizer.zero_grad() |
|
|
|
tgt_out = tgt[1:, :] |
|
loss = self.loss_fn( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1) |
|
) |
|
|
|
loss_item = loss.item() |
|
|
|
|
|
probs = torch.softmax(logits, dim=-1) |
|
token_indices = torch.argmax(torch.softmax(logits, dim=-1), dim=-1) |
|
|
|
gap = ' ' if self.word_tokenize else '' |
|
pred_sentences = self.batch_indices_to_text( |
|
token_indices, batch_size=max_seq_len, gap=gap |
|
) |
|
wer = np.mean(GridDataset.get_wer( |
|
pred_sentences, raw_tgt, char_map=self.output_char_map |
|
)) |
|
|
|
log_scalar('loss', loss, tot_iters) |
|
log_scalar('wer', wer, tot_iters) |
|
print(f'TEST PREDS [loss={loss_item:.4f}, wer={wer:.4f}]') |
|
self.show_sentences(pred_sentences, raw_tgt, batch_size) |
|
|
|
if loss < self.best_test_loss: |
|
print(f'NEW BEST LOSS: {loss}') |
|
self.best_test_loss = loss |
|
savename = 'I{}-L{:.4f}-W{:.4f}'.format( |
|
tot_iters, loss, wer |
|
) |
|
|
|
savename = savename.replace('.', '') + '.pt' |
|
savepath = os.path.join(self.weights_dir, savename) |
|
|
|
(save_dir, name) = os.path.split(savepath) |
|
if not os.path.exists(save_dir): |
|
os.makedirs(save_dir) |
|
|
|
torch.save(self.model.state_dict(), savepath) |
|
print(f'best model saved at {savepath}') |
|
|
|
def batch_indices_to_text( |
|
self, indices_tensor, batch_size, gap='' |
|
): |
|
sentences = [] |
|
|
|
for k in range(batch_size): |
|
tokens = [] |
|
|
|
for indices_row in indices_tensor: |
|
idx = indices_row[k] |
|
|
|
if idx == EOS_IDX: |
|
break |
|
if idx in [PAD_IDX, BOS_IDX, EOS_IDX]: |
|
continue |
|
|
|
token = self.text_vocab.lookup_token(idx) |
|
tokens.append(token) |
|
|
|
sentence = gap.join(tokens) |
|
sentences.append(sentence) |
|
|
|
return sentences |
|
|
|
@staticmethod |
|
def batch_tokenize_text(batch_sentences, word_tokenize=False): |
|
return [ |
|
GridDataset.tokenize_text( |
|
sentence, word_tokenize=word_tokenize |
|
) for sentence in batch_sentences |
|
] |
|
|
|
def evaluate(self, model): |
|
model.eval() |
|
losses = 0 |
|
|
|
language_pair = (str(CharMap.phonemes), str(CharMap.letters)) |
|
val_iter = Multi30k( |
|
split='valid', language_pair=language_pair |
|
) |
|
val_dataloader = DataLoader( |
|
val_iter, batch_size=self.batch_size, |
|
collate_fn=self.collate_fn |
|
) |
|
|
|
for src, tgt in val_dataloader: |
|
src = src.to(self.device) |
|
tgt = tgt.to(self.device) |
|
tgt_input = tgt[:-1, :] |
|
( |
|
src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask |
|
) = create_mask(src, tgt_input, self.device) |
|
|
|
logits = model( |
|
src, tgt_input, src_mask, tgt_mask, |
|
src_padding_mask, tgt_padding_mask, src_padding_mask |
|
) |
|
|
|
tgt_out = tgt[1:, :] |
|
loss = self.loss_fn( |
|
logits.reshape(-1, logits.shape[-1]), |
|
tgt_out.reshape(-1) |
|
) |
|
losses += loss.item() |
|
|
|
return losses / len(list(val_dataloader)) |
|
|
|
|
|
def translate( |
|
self, phoneme_sentence: str, beam_size=0 |
|
): |
|
self.model.eval() |
|
dummy_sentence = self.text_vocab.lookup_token( |
|
len(self.text_vocab) - 1 |
|
) |
|
src, _ = self.collate_fn(zip( |
|
[phoneme_sentence], [dummy_sentence] |
|
)) |
|
|
|
batch_size, max_seq_len = src.shape |
|
src = src.to(self.device) |
|
|
|
num_tokens = src.shape[0] |
|
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) |
|
max_len = num_tokens + 5 |
|
|
|
if beam_size > 0: |
|
tgt_tokens = self.beam_search_decode( |
|
src, src_mask, max_len=max_len, |
|
start_symbol=BOS_IDX, beam_size=beam_size |
|
) |
|
else: |
|
tgt_tokens = self.greedy_decode( |
|
src, src_mask, max_len=max_len, |
|
start_symbol=BOS_IDX |
|
) |
|
|
|
gap = ' ' if self.word_tokenize else '' |
|
pred_sentence = self.batch_indices_to_text( |
|
tgt_tokens, batch_size=max_seq_len, gap=gap |
|
)[0] |
|
return pred_sentence |
|
|
|
|
|
def greedy_decode(self, src, src_mask, max_len, start_symbol): |
|
src = src.to(self.device) |
|
src_mask = src_mask.to(self.device) |
|
memory = self.model.encode(src, src_mask) |
|
ys = ( |
|
torch.ones(1, 1).fill_(start_symbol). |
|
type(torch.long).to(self.device) |
|
) |
|
|
|
for i in range(max_len - 1): |
|
memory = memory.to(self.device) |
|
tgt_mask = ( |
|
generate_square_subsequent_mask( |
|
ys.size(0), device=self.device |
|
).type(torch.bool) |
|
).to(self.device) |
|
|
|
out = self.model.decode(ys, memory, tgt_mask) |
|
out = out.transpose(0, 1) |
|
prob = self.model.generator(out[:, -1]) |
|
_, next_word = torch.max(prob, dim=1) |
|
next_word = next_word.item() |
|
|
|
ys = torch.cat([ |
|
ys, torch.ones(1, 1).type_as(src.data).fill_(next_word) |
|
], dim=0) |
|
|
|
if next_word == EOS_IDX: |
|
break |
|
|
|
return ys |
|
|
|
def beam_search_decode( |
|
self, src, src_mask, max_len, start_symbol, beam_size=5 |
|
): |
|
src = src.to(self.device) |
|
src_mask = src_mask.to(self.device) |
|
memory = self.model.encode(src, src_mask) |
|
ys = ( |
|
torch.ones(1, 1).fill_(start_symbol). |
|
type(torch.long).to(self.device) |
|
) |
|
|
|
|
|
hypotheses = [(ys, 0.0)] |
|
|
|
for _ in range(max_len - 1): |
|
new_hypotheses = [] |
|
|
|
for seq, score in hypotheses: |
|
if seq[-1] == EOS_IDX: |
|
new_hypotheses.append((seq, score)) |
|
continue |
|
|
|
tgt_mask = generate_square_subsequent_mask( |
|
seq.size(0), device=self.device |
|
).type(torch.bool) |
|
|
|
out = self.model.decode(seq, memory, tgt_mask) |
|
out = out.transpose(0, 1) |
|
prob = self.model.generator(out[:, -1]) |
|
|
|
topk_prob, topk_indices = torch.topk(prob, beam_size) |
|
|
|
for i in range(beam_size): |
|
next_word = topk_indices[0][i] |
|
|
|
next_score = score - topk_prob[0][i].item() |
|
new_seq = torch.cat([ |
|
seq, torch.ones(1, 1).type_as(src.data).fill_(next_word) |
|
], dim=0) |
|
|
|
|
|
new_hypotheses.append((new_seq, next_score)) |
|
|
|
if len(new_hypotheses) == 0: |
|
break |
|
|
|
|
|
hypotheses = sorted( |
|
new_hypotheses, key=lambda x: x[1] |
|
)[:beam_size] |
|
|
|
return hypotheses[0][0] |
|
|
|
@staticmethod |
|
def yield_tokens(sequence_map, tokenizer): |
|
for key in sequence_map: |
|
yield tokenizer(sequence_map[key]) |
|
|
|
def build_vocab(self, sequence_map, tokenizer): |
|
return build_vocab_from_iterator( |
|
self.yield_tokens(sequence_map, tokenizer), |
|
min_freq=1, specials=SPECIAL_SYMBOLS, |
|
special_first=True |
|
) |
|
|
|
|
|
@staticmethod |
|
def sequential_transforms(*transforms): |
|
def func(txt_input): |
|
for transform in transforms: |
|
txt_input = transform(txt_input) |
|
|
|
return txt_input |
|
|
|
return func |
|
|
|
|
|
@staticmethod |
|
def tensor_transform(token_ids: List[int]): |
|
return torch.cat(( |
|
torch.tensor([BOS_IDX]), torch.tensor(token_ids), |
|
torch.tensor([EOS_IDX]) |
|
)) |
|
|
|
|
|
if __name__ == '__main__': |
|
vocab_filepaths = ( |
|
'data/grid_phoneme_vocab.pth', |
|
'data/grid_text_char_vocab.pth' |
|
) |
|
""" |
|
vocab_filepaths = ( |
|
'data/lsr2_phoneme_vocab.pth', |
|
'data/lsr2_text_char_vocab.pth' |
|
) |
|
""" |
|
|
|
trainer = TranslatorTrainer( |
|
word_tokenize=False, vocab_files=vocab_filepaths, |
|
input_char_map=options.char_map, |
|
output_char_map=options.text_char_map |
|
) |
|
|
|
trainer.train() |
|
|
|
|
|
|
|
|