Text2Human / Text2Human /train_sampler.py
yumingj's picture
update
bde71cb
import argparse
import logging
import os
import os.path as osp
import random
import time
import torch
from data.segm_attr_dataset import DeepFashionAttrSegmDataset
from models import create_model
from utils.logger import MessageLogger, get_root_logger, init_tb_logger
from utils.options import dict2str, dict_to_nonedict, parse
from utils.util import make_exp_dirs
def main():
# options
parser = argparse.ArgumentParser()
parser.add_argument('-opt', type=str, help='Path to option YAML file.')
args = parser.parse_args()
opt = parse(args.opt, is_train=True)
# mkdir and loggers
make_exp_dirs(opt)
log_file = osp.join(opt['path']['log'], f"train_{opt['name']}.log")
logger = get_root_logger(
logger_name='base', log_level=logging.INFO, log_file=log_file)
logger.info(dict2str(opt))
# initialize tensorboard logger
tb_logger = None
if opt['use_tb_logger'] and 'debug' not in opt['name']:
tb_logger = init_tb_logger(log_dir='./tb_logger/' + opt['name'])
# convert to NoneDict, which returns None for missing keys
opt = dict_to_nonedict(opt)
# set up data loader
train_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['train_ann_file'],
xflip=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opt['batch_size'],
shuffle=True,
num_workers=opt['num_workers'],
persistent_workers=True,
drop_last=True)
logger.info(f'Number of train set: {len(train_dataset)}.')
opt['max_iters'] = opt['num_epochs'] * len(
train_dataset) // opt['batch_size']
val_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['train_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['val_ann_file'])
val_loader = torch.utils.data.DataLoader(
dataset=val_dataset, batch_size=opt['batch_size'], shuffle=False)
logger.info(f'Number of val set: {len(val_dataset)}.')
test_dataset = DeepFashionAttrSegmDataset(
img_dir=opt['test_img_dir'],
segm_dir=opt['segm_dir'],
pose_dir=opt['pose_dir'],
ann_dir=opt['test_ann_file'])
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset, batch_size=opt['batch_size'], shuffle=False)
logger.info(f'Number of test set: {len(test_dataset)}.')
current_iter = 0
model = create_model(opt)
data_time, iter_time = 0, 0
current_iter = 0
# create message logger (formatted outputs)
msg_logger = MessageLogger(opt, current_iter, tb_logger)
for epoch in range(opt['num_epochs']):
lr = model.update_learning_rate(epoch, current_iter)
for _, batch_data in enumerate(train_loader):
data_time = time.time() - data_time
current_iter += 1
model.feed_data(batch_data)
model.optimize_parameters()
iter_time = time.time() - iter_time
if current_iter % opt['print_freq'] == 0:
log_vars = {'epoch': epoch, 'iter': current_iter}
log_vars.update({'lrs': [lr]})
log_vars.update({'time': iter_time, 'data_time': data_time})
log_vars.update(model.get_current_log())
msg_logger(log_vars)
data_time = time.time()
iter_time = time.time()
if epoch % opt['val_freq'] == 0 and epoch != 0:
save_dir = f'{opt["path"]["visualization"]}/valset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
model.inference(val_loader, save_dir)
save_dir = f'{opt["path"]["visualization"]}/testset/epoch_{epoch:03d}' # noqa
os.makedirs(save_dir, exist_ok=opt['debug'])
model.inference(test_loader, save_dir)
# save model
model.save_network(
model._denoise_fn,
f'{opt["path"]["models"]}/sampler_epoch{epoch}.pth')
if __name__ == '__main__':
main()