|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import argparse |
|
import os |
|
import pprint |
|
|
|
import torch |
|
import torch.nn.parallel |
|
import torch.backends.cudnn as cudnn |
|
import torch.optim |
|
import torch.utils.data |
|
import torch.utils.data.distributed |
|
import torchvision.transforms as transforms |
|
|
|
import _init_paths |
|
from config import cfg |
|
from config import update_config |
|
from core.loss import JointsMSELoss |
|
from core.function import validate |
|
from utils.utils import create_logger |
|
|
|
import dataset |
|
import models |
|
|
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser(description='Train keypoints network') |
|
|
|
parser.add_argument('--cfg', |
|
help='experiment configure file name', |
|
required=True, |
|
type=str) |
|
|
|
parser.add_argument('opts', |
|
help="Modify config options using the command-line", |
|
default=None, |
|
nargs=argparse.REMAINDER) |
|
|
|
parser.add_argument('--modelDir', |
|
help='model directory', |
|
type=str, |
|
default='') |
|
parser.add_argument('--logDir', |
|
help='log directory', |
|
type=str, |
|
default='') |
|
parser.add_argument('--dataDir', |
|
help='data directory', |
|
type=str, |
|
default='') |
|
parser.add_argument('--prevModelDir', |
|
help='prev Model directory', |
|
type=str, |
|
default='') |
|
|
|
args = parser.parse_args() |
|
return args |
|
|
|
|
|
def main(): |
|
args = parse_args() |
|
update_config(cfg, args) |
|
|
|
logger, final_output_dir, tb_log_dir = create_logger( |
|
cfg, args.cfg, 'valid') |
|
|
|
logger.info(pprint.pformat(args)) |
|
logger.info(cfg) |
|
|
|
|
|
cudnn.benchmark = cfg.CUDNN.BENCHMARK |
|
torch.backends.cudnn.deterministic = cfg.CUDNN.DETERMINISTIC |
|
torch.backends.cudnn.enabled = cfg.CUDNN.ENABLED |
|
|
|
model = eval('models.'+cfg.MODEL.NAME+'.get_pose_net')( |
|
cfg, is_train=False |
|
) |
|
|
|
if cfg.TEST.MODEL_FILE: |
|
logger.info('=> loading model from {}'.format(cfg.TEST.MODEL_FILE)) |
|
model.load_state_dict(torch.load(cfg.TEST.MODEL_FILE), strict=False) |
|
else: |
|
model_state_file = os.path.join( |
|
final_output_dir, 'final_state.pth' |
|
) |
|
logger.info('=> loading model from {}'.format(model_state_file)) |
|
model.load_state_dict(torch.load(model_state_file)) |
|
|
|
model = torch.nn.DataParallel(model, device_ids=cfg.GPUS).cuda() |
|
|
|
|
|
criterion = JointsMSELoss( |
|
use_target_weight=cfg.LOSS.USE_TARGET_WEIGHT |
|
).cuda() |
|
|
|
|
|
normalize = transforms.Normalize( |
|
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] |
|
) |
|
valid_dataset = eval('dataset.'+cfg.DATASET.DATASET)( |
|
cfg, cfg.DATASET.ROOT, cfg.DATASET.TEST_SET, False, |
|
transforms.Compose([ |
|
transforms.ToTensor(), |
|
normalize, |
|
]) |
|
) |
|
valid_loader = torch.utils.data.DataLoader( |
|
valid_dataset, |
|
batch_size=cfg.TEST.BATCH_SIZE_PER_GPU*len(cfg.GPUS), |
|
shuffle=False, |
|
num_workers=cfg.WORKERS, |
|
pin_memory=True |
|
) |
|
|
|
|
|
validate(cfg, valid_loader, valid_dataset, model, criterion, |
|
final_output_dir, tb_log_dir) |
|
|
|
|
|
if __name__ == '__main__': |
|
main() |
|
|