Matthew
initial commit
0392181
# --------------------------------------------------------
# OpenVQA
# Written by Yuhao Cui https://github.com/cuiyuhao1996
# --------------------------------------------------------
import os, torch, datetime, shutil, time
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from openvqa.models.model_loader import ModelLoader
from openvqa.utils.optim import get_optim, adjust_lr
from utils.test_engine import test_engine, ckpt_proc
from utils.extract_engine import extract_engine
def train_engine(__C, dataset, dataset_eval=None):
data_size = dataset.data_size
token_size = dataset.token_size
ans_size = dataset.ans_size
pretrained_emb = dataset.pretrained_emb
net = ModelLoader(__C).Net(
__C,
pretrained_emb,
token_size,
ans_size
)
net.cuda()
net.train()
if __C.N_GPU > 1:
net = nn.DataParallel(net, device_ids=__C.DEVICES)
# Define Loss Function
loss_fn = eval('torch.nn.' + __C.LOSS_FUNC_NAME_DICT[__C.LOSS_FUNC] + "(reduction='" + __C.LOSS_REDUCTION + "').cuda()")
# Load checkpoint if resume training
if __C.RESUME:
print(' ========== Resume training')
if __C.CKPT_PATH is not None:
print('Warning: Now using CKPT_PATH args, '
'CKPT_VERSION and CKPT_EPOCH will not work')
path = __C.CKPT_PATH
else:
path = __C.CKPTS_PATH + \
'/ckpt_' + __C.CKPT_VERSION + \
'/epoch' + str(__C.CKPT_EPOCH) + '.pkl'
# Load the network parameters
print('Loading ckpt from {}'.format(path))
ckpt = torch.load(path)
print('Finish!')
if __C.N_GPU > 1:
net.load_state_dict(ckpt_proc(ckpt['state_dict']))
else:
net.load_state_dict(ckpt['state_dict'])
start_epoch = ckpt['epoch']
# Load the optimizer paramters
optim = get_optim(__C, net, data_size, ckpt['lr_base'])
optim._step = int(data_size / __C.BATCH_SIZE * start_epoch)
optim.optimizer.load_state_dict(ckpt['optimizer'])
if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH):
os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)
else:
if ('ckpt_' + __C.VERSION) not in os.listdir(__C.CKPTS_PATH):
#shutil.rmtree(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)
os.mkdir(__C.CKPTS_PATH + '/ckpt_' + __C.VERSION)
optim = get_optim(__C, net, data_size)
start_epoch = 0
loss_sum = 0
named_params = list(net.named_parameters())
grad_norm = np.zeros(len(named_params))
# Define multi-thread dataloader
# if __C.SHUFFLE_MODE in ['external']:
# dataloader = Data.DataLoader(
# dataset,
# batch_size=__C.BATCH_SIZE,
# shuffle=False,
# num_workers=__C.NUM_WORKERS,
# pin_memory=__C.PIN_MEM,
# drop_last=True
# )
# else:
dataloader = Data.DataLoader(
dataset,
batch_size=__C.BATCH_SIZE,
shuffle=True,
num_workers=__C.NUM_WORKERS,
pin_memory=__C.PIN_MEM,
drop_last=True
)
logfile = open(
__C.LOG_PATH +
'/log_run_' + __C.VERSION + '.txt',
'a+'
)
logfile.write(str(__C))
logfile.close()
# Training script
for epoch in range(start_epoch, __C.MAX_EPOCH):
# Save log to file
logfile = open(
__C.LOG_PATH +
'/log_run_' + __C.VERSION + '.txt',
'a+'
)
logfile.write(
'=====================================\nnowTime: ' +
datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S') +
'\n'
)
logfile.close()
# Learning Rate Decay
if epoch in __C.LR_DECAY_LIST:
adjust_lr(optim, __C.LR_DECAY_R)
# Externally shuffle data list
# if __C.SHUFFLE_MODE == 'external':
# dataset.shuffle_list(dataset.ans_list)
time_start = time.time()
# Iteration
for step, (
frcn_feat_iter,
grid_feat_iter,
bbox_feat_iter,
ques_ix_iter,
ans_iter
) in enumerate(dataloader):
optim.zero_grad()
frcn_feat_iter = frcn_feat_iter.cuda()
grid_feat_iter = grid_feat_iter.cuda()
bbox_feat_iter = bbox_feat_iter.cuda()
ques_ix_iter = ques_ix_iter.cuda()
ans_iter = ans_iter.cuda()
loss_tmp = 0
for accu_step in range(__C.GRAD_ACCU_STEPS):
loss_tmp = 0
sub_frcn_feat_iter = \
frcn_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
(accu_step + 1) * __C.SUB_BATCH_SIZE]
sub_grid_feat_iter = \
grid_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
(accu_step + 1) * __C.SUB_BATCH_SIZE]
sub_bbox_feat_iter = \
bbox_feat_iter[accu_step * __C.SUB_BATCH_SIZE:
(accu_step + 1) * __C.SUB_BATCH_SIZE]
sub_ques_ix_iter = \
ques_ix_iter[accu_step * __C.SUB_BATCH_SIZE:
(accu_step + 1) * __C.SUB_BATCH_SIZE]
sub_ans_iter = \
ans_iter[accu_step * __C.SUB_BATCH_SIZE:
(accu_step + 1) * __C.SUB_BATCH_SIZE]
pred = net(
sub_frcn_feat_iter,
sub_grid_feat_iter,
sub_bbox_feat_iter,
sub_ques_ix_iter
)
loss_item = [pred, sub_ans_iter]
loss_nonlinear_list = __C.LOSS_FUNC_NONLINEAR[__C.LOSS_FUNC]
for item_ix, loss_nonlinear in enumerate(loss_nonlinear_list):
if loss_nonlinear in ['flat']:
loss_item[item_ix] = loss_item[item_ix].view(-1)
elif loss_nonlinear:
loss_item[item_ix] = eval('F.' + loss_nonlinear + '(loss_item[item_ix], dim=1)')
loss = loss_fn(loss_item[0], loss_item[1])
if __C.LOSS_REDUCTION == 'mean':
# only mean-reduction needs be divided by grad_accu_steps
loss /= __C.GRAD_ACCU_STEPS
loss.backward()
loss_tmp += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
loss_sum += loss.cpu().data.numpy() * __C.GRAD_ACCU_STEPS
if __C.VERBOSE:
if dataset_eval is not None:
mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['val']
else:
mode_str = __C.SPLIT['train'] + '->' + __C.SPLIT['test']
print("\r[Version %s][Model %s][Dataset %s][Epoch %2d][Step %4d/%4d][%s] Loss: %.4f, Lr: %.2e" % (
__C.VERSION,
__C.MODEL_USE,
__C.DATASET,
epoch + 1,
step,
int(data_size / __C.BATCH_SIZE),
mode_str,
loss_tmp / __C.SUB_BATCH_SIZE,
optim._rate
), end=' ')
# Gradient norm clipping
if __C.GRAD_NORM_CLIP > 0:
nn.utils.clip_grad_norm_(
net.parameters(),
__C.GRAD_NORM_CLIP
)
# Save the gradient information
for name in range(len(named_params)):
norm_v = torch.norm(named_params[name][1].grad).cpu().data.numpy() \
if named_params[name][1].grad is not None else 0
grad_norm[name] += norm_v * __C.GRAD_ACCU_STEPS
# print('Param %-3s Name %-80s Grad_Norm %-20s'%
# (str(grad_wt),
# params[grad_wt][0],
# str(norm_v)))
optim.step()
time_end = time.time()
elapse_time = time_end-time_start
print('Finished in {}s'.format(int(elapse_time)))
epoch_finish = epoch + 1
# Save checkpoint
if not __C.SAVE_LAST or epoch_finish == __C.MAX_EPOCH:
if __C.N_GPU > 1:
state = {
'state_dict': net.module.state_dict(),
'optimizer': optim.optimizer.state_dict(),
'lr_base': optim.lr_base,
'epoch': epoch_finish
}
else:
state = {
'state_dict': net.state_dict(),
'optimizer': optim.optimizer.state_dict(),
'lr_base': optim.lr_base,
'epoch': epoch_finish
}
torch.save(
state,
__C.CKPTS_PATH +
'/ckpt_' + __C.VERSION +
'/epoch' + str(epoch_finish) +
'.pkl'
)
# Logging
logfile = open(
__C.LOG_PATH +
'/log_run_' + __C.VERSION + '.txt',
'a+'
)
logfile.write(
'Epoch: ' + str(epoch_finish) +
', Loss: ' + str(loss_sum / data_size) +
', Lr: ' + str(optim._rate) + '\n' +
'Elapsed time: ' + str(int(elapse_time)) +
', Speed(s/batch): ' + str(elapse_time / step) +
'\n\n'
)
logfile.close()
# Eval after every epoch
if dataset_eval is not None:
test_engine(
__C,
dataset_eval,
state_dict=net.state_dict(),
validation=True
)
# if self.__C.VERBOSE:
# logfile = open(
# self.__C.LOG_PATH +
# '/log_run_' + self.__C.VERSION + '.txt',
# 'a+'
# )
# for name in range(len(named_params)):
# logfile.write(
# 'Param %-3s Name %-80s Grad_Norm %-25s\n' % (
# str(name),
# named_params[name][0],
# str(grad_norm[name] / data_size * self.__C.BATCH_SIZE)
# )
# )
# logfile.write('\n')
# logfile.close()
loss_sum = 0
grad_norm = np.zeros(len(named_params))
# Modification - optionally run full result extract after training ends
if __C.EXTRACT_AFTER:
extract_engine(__C, state_dict=net.state_dict())