Spaces:
Runtime error
Runtime error
# -------------------------------------------------------- | |
# OpenVQA | |
# Written by Yuhao Cui https://github.com/cuiyuhao1996 | |
# -------------------------------------------------------- | |
import os, torch, datetime, shutil, time | |
import numpy as np | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils.data as Data | |
from openvqa.models.model_loader import ModelLoader | |
from openvqa.utils.optim import get_optim, adjust_lr | |
from utils.test_engine import test_engine, ckpt_proc | |
from utils.extract_engine import extract_engine | |
def train_engine(__C, dataset, dataset_eval=None): | |
data_size = dataset.data_size | |
token_size = dataset.token_size | |
ans_size = dataset.ans_size | |
pretrained_emb = dataset.pretrained_emb | |
net = ModelLoader(__C).Net( | |
__C, | |
pretrained_emb, | |
token_size, | |
ans_size | |
) | |
net.cuda() | |
net.train() | |
if __C.N_GPU > 1: | |
net = nn.DataParallel(net, device_ids=__C.DEVICES) | |
# Define Loss Function | |
loss_fn = eval('torch.nn.' + __C.LOSS_FUNC_NAME_DICT[__C.LOSS_FUNC] + "(reduction='" + __C.LOSS_REDUCTION + "').cuda()") | |
# Load checkpoint if resume training | |
if __C.RESUME: | |
print(' ========== Resume training') | |
if __C.CKPT_PATH is not None: | |
print('Warning: Now using CKPT_PATH args, ' | |
'CKPT_VERSION and CKPT_EPOCH will not work') | |
path = __C.CKPT_PATH | |
else: | |
path = __C.CKPTS_PATH + \ | |
'/ckpt_' + __C.CKPT_VERSION + \ | |
'/epoch' + str(__C.CKPT_EPOCH) + '.pkl' | |
# Load the network parameters | |
print('Loading ckpt from {}'.format(path)) | |
ckpt = torch.load(path) | |
print('Finish!') | |
if __C.N_GPU > 1: | |
net.load_state_dict(ckpt_proc(ckpt['state_dict'])) | |
else: | |
net.load_state_dict(ckpt['state_dict']) | |
start_epoch = ckpt['epoch'] | |
# Load the optimizer paramters | |
optim = get_optim(__C, net, data_size, ckpt['lr_base']) | |
optim._step = int(data_size / __C.BATCH_SIZE * start_epoch) | |
optim.optimizer.load_state_dict(ckpt['optimizer']) | |
if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH): | |
os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION) | |
else: | |
if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH): | |
#shutil.rmtree(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION) | |
os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION) | |
optim = get_optim(__C, net, data_size) | |
start_epoch = 0 | |
loss_sum = 0 | |
named_params = list(net.named_parameters()) | |
grad_norm = np.zeros(len(named_params)) | |
# Define multi-thread dataloader | |
# if __C.SHUFFLE_MODE in ['external']: | |
# dataloader = Data.DataLoader( | |
# dataset, | |
# batch_size=__C.BATCH_SIZE, | |
# shuffle=False, | |
# num_workers=__C.NUM_WORKERS, | |
# pin_memory=__C.PIN_MEM, | |
# drop_last=True | |
# ) | |
# else: | |
dataloader = Data.DataLoader( | |
dataset, | |
batch_size=__C.BATCH_SIZE, | |
shuffle=True, | |
num_workers=__C.NUM_WORKERS, | |
pin_memory=__C.PIN_MEM, | |
drop_last=True | |
) | |
logfile = open( | |
__C.LOG_PATH + | |
'/log_run_' + __C.VERSION + '.txt', | |
'a+' | |
) | |
logfile.write(str(__C)) | |
logfile.close() | |
# Training script | |
for epoch in range(start_epoch, __C.MAX_EPOCH): | |
# Save log to file | |
logfile = open( | |
__C.LOG_PATH + | |
'/log_run_' + __C.VERSION + '.txt', | |
'a+' | |
) | |
logfile.write( | |
'=====================================\nnowTime: ' + | |
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') + | |
'\n' | |
) | |
logfile.close() | |
# Learning Rate Decay | |
if epoch in __C.LR_DECAY_LIST: | |
adjust_lr(optim, __C.LR_DECAY_R) | |
# Externally shuffle data list | |
# if __C.SHUFFLE_MODE == 'external': | |
# dataset.shuffle_list(dataset.ans_list) | |
time_start = time.time() | |
# Iteration | |
for step, ( | |
frcn_feat_iter, | |
grid_feat_iter, | |
bbox_feat_iter, | |
ques_ix_iter, | |
ans_iter | |
) in enumerate(dataloader): | |
optim.zero_grad() | |
frcn_feat_iter = frcn_feat_iter.cuda() | |
grid_feat_iter = grid_feat_iter.cuda() | |
bbox_feat_iter = bbox_feat_iter.cuda() | |
ques_ix_iter = ques_ix_iter.cuda() | |
ans_iter = ans_iter.cuda() | |
loss_tmp = 0 | |
for accu_step in range(__C.GRAD_ACCU_STEPS): | |
loss_tmp = 0 | |
sub_frcn_feat_iter = \ | |
frcn_feat_iter[accu_step * __C.SUB_BATCH_SIZE: | |
(accu_step + 1) * __C.SUB_BATCH_SIZE] | |
sub_grid_feat_iter = \ | |
grid_feat_iter[accu_step * __C.SUB_BATCH_SIZE: | |
(accu_step + 1) * __C.SUB_BATCH_SIZE] | |
sub_bbox_feat_iter = \ | |
bbox_feat_iter[accu_step * __C.SUB_BATCH_SIZE: | |
(accu_step + 1) * __C.SUB_BATCH_SIZE] | |
sub_ques_ix_iter = \ | |
ques_ix_iter[accu_step * __C.SUB_BATCH_SIZE: | |
(accu_step + 1) * __C.SUB_BATCH_SIZE] | |
sub_ans_iter = \ | |
ans_iter[accu_step * __C.SUB_BATCH_SIZE: | |
(accu_step + 1) * __C.SUB_BATCH_SIZE] | |
pred = net( | |
sub_frcn_feat_iter, | |
sub_grid_feat_iter, | |
sub_bbox_feat_iter, | |
sub_ques_ix_iter | |
) | |
loss_item = [pred, sub_ans_iter] | |
loss_nonlinear_list = __C.LOSS_FUNC_NONLINEAR[__C.LOSS_FUNC] | |
for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list): | |
if loss_nonlinear in ['flat']: | |
loss_item[item_ix] = loss_item[item_ix].view(-1) | |
elif loss_nonlinear: | |
loss_item[item_ix] = eval('F.' + loss_nonlinear + '(loss_item[item_ix], dim=1)') | |
loss = loss_fn(loss_item[0], loss_item[1]) | |
if __C.LOSS_REDUCTION == 'mean': | |
# only mean-reduction needs be divided by grad_accu_steps | |
loss /= __C.GRAD_ACCU_STEPS | |
loss.backward() | |
loss_tmp += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS | |
loss_sum += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS | |
if __C.VERBOSE: | |
if dataset_eval is not None: | |
mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['val'] | |
else: | |
mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['test'] | |
print("\r[Version %s][Model %s][Dataset %s][Epoch %2d][Step %4d/%4d][%s] Loss: %.4f, Lr: %.2e" % ( | |
__C.VERSION, | |
__C.MODEL_USE, | |
__C.DATASET, | |
epoch + 1, | |
step, | |
int(data_size / __C.BATCH_SIZE), | |
mode_str, | |
loss_tmp / __C.SUB_BATCH_SIZE, | |
optim._rate | |
), end=' ') | |
# Gradient norm clipping | |
if __C.GRAD_NORM_CLIP > 0: | |
nn.utils.clip_grad_norm_( | |
net.parameters(), | |
__C.GRAD_NORM_CLIP | |
) | |
# Save the gradient information | |
for name in range(len(named_params)): | |
norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \ | |
if named_params[name][1].grad is not None else 0 | |
grad_norm[name] += norm_v * __C.GRAD_ACCU_STEPS | |
# print('Param %-3s Name %-80s Grad_Norm %-20s'% | |
# (str(grad_wt), | |
# params[grad_wt][0], | |
# str(norm_v))) | |
optim.step() | |
time_end = time.time() | |
elapse_time = time_end-time_start | |
print('Finished in {}s'.format(int(elapse_time))) | |
epoch_finish = epoch + 1 | |
# Save checkpoint | |
if not __C.SAVE_LAST or epoch_finish == __C.MAX_EPOCH: | |
if __C.N_GPU > 1: | |
state = { | |
'state_dict': net.module.state_dict(), | |
'optimizer': optim.optimizer.state_dict(), | |
'lr_base': optim.lr_base, | |
'epoch': epoch_finish | |
} | |
else: | |
state = { | |
'state_dict': net.state_dict(), | |
'optimizer': optim.optimizer.state_dict(), | |
'lr_base': optim.lr_base, | |
'epoch': epoch_finish | |
} | |
torch.save( | |
state, | |
__C.CKPTS_PATH + | |
'/ckpt_' + __C.VERSION + | |
'/epoch' + str(epoch_finish) + | |
'.pkl' | |
) | |
# Logging | |
logfile = open( | |
__C.LOG_PATH + | |
'/log_run_' + __C.VERSION + '.txt', | |
'a+' | |
) | |
logfile.write( | |
'Epoch: ' + str(epoch_finish) + | |
', Loss: ' + str(loss_sum / data_size) + | |
', Lr: ' + str(optim._rate) + '\n' + | |
'Elapsed time: ' + str(int(elapse_time)) + | |
', Speed(s/batch): ' + str(elapse_time / step) + | |
'\n\n' | |
) | |
logfile.close() | |
# Eval after every epoch | |
if dataset_eval is not None: | |
test_engine( | |
__C, | |
dataset_eval, | |
state_dict=net.state_dict(), | |
validation=True | |
) | |
# if self.__C.VERBOSE: | |
# logfile = open( | |
# self.__C.LOG_PATH + | |
# '/log_run_' + self.__C.VERSION + '.txt', | |
# 'a+' | |
# ) | |
# for name in range(len(named_params)): | |
# logfile.write( | |
# 'Param %-3s Name %-80s Grad_Norm %-25s\n' % ( | |
# str(name), | |
# named_params[name][0], | |
# str(grad_norm[name] / data_size * self.__C.BATCH_SIZE) | |
# ) | |
# ) | |
# logfile.write('\n') | |
# logfile.close() | |
loss_sum = 0 | |
grad_norm = np.zeros(len(named_params)) | |
# Modification - optionally run full result extract after training ends | |
if __C.EXTRACT_AFTER: | |
extract_engine(__C, state_dict=net.state_dict()) |