torchnet / Trainer.py
milselarch's picture
push to main
df07554
raw
history blame contribute delete
No virus
12.8 kB
import torch.nn as nn
import functools
import torch.optim as optim
import options as opt
import time
from helpers import *
from dataset import GridDataset, CharMap
from datetime import datetime as Datetime
from models.LipNet import LipNet
from tqdm.auto import tqdm
from PauseChecker import PauseChecker
from torch.utils.data import DataLoader
from torch.multiprocessing import Manager
from BaseTrainer import BaseTrainer
class Trainer(BaseTrainer):
def __init__(
self, name=opt.run_name, write_logs=True,
num_workers=None, base_dir='', char_map=opt.char_map,
pre_gru_repeats=None
):
super().__init__(name=name, base_dir=base_dir)
images_dir = opt.images_dir
if opt.use_lip_crops:
images_dir = opt.crop_images_dir
if num_workers is None:
num_workers = opt.num_workers
if pre_gru_repeats is None:
pre_gru_repeats = opt.pre_gru_repeats
assert pre_gru_repeats >= 1
assert isinstance(pre_gru_repeats, int)
self.images_dir = images_dir
self.num_workers = num_workers
self.pre_gru_repeats = pre_gru_repeats
self.char_map = char_map
manager = Manager()
if opt.cache_videos:
shared_dict = manager.dict()
else:
shared_dict = None
self.shared_dict = shared_dict
self.dataset_kwargs = self.get_dataset_kwargs(
shared_dict=shared_dict, base_dir=self.base_dir,
char_map=self.char_map
)
self.best_test_loss = float('inf')
self.train_dataset = None
self.test_dataset = None
self.model = None
self.net = None
if write_logs:
self.init_tensorboard()
def load_datasets(self):
if self.train_dataset is None:
self.train_dataset = GridDataset(
**self.dataset_kwargs, phase='train',
file_list=opt.train_list
)
if self.test_dataset is None:
self.test_dataset = GridDataset(
**self.dataset_kwargs, phase='test',
file_list=opt.val_list
)
def create_model(self):
output_classes = len(self.train_dataset.get_char_mapping())
if self.model is None:
self.model = LipNet(
output_classes=output_classes,
pre_gru_repeats=self.pre_gru_repeats
)
self.model = self.model.cuda()
if self.net is None:
self.net = nn.DataParallel(self.model).cuda()
def load_weights(self, weights_path):
self.load_datasets()
self.create_model()
weights_path = os.path.join(self.base_dir, weights_path)
pretrained_dict = torch.load(weights_path)
model_dict = self.model.state_dict()
pretrained_dict = {
k: v for k, v in pretrained_dict.items() if
k in model_dict.keys() and v.size() == model_dict[k].size()
}
missed_params = [
k for k, v in model_dict.items()
if k not in pretrained_dict.keys()
]
print('loaded params/tot params: {}/{}'.format(
len(pretrained_dict), len(model_dict)
))
print('miss matched params:{}'.format(missed_params))
model_dict.update(pretrained_dict)
self.model.load_state_dict(model_dict)
@staticmethod
def make_date_stamp():
return Datetime.now().strftime("%y%m%d-%H%M")
@staticmethod
def dataset2dataloader(
dataset, num_workers, shuffle=True
):
return DataLoader(
dataset,
batch_size=opt.batch_size,
shuffle=shuffle,
num_workers=num_workers,
drop_last=False
)
def test(self):
dataset = self.test_dataset
with torch.no_grad():
print('num_test_data:{}'.format(len(dataset.data)))
self.model.eval()
loader = self.dataset2dataloader(
dataset, shuffle=False, num_workers=self.num_workers
)
loss_list = []
wer = []
cer = []
crit = nn.CTCLoss(zero_infinity=True)
tic = time.time()
print('RUNNING VALIDATION')
pbar = tqdm(loader)
for (i_iter, input_sample) in enumerate(pbar):
PauseChecker.check()
vid = input_sample.get('vid').cuda()
vid_len = input_sample.get('vid_len').cuda()
txt, txt_len = self.extract_char_output(input_sample)
y = self.net(vid)
# assert not contains_nan_or_inf(y)
assert (
self.pre_gru_repeats * vid_len.view(-1) >
2 * txt_len.view(-1)
).all()
loss = crit(
y.transpose(0, 1).log_softmax(-1), txt,
self.pre_gru_repeats * vid_len.view(-1),
txt_len.view(-1)
).detach().cpu().numpy()
loss_list.append(loss)
pred_txt = dataset.ctc_decode(y)
truth_txt = [
dataset.arr2txt(txt[_], start=1)
for _ in range(txt.size(0))
]
wer.extend(dataset.wer(pred_txt, truth_txt))
cer.extend(dataset.cer(pred_txt, truth_txt))
if i_iter % opt.display == 0:
v = 1.0 * (time.time() - tic) / (i_iter + 1)
eta = v * (len(loader) - i_iter) / 3600.0
self.log_pred_texts(pred_txt, truth_txt, sub_samples=10)
print('test_iter={},eta={},wer={},cer={}'.format(
i_iter, eta, np.array(wer).mean(),
np.array(cer).mean()
))
print(''.join(161 * '-'))
return (
np.array(loss_list).mean(), np.array(wer).mean(),
np.array(cer).mean()
)
def extract_char_output(self, input_sample):
"""
extract output character sequence from input_sample
output character sequence is text if char_map is CharMap.letters
output character sequence is phonemes if char_map is CharMap.phonemes
"""
if self.char_map == CharMap.letters:
txt = input_sample.get('txt').cuda()
txt_len = input_sample.get('txt_len').cuda()
elif self.char_map == CharMap.phonemes:
txt = input_sample.get('phonemes').cuda()
txt_len = input_sample.get('phonemes_len').cuda()
elif self.char_map == CharMap.cmu_phonemes:
txt = input_sample.get('cmu_phonemes').cuda()
txt_len = input_sample.get('cmu_phonemes_len').cuda()
else:
raise ValueError(f'UNSUPPORTED CHAR_MAP: {self.char_map}')
return txt, txt_len
def train(self):
self.load_datasets()
self.create_model()
dataset = self.train_dataset
loader = self.dataset2dataloader(
dataset, num_workers=self.num_workers
)
"""
optimizer = optim.Adam(
self.model.parameters(), lr=opt.base_lr,
weight_decay=0., amsgrad=True
)
"""
optimizer = optim.RMSprop(
self.model.parameters(), lr=opt.base_lr
)
print('num_train_data:{}'.format(len(dataset.data)))
# don't allow loss function to create infinite loss for
# sequences that are too short
crit = nn.CTCLoss(zero_infinity=True)
tic = time.time()
train_wer = []
self.best_test_loss = float('inf')
log_scalar = functools.partial(self.log_scalar, label='train')
for epoch in range(opt.max_epoch):
print(f'RUNNING EPOCH {epoch}')
pbar = tqdm(loader)
for (i_iter, input_sample) in enumerate(pbar):
PauseChecker.check()
self.model.train()
vid = input_sample.get('vid').cuda()
vid_len = input_sample.get('vid_len').cuda()
txt, txt_len = self.extract_char_output(input_sample)
optimizer.zero_grad()
y = self.net(vid)
assert not contains_nan_or_inf(y)
assert (
self.pre_gru_repeats * vid_len.view(-1) >
2 * txt_len.view(-1)
).all()
loss = crit(
y.transpose(0, 1).log_softmax(-1), txt,
self.pre_gru_repeats * vid_len.view(-1),
txt_len.view(-1)
)
if contains_nan_or_inf(loss):
print(f'LOSS IS INVALID. SKIPPING {i_iter}')
# print('Y', y)
# print('txt', txt)
continue
loss.backward()
params = self.model.parameters()
# Check for NaNs in gradients
if any(torch.isnan(p.grad).any() for p in params):
optimizer.zero_grad() # Clear gradients to prevent update
print('SKIPPING NAN GRADS')
continue
if opt.is_optimize:
optimizer.step()
assert not contains_nan_or_inf(self.model.conv1.weight)
tot_iter = i_iter + epoch * len(loader)
pred_txt = dataset.ctc_decode(y)
truth_txt = [
dataset.arr2txt(txt[_], start=1)
for _ in range(txt.size(0))
]
train_wer.extend(dataset.wer(pred_txt, truth_txt))
if tot_iter % opt.display == 0:
v = 1.0 * (time.time() - tic) / (tot_iter + 1)
eta = (len(loader) - i_iter) * v / 3600.0
wer = np.array(train_wer).mean()
log_scalar('loss', loss, tot_iter)
log_scalar('wer', wer, tot_iter)
self.log_pred_texts(pred_txt, truth_txt, sub_samples=3)
print('epoch={},tot_iter={},eta={},loss={},train_wer={}'
.format(
epoch, tot_iter, eta, loss,
np.array(train_wer).mean()
)
)
print(''.join(161 * '-'))
if (tot_iter > 0) and (tot_iter % opt.test_step == 0):
# if tot_iter % opt.test_step == 0:
self.run_test(tot_iter, optimizer)
@staticmethod
def log_pred_texts(pred_txt, truth_txt, pad=80, sub_samples=None):
line_length = 2 * pad + 1
print(''.join(line_length * '-'))
print('{:<{pad}}|{:>{pad}}'.format(
'predict', 'truth', pad=pad
))
print(''.join(line_length * '-'))
zipped_samples = list(zip(pred_txt, truth_txt))
if sub_samples is not None:
zipped_samples = zipped_samples[:sub_samples]
for (predict, truth) in zipped_samples:
print('{:<{pad}}|{:>{pad}}'.format(
predict, truth, pad=pad
))
print(''.join(line_length * '-'))
def run_test(self, tot_iter, optimizer):
log_scalar = functools.partial(self.log_scalar, label='test')
(loss, wer, cer) = self.test()
print('i_iter={},lr={},loss={},wer={},cer={}'.format(
tot_iter, show_lr(optimizer), loss, wer, cer
))
log_scalar('loss', loss, tot_iter)
log_scalar('wer', wer, tot_iter)
log_scalar('cer', cer, tot_iter)
if loss < self.best_test_loss:
print(f'NEW BEST LOSS: {loss}')
self.best_test_loss = loss
savename = 'I{}-L{:.4f}-W{:.4f}-C{:.4f}'.format(
tot_iter, loss, wer, cer
)
savename = savename.replace('.', '') + '.pt'
savepath = os.path.join(self.weights_dir, savename)
(save_dir, name) = os.path.split(savepath)
if not os.path.exists(save_dir):
os.makedirs(save_dir)
torch.save(self.model.state_dict(), savepath)
print(f'best model saved at {savepath}')
if not opt.is_optimize:
exit()
def predict_sample(self, input_sample):
self.model.eval()
vid = input_sample.get('vid').cuda()
return self.predict_video(vid)
def predict_video(self, video):
video = video.cuda()
vid = video.unsqueeze(0)
y = self.net(vid)
pred_txt = self.train_dataset.ctc_decode(y)
return pred_txt