|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import os |
|
import pprint |
|
import shutil |
|
|
|
import torch |
|
import torch.nn.parallel |
|
import torch.backends.cudnn as cudnn |
|
import torch.optim |
|
import torch.utils.data |
|
import torch.utils.data.distributed |
|
import torchvision.transforms as transforms |
|
from tensorboardX import SummaryWriter |
|
|
|
import _init_paths |
|
from config import cfg |
|
from config import update_config |
|
from core.loss import JointsMSELoss |
|
from core.function import train |
|
from core.function import validate |
|
from utils.utils import get_optimizer |
|
from utils.utils import save_checkpoint |
|
from utils.utils import create_logger |
|
from utils.utils import get_model_summary |
|
|
|
import dataset |
|
import models |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Train keypoints network') |
|
# general |
|
parser.add_argument('--cfg', |
|
help='experiment configure file name', |
|
required=True, |
|
type=str) |
|
|
|
parser.add_argument('opts', |
|
help="Modify config options using the command-line", |
|
default=None, |
|
nargs=argparse.REMAINDER) |
|
|
|
# philly |
|
parser.add_argument('--modelDir', |
|
help='model directory', |
|
type=str, |
|
default='') |
|
parser.add_argument('--logDir', |
|
help='log directory', |
|
type=str, |
|
default='') |
|
parser.add_argument('--dataDir', |
|
help='data directory', |
|
type=str, |
|
default='') |
|
parser.add_argument('--prevModelDir', |
|
help='prev Model directory', |
|
type=str, |
|
default='') |
|
|
|
args = parser.parse_args() |
|
|
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
update_config(cfg, args) |
|
|
|
logger, final_output_dir, tb_log_dir = create_logger( |
|
cfg, args.cfg, 'train') |
|
|
|
logger.info(pprint.pformat(args)) |
|
logger.info(cfg) |
|
|
|
# cudnn related setting |
|
cudnn.benchmark = cfg.CUDNN.BENCHMARK |
|
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
|
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
|
|
|
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( |
|
cfg, is_train=True |
|
) |
|
|
|
# copy model file |
|
this_dir = os.path.dirname(__file__) |
|
shutil.copy2( |
|
os.path.join(this_dir, '../lib/models', cfg.MODEL.NAME + '.py'), |
|
final_output_dir) |
|
# logger.info(pprint.pformat(model)) |
|
|
|
writer_dict = { |
|
'writer': SummaryWriter(log_dir=tb_log_dir), |
|
'train_global_steps': 0, |
|
'valid_global_steps': 0, |
|
} |
|
|
|
dump_input = torch.rand( |
|
(1, 3, cfg.MODEL.IMAGE_SIZE[1], cfg.MODEL.IMAGE_SIZE[0]) |
|
) |
|
writer_dict['writer'].add_graph(model, (dump_input, )) |
|
|
|
logger.info(get_model_summary(model, dump_input)) |
|
|
|
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() |
|
|
|
# define loss function (criterion) and optimizer |
|
criterion = JointsMSELoss( |
|
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT |
|
).cuda() |
|
|
|
# Data loading code |
|
normalize = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
train_dataset = eval('dataset.'+cfg.DATASET.DATASET)( |
|
cfg, cfg.DATASET.ROOT, cfg.DATASET.TRAIN_SET, True, |
|
transforms.Compose([ |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
) |
|
valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)( |
|
cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, |
|
transforms.Compose([ |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
) |
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
train_dataset, |
|
batch_size=cfg.TRAIN.BATCH_SIZE_PER_GPU*len(cfg.GPUS), |
|
shuffle=cfg.TRAIN.SHUFFLE, |
|
num_workers=cfg.WORKERS, |
|
pin_memory=cfg.PIN_MEMORY |
|
) |
|
valid_loader = torch.utils.data.DataLoader( |
|
valid_dataset, |
|
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS), |
|
shuffle=False, |
|
num_workers=cfg.WORKERS, |
|
pin_memory=cfg.PIN_MEMORY |
|
) |
|
|
|
best_perf = 0.0 |
|
best_model = False |
|
last_epoch = -1 |
|
optimizer = get_optimizer(cfg, model) |
|
begin_epoch = cfg.TRAIN.BEGIN_EPOCH |
|
checkpoint_file = os.path.join( |
|
final_output_dir, 'checkpoint.pth' |
|
) |
|
|
|
if cfg.AUTO_RESUME and os.path.exists(checkpoint_file): |
|
logger.info("=> loading checkpoint '{}'".format(checkpoint_file)) |
|
checkpoint = torch.load(checkpoint_file) |
|
begin_epoch = checkpoint['epoch'] |
|
best_perf = checkpoint['perf'] |
|
last_epoch = checkpoint['epoch'] |
|
model.load_state_dict(checkpoint['state_dict']) |
|
|
|
optimizer.load_state_dict(checkpoint['optimizer']) |
|
logger.info("=> loaded checkpoint '{}' (epoch {})".format( |
|
checkpoint_file, checkpoint['epoch'])) |
|
|
|
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR( |
|
optimizer, cfg.TRAIN.LR_STEP, cfg.TRAIN.LR_FACTOR, |
|
last_epoch=last_epoch |
|
) |
|
|
|
for epoch in range(begin_epoch, cfg.TRAIN.END_EPOCH): |
|
lr_scheduler.step() |
|
|
|
# train for one epoch |
|
train(cfg, train_loader, model, criterion, optimizer, epoch, |
|
final_output_dir, tb_log_dir, writer_dict) |
|
|
|
|
|
# evaluate on validation set |
|
perf_indicator = validate( |
|
cfg, valid_loader, valid_dataset, model, criterion, |
|
final_output_dir, tb_log_dir, writer_dict |
|
) |
|
|
|
if perf_indicator >= best_perf: |
|
best_perf = perf_indicator |
|
best_model = True |
|
print("best model so far!") |
|
else: |
|
best_model = False |
|
|
|
logger.info('=> saving checkpoint to {}'.format(final_output_dir)) |
|
save_checkpoint({ |
|
'epoch': epoch + 1, |
|
'model': cfg.MODEL.NAME, |
|
'state_dict': model.state_dict(), |
|
'best_state_dict': model.module.state_dict(), |
|
'perf': perf_indicator, |
|
'optimizer': optimizer.state_dict(), |
|
}, best_model, final_output_dir) |
|
|
|
final_model_state_file = os.path.join( |
|
final_output_dir, 'final_state.pth' |
|
) |
|
logger.info('=> saving final model state to {}'.format( |
|
final_model_state_file) |
|
) |
|
torch.save(model.module.state_dict(), final_model_state_file) |
|
writer_dict['writer'].close() |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|