Spaces:
Build error
Build error
#!/usr/bin/env python | |
# -*- encoding: utf-8 -*- | |
""" | |
@Author : Peike Li | |
@Contact : peike.li@yahoo.com | |
@File : train.py | |
@Time : 8/4/19 3:36 PM | |
@Desc : | |
@License : This source code is licensed under the license found in the | |
LICENSE file in the root directory of this source tree. | |
""" | |
import os | |
import json | |
import timeit | |
import argparse | |
import torch | |
import torch.optim as optim | |
import torchvision.transforms as transforms | |
import torch.backends.cudnn as cudnn | |
from torch.utils import data | |
import networks | |
import utils.schp as schp | |
from datasets.datasets import LIPDataSet | |
from datasets.target_generation import generate_edge_tensor | |
from utils.transforms import BGR2RGB_transform | |
from utils.criterion import CriterionAll | |
from utils.encoding import DataParallelModel, DataParallelCriterion | |
from utils.warmup_scheduler import SGDRScheduler | |
def get_arguments(): | |
"""Parse all the arguments provided from the CLI. | |
Returns: | |
A list of parsed arguments. | |
""" | |
parser = argparse.ArgumentParser(description="Self Correction for Human Parsing") | |
# Network Structure | |
parser.add_argument("--arch", type=str, default='resnet101') | |
# Data Preference | |
parser.add_argument("--data-dir", type=str, default='./data/LIP') | |
parser.add_argument("--batch-size", type=int, default=16) | |
parser.add_argument("--input-size", type=str, default='473,473') | |
parser.add_argument("--num-classes", type=int, default=20) | |
parser.add_argument("--ignore-label", type=int, default=255) | |
parser.add_argument("--random-mirror", action="store_true") | |
parser.add_argument("--random-scale", action="store_true") | |
# Training Strategy | |
parser.add_argument("--learning-rate", type=float, default=7e-3) | |
parser.add_argument("--momentum", type=float, default=0.9) | |
parser.add_argument("--weight-decay", type=float, default=5e-4) | |
parser.add_argument("--gpu", type=str, default='0,1,2') | |
parser.add_argument("--start-epoch", type=int, default=0) | |
parser.add_argument("--epochs", type=int, default=150) | |
parser.add_argument("--eval-epochs", type=int, default=10) | |
parser.add_argument("--imagenet-pretrain", type=str, default='./pretrain_model/resnet101-imagenet.pth') | |
parser.add_argument("--log-dir", type=str, default='./log') | |
parser.add_argument("--model-restore", type=str, default='./log/checkpoint.pth.tar') | |
parser.add_argument("--schp-start", type=int, default=100, help='schp start epoch') | |
parser.add_argument("--cycle-epochs", type=int, default=10, help='schp cyclical epoch') | |
parser.add_argument("--schp-restore", type=str, default='./log/schp_checkpoint.pth.tar') | |
parser.add_argument("--lambda-s", type=float, default=1, help='segmentation loss weight') | |
parser.add_argument("--lambda-e", type=float, default=1, help='edge loss weight') | |
parser.add_argument("--lambda-c", type=float, default=0.1, help='segmentation-edge consistency loss weight') | |
return parser.parse_args() | |
def main(): | |
args = get_arguments() | |
print(args) | |
start_epoch = 0 | |
cycle_n = 0 | |
if not os.path.exists(args.log_dir): | |
os.makedirs(args.log_dir) | |
with open(os.path.join(args.log_dir, 'args.json'), 'w') as opt_file: | |
json.dump(vars(args), opt_file) | |
gpus = [int(i) for i in args.gpu.split(',')] | |
if not args.gpu == 'None': | |
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu | |
input_size = list(map(int, args.input_size.split(','))) | |
cudnn.enabled = True | |
cudnn.benchmark = True | |
# Model Initialization | |
AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) | |
model = DataParallelModel(AugmentCE2P) | |
model.cuda() | |
IMAGE_MEAN = AugmentCE2P.mean | |
IMAGE_STD = AugmentCE2P.std | |
INPUT_SPACE = AugmentCE2P.input_space | |
print('image mean: {}'.format(IMAGE_MEAN)) | |
print('image std: {}'.format(IMAGE_STD)) | |
print('input space:{}'.format(INPUT_SPACE)) | |
restore_from = args.model_restore | |
if os.path.exists(restore_from): | |
print('Resume training from {}'.format(restore_from)) | |
checkpoint = torch.load(restore_from) | |
model.load_state_dict(checkpoint['state_dict']) | |
start_epoch = checkpoint['epoch'] | |
SCHP_AugmentCE2P = networks.init_model(args.arch, num_classes=args.num_classes, pretrained=args.imagenet_pretrain) | |
schp_model = DataParallelModel(SCHP_AugmentCE2P) | |
schp_model.cuda() | |
if os.path.exists(args.schp_restore): | |
print('Resuming schp checkpoint from {}'.format(args.schp_restore)) | |
schp_checkpoint = torch.load(args.schp_restore) | |
schp_model_state_dict = schp_checkpoint['state_dict'] | |
cycle_n = schp_checkpoint['cycle_n'] | |
schp_model.load_state_dict(schp_model_state_dict) | |
# Loss Function | |
criterion = CriterionAll(lambda_1=args.lambda_s, lambda_2=args.lambda_e, lambda_3=args.lambda_c, | |
num_classes=args.num_classes) | |
criterion = DataParallelCriterion(criterion) | |
criterion.cuda() | |
# Data Loader | |
if INPUT_SPACE == 'BGR': | |
print('BGR Transformation') | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Normalize(mean=IMAGE_MEAN, | |
std=IMAGE_STD), | |
]) | |
elif INPUT_SPACE == 'RGB': | |
print('RGB Transformation') | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
BGR2RGB_transform(), | |
transforms.Normalize(mean=IMAGE_MEAN, | |
std=IMAGE_STD), | |
]) | |
train_dataset = LIPDataSet(args.data_dir, 'train', crop_size=input_size, transform=transform) | |
train_loader = data.DataLoader(train_dataset, batch_size=args.batch_size * len(gpus), | |
num_workers=16, shuffle=True, pin_memory=True, drop_last=True) | |
print('Total training samples: {}'.format(len(train_dataset))) | |
# Optimizer Initialization | |
optimizer = optim.SGD(model.parameters(), lr=args.learning_rate, momentum=args.momentum, | |
weight_decay=args.weight_decay) | |
lr_scheduler = SGDRScheduler(optimizer, total_epoch=args.epochs, | |
eta_min=args.learning_rate / 100, warmup_epoch=10, | |
start_cyclical=args.schp_start, cyclical_base_lr=args.learning_rate / 2, | |
cyclical_epoch=args.cycle_epochs) | |
total_iters = args.epochs * len(train_loader) | |
start = timeit.default_timer() | |
for epoch in range(start_epoch, args.epochs): | |
lr_scheduler.step(epoch=epoch) | |
lr = lr_scheduler.get_lr()[0] | |
model.train() | |
for i_iter, batch in enumerate(train_loader): | |
i_iter += len(train_loader) * epoch | |
images, labels, _ = batch | |
labels = labels.cuda(non_blocking=True) | |
edges = generate_edge_tensor(labels) | |
labels = labels.type(torch.cuda.LongTensor) | |
edges = edges.type(torch.cuda.LongTensor) | |
preds = model(images) | |
# Online Self Correction Cycle with Label Refinement | |
if cycle_n >= 1: | |
with torch.no_grad(): | |
soft_preds = schp_model(images) | |
soft_parsing = [] | |
soft_edge = [] | |
for soft_pred in soft_preds: | |
soft_parsing.append(soft_pred[0][-1]) | |
soft_edge.append(soft_pred[1][-1]) | |
soft_preds = torch.cat(soft_parsing, dim=0) | |
soft_edges = torch.cat(soft_edge, dim=0) | |
else: | |
soft_preds = None | |
soft_edges = None | |
loss = criterion(preds, [labels, edges, soft_preds, soft_edges], cycle_n) | |
optimizer.zero_grad() | |
loss.backward() | |
optimizer.step() | |
if i_iter % 100 == 0: | |
print('iter = {} of {} completed, lr = {}, loss = {}'.format(i_iter, total_iters, lr, | |
loss.data.cpu().numpy())) | |
if (epoch + 1) % (args.eval_epochs) == 0: | |
schp.save_schp_checkpoint({ | |
'epoch': epoch + 1, | |
'state_dict': model.state_dict(), | |
}, False, args.log_dir, filename='checkpoint_{}.pth.tar'.format(epoch + 1)) | |
# Self Correction Cycle with Model Aggregation | |
if (epoch + 1) >= args.schp_start and (epoch + 1 - args.schp_start) % args.cycle_epochs == 0: | |
print('Self-correction cycle number {}'.format(cycle_n)) | |
schp.moving_average(schp_model, model, 1.0 / (cycle_n + 1)) | |
cycle_n += 1 | |
schp.bn_re_estimate(train_loader, schp_model) | |
schp.save_schp_checkpoint({ | |
'state_dict': schp_model.state_dict(), | |
'cycle_n': cycle_n, | |
}, False, args.log_dir, filename='schp_{}_checkpoint.pth.tar'.format(cycle_n)) | |
torch.cuda.empty_cache() | |
end = timeit.default_timer() | |
print('epoch = {} of {} completed using {} s'.format(epoch, args.epochs, | |
(end - start) / (epoch - start_epoch + 1))) | |
end = timeit.default_timer() | |
print('Training Finished in {} seconds'.format(end - start)) | |
if __name__ == '__main__': | |
main() | |