import argparse import time import datetime from data import CustomDataLoader from data.super_dataset import SuperDataset from models import create_model from configs import parse_config from utils.util import print_losses, check_path, make_grid, AverageMeter from utils.data_utils import check_old_config_val_possible import os import torch.multiprocessing as mp import torch.distributed as dist from torch.utils.tensorboard import SummaryWriter import copy import sys def main(): # parse arguments parser = argparse.ArgumentParser(description='Style Master') parser.add_argument('--cfg_file', type=str, default='./exp/sp2pII-phase4.yaml') parser.add_argument('--image_prompt', type=str, default='./example/reference/01.png') parser.add_argument('--text_prompt', type=str, default='sketch with black pencil') parser.add_argument('--name', type=str, default='01') args = parser.parse_args() # parse config config = parse_config(args.cfg_file) if args.image_prompt: config['training']['image_prompt'] = args.image_prompt if args.text_prompt: config['training']['text_prompt'] = args.text_prompt if args.name: config['common']['name'] = args.name for group in config: print(group + ':') for k, v in config[group].items(): print(' {}: {}'.format(k, v)) # we want GPU ids match nvidia-smi output order, so do some manipulations here. # GPU ids need to always start from 0, but the system variable CUDA_VISIBLE_DEVICES can be set to e.g. GPU 2 and 3. gpu_string = ','.join(map(str, config['common']['gpu_ids'])) gpu_ids_fix = list(range(len(config['common']['gpu_ids']))) os.environ['CUDA_VISIBLE_DEVICES'] = gpu_string config['common']['gpu_ids'] = gpu_ids_fix if config['training']['DDP']: num_nodes = config['training']['num_nodes'] num_gpus = len(config['common']['gpu_ids']) config['training']['world_size'] = num_gpus * num_nodes os.environ['MASTER_ADDR'] = config['training']['DDP_address'] os.environ['MASTER_PORT'] = config['training']['DDP_port'] mp.spawn(train_val, nprocs=num_gpus, args=(config,)) else: # under DP mode, shall set batch size to (actual effective batch size * num_gpu) config['dataset']['batch_size'] *= len(config['common']['gpu_ids']) if 'patch_batch_size' in config['dataset']: config['dataset']['patch_batch_size'] *= len(config['common']['gpu_ids']) train_val(None, config) # GPU parameter is automatically filled when using DDP. It is an irrelevant placeholder if not using DDP. def train_val(gpu, config): import torch if config['training']['val']: config_val = copy.deepcopy(config) config_val['common']['phase'] = 'val' if config['training']['DDP']: dist.init_process_group( backend='nccl', init_method='env://', world_size=config['training']['world_size'], rank=gpu ) torch.cuda.set_device(gpu) # DDP requirement # Dataset and dataloader construction. # prepare train data train_dataset = SuperDataset(config, shuffle=True, check_all_data=config['dataset']['check_all_data'], DDP_device=gpu) if config['dataset']['train_data'] == {}: print("old data config format detected. Converting to new format.") train_dataset.config = train_dataset.convert_old_config_to_new() train_dataset.static_data.load_static_data() train_dataset.static_data.create_transforms() if train_dataset.check_all_data: train_dataset.check_data() if len(train_dataset) == 0: if gpu == 0 or gpu is None: print("Train set has 0 data samples. Exiting.") sys.exit(0) _, train_video_dataset = train_dataset.split_data('count', 5, mode='copy_partial') # prepare val data force_use_train_data = False if config['training']['val']: val_dataset = SuperDataset(config_val) if len(config_val['dataset']['train_data']) == 0: if check_old_config_val_possible(config_val): val_dataset.convert_old_config_to_new() val_dataset.static_data.load_static_data() else: force_use_train_data = True else: val_dataset.static_data.load_static_data() if len(val_dataset) == 0 or force_use_train_data: if gpu == 0 or gpu is None: print("Validation set has 0 data samples. Using part of training data for validation.") validation_ratio = config['training']['val_percent']/100 train_dataset, val_dataset = train_dataset.split_data('ratio', validation_ratio, mode='split') if len(val_dataset) == 0: print("There are too few training data to establish a validation set. " "Use the training set as validation set.") val_dataset = copy.deepcopy(train_dataset) val_dataset.static_data.create_transforms() if val_dataset.check_all_data: val_dataset.check_data() # print dataset info if gpu == 0 or gpu is None: print("--------train dataset static data content-----------") for i, cnt in enumerate(train_dataset.static_data.len_of_groups): print("%s: %d" % (train_dataset.static_data.group_names[i], cnt)) print("----------------------------------------") if config['training']['val']: print("--------val dataset static data content-----------") for i, cnt in enumerate(val_dataset.static_data.len_of_groups): print("group %s: %d" % (val_dataset.static_data.group_names[i], cnt)) print("----------------------------------------") # prepare dataloaders train_dataloader = CustomDataLoader(config, train_dataset, DDP_gpu=gpu, drop_last=config['dataset']['drop_last']) if config['training']['val']: val_dataloader = CustomDataLoader(config_val, val_dataset, DDP_gpu=gpu, drop_last=config['dataset']['drop_last']) config_train_video = copy.deepcopy(config) config_train_video['dataset']['serial_batches'] = True config_train_video['batch_size'] = 1 train_video_dataloader = CustomDataLoader(config_train_video, train_video_dataset, DDP_gpu=None, drop_last=False) if config['training']['DDP']: model = create_model(config, DDP_device=gpu) # create a DDP model given opt.model and other options model.setup(config, DDP_device=gpu) # regular setup: load and print networks; create schedulers else: model = create_model(config) # create a model (singleGPU or dataparallel) given opt.model and other options model.setup(config) # regular setup: load and print networks; create schedulers total_iters = model.total_iters # the total number of training iterations # visualization setups if gpu == 0 or gpu is None: log_dir = os.path.join(config['training']['log_dir'], config['common']['name'] + '_' + datetime.datetime.now().strftime("%Y_%m_%d__%H_%M_%S")) check_path(log_dir) writer = SummaryWriter(log_dir) train_loss_averages = {} val_loss_averages = {} if config['training']['epoch_as_iter']: iter_ub = config['training']['n_epochs'] + config['training']['n_epochs_decay'] while total_iters < iter_ub: # Training loop if gpu == 0 or gpu is None: print('-------------------------------Training----------------------------------') model.train() quit = False for i, data in enumerate(train_dataloader): total_iters += 1 if total_iters >= iter_ub: quit = True break model.curr_epoch = total_iters model.total_iters = total_iters model.set_input(data) model.optimize_parameters() losses = model.get_current_losses() for k, v in losses.items(): if k not in train_loss_averages: train_loss_averages[k] = AverageMeter() train_loss_averages[k].update(v) if (total_iters % config['training']['print_freq'] == 0) and ((gpu == 0) or (gpu is None)): print_losses(total_iters, 1, 1, losses, train_loss_averages) tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Training/' + name, img_grid, total_iters) else: writer.add_image('iteration ' + str(total_iters) + ' training ' + name, img_grid) for k, v in losses.items(): writer.add_scalar('Training/' + k, train_loss_averages[k].avg(), total_iters) if (total_iters % config['training']['save_latest_freq'] == 0) and ((gpu == 0) or (gpu is None)): print('saving the latest model (total_iters %d)' % total_iters) model.save_networks('latest') if (total_iters % config['training']['save_epoch_freq'] == 0) and ((gpu == 0) or (gpu is None)): print('saving the model at the end of iters %d' % total_iters) model.save_networks('latest') #model.save_networks(total_iters) model.update_learning_rate() for k, v in losses.items(): train_loss_averages[k].clear() if quit: exit(0) if config['training']['val']: # Validation loop if gpu == 0 or gpu is None: print('-------------------------------Validating----------------------------------') model.eval() for i, data in enumerate(val_dataloader): with torch.no_grad(): model.set_input(data) model.eval_step() losses = model.get_current_losses() for k, v in losses.items(): if k not in val_loss_averages: val_loss_averages[k] = AverageMeter() val_loss_averages[k].update(v) if gpu == 0 or gpu is None: print_losses(total_iters, 1, 1, losses, val_loss_averages, mode='validating') tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Validation/' + name, img_grid, total_iters) else: writer.add_image('iteration ' + str(total_iters) + ' validating ' + name, img_grid) for k, v in losses.items(): writer.add_scalar('Validation/' + k, val_loss_averages[k].avg(), total_iters) main_loss = 'G' if 'G' in losses else model.loss_names[0] if val_loss_averages[main_loss].avg() < model.best_val_loss and (gpu == 0 or gpu is None): model.best_val_loss = val_loss_averages[main_loss].avg() print('New validation best loss. saving the model.') model.save_networks('', val_loss=model.best_val_loss) for k, v in losses.items(): val_loss_averages[k].clear() if config['training']['save_training_progress']: # produce images on the same images every epoch to visualize how training is progressing. if gpu == 0 or gpu is None: for i, data in enumerate(train_video_dataloader): with torch.no_grad(): model.set_input(data) model.forward() tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Training Video/' + name + ' ' + str(i), img_grid, total_iters) else: writer.add_image('epoch 0 iteration ' + str(total_iters) + ' training_video ' + name + ' ' + str(i), img_grid) for epoch in range(model.curr_epoch, config['training']['n_epochs'] + config['training']['n_epochs_decay'] + 1): # outer loop for different epochs; we save the model by , + epoch_start_time = time.time() # timer for entire epoch iter_data_time = time.time() # timer for data loading per iteration epoch_iter = 0 # the number of training iterations in current epoch, reset to 0 every epoch model.curr_epoch = epoch # Training loop if gpu == 0 or gpu is None: print('-------------------------------Training----------------------------------') model.train() for i, data in enumerate(train_dataloader): # inner loop within one epoch iter_start_time = time.time() # timer for computation per iteration total_iters += 1 epoch_iter += 1 model.set_input(data) # unpack data from dataset and apply preprocessing model.optimize_parameters() # calculate loss functions, get gradients, update network weights losses = model.get_current_losses() for k, v in losses.items(): if k not in train_loss_averages: train_loss_averages[k] = AverageMeter() train_loss_averages[k].update(v) if epoch_iter % config['training']['print_freq'] == 0 and ((gpu == 0) or (gpu is None)): print_losses(epoch, epoch_iter, (len(train_dataset) + config['dataset']['batch_size'] - 1) // config['dataset']['batch_size'], losses, train_loss_averages) tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Training/' + name, img_grid, total_iters) else: writer.add_image('epoch ' + str(epoch) + ' iteration ' + str(total_iters) + ' training ' + name, img_grid) for k, v in losses.items(): writer.add_scalar('Training/' + k, train_loss_averages[k].avg(), total_iters) if total_iters % config['training']['save_latest_freq'] == 0 and ((gpu == 0) or (gpu is None)): # cache our latest model every iterations print('saving the latest model (epoch %d, total_iters %d)' % (epoch, total_iters)) model.save_networks('latest') iter_data_time = time.time() if gpu == 0 or gpu is None: print_losses(epoch, epoch_iter, (len(train_dataset) + config['dataset']['batch_size'] - 1) // config['dataset']['batch_size'], losses, train_loss_averages) print('End of epoch %d / %d \t Time Taken: %d sec' % (epoch, config['training']['n_epochs'] + config['training']['n_epochs_decay'], time.time() - epoch_start_time)) tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Training/' + name, img_grid, total_iters) else: writer.add_image('epoch ' + str(epoch) + ' iteration ' + str(total_iters) + ' training ' + name, img_grid) for k, v in losses.items(): writer.add_scalar('Training/' + k, train_loss_averages[k].avg(), total_iters) if epoch % config['training']['save_epoch_freq'] == 0 and ((gpu == 0) or (gpu is None)): # cache our model every epochs print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters)) model.save_networks('latest') model.save_networks(epoch) model.update_learning_rate() # update learning rates at the end of every epoch. for k, v in losses.items(): train_loss_averages[k].clear() if config['training']['val']: # Validation loop if gpu == 0 or gpu is None: print('-------------------------------Validating----------------------------------') model.eval() epoch_iter = 0 for i, data in enumerate(val_dataloader): # inner loop within one epoch epoch_iter += 1 with torch.no_grad(): model.set_input(data) model.eval_step() losses = model.get_current_losses() for k, v in losses.items(): if k not in val_loss_averages: val_loss_averages[k] = AverageMeter() val_loss_averages[k].update(v) if gpu == 0 or gpu is None: print_losses(epoch, epoch_iter, (len(val_dataset) + config['dataset']['batch_size'] - 1) // config['dataset']['batch_size'], losses, val_loss_averages, mode='validating') tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Validation/' + name, img_grid, total_iters) else: writer.add_image('epoch ' + str(epoch) + ' iteration ' + str(total_iters) + ' validating ' + name, img_grid) for k, v in losses.items(): writer.add_scalar('Validation/' + k, val_loss_averages[k].avg(), total_iters) main_loss = 'G' if 'G' in losses else model.loss_names[0] if val_loss_averages[main_loss].avg() < model.best_val_loss and gpu == 0 or gpu is None: model.best_val_loss = val_loss_averages[main_loss].avg() print('New validation best loss. saving the model.') model.save_networks('', val_loss=model.best_val_loss) for k, v in losses.items(): val_loss_averages[k].clear() if config['training']['save_training_progress']: # produce images on the same images every epoch to visualize how training is progressing. if gpu == 0 or gpu is None: for i, data in enumerate(train_video_dataloader): with torch.no_grad(): model.set_input(data) model.forward() tmp = make_grid(model) for j in range(0, len(tmp[0])): img_grid, name = tmp[0][j], tmp[1][j] if config['training']['use_new_log']: writer.add_image('Training Video/' + name + ' ' + str(i), img_grid, total_iters) else: writer.add_image('epoch ' + str(epoch) + ' iteration ' + str(total_iters) + ' training_video ' + name + ' ' + str(i), img_grid) # If we shut down process now, writer could save incomplete data. Wait a bit to let it finish. time.sleep(5) if __name__ == '__main__': main()