#!/usr/bin/python # -*- encoding: utf-8 -*- from logger import setup_logger from model import BiSeNet from face_dataset import FaceMask from loss import OhemCELoss from evaluate import evaluate from optimizer import Optimizer import cv2 import numpy as np import torch import torch.nn as nn from torch.utils.data import DataLoader import torch.nn.functional as F import torch.distributed as dist import os import os.path as osp import logging import time import datetime import argparse respth = './res' if not osp.exists(respth): os.makedirs(respth) logger = logging.getLogger() def parse_args(): parse = argparse.ArgumentParser() parse.add_argument( '--local_rank', dest = 'local_rank', type = int, default = -1, ) return parse.parse_args() def train(): args = parse_args() torch.cuda.set_device(args.local_rank) dist.init_process_group( backend = 'nccl', init_method = 'tcp://127.0.0.1:33241', world_size = torch.cuda.device_count(), rank=args.local_rank ) setup_logger(respth) # dataset n_classes = 19 n_img_per_gpu = 16 n_workers = 8 cropsize = [448, 448] data_root = '/home/zll/data/CelebAMask-HQ/' ds = FaceMask(data_root, cropsize=cropsize, mode='train') sampler = torch.utils.data.distributed.DistributedSampler(ds) dl = DataLoader(ds, batch_size = n_img_per_gpu, shuffle = False, sampler = sampler, num_workers = n_workers, pin_memory = True, drop_last = True) # model ignore_idx = -100 net = BiSeNet(n_classes=n_classes) net.cuda() net.train() net = nn.parallel.DistributedDataParallel(net, device_ids = [args.local_rank, ], output_device = args.local_rank ) score_thres = 0.7 n_min = n_img_per_gpu * cropsize[0] * cropsize[1]//16 LossP = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) Loss2 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) Loss3 = OhemCELoss(thresh=score_thres, n_min=n_min, ignore_lb=ignore_idx) ## optimizer momentum = 0.9 weight_decay = 5e-4 lr_start = 1e-2 max_iter = 80000 power = 0.9 warmup_steps = 1000 warmup_start_lr = 1e-5 optim = Optimizer( model = net.module, lr0 = lr_start, momentum = momentum, wd = weight_decay, warmup_steps = warmup_steps, warmup_start_lr = warmup_start_lr, max_iter = max_iter, power = power) ## train loop msg_iter = 50 loss_avg = [] st = glob_st = time.time() diter = iter(dl) epoch = 0 for it in range(max_iter): try: im, lb = next(diter) if not im.size()[0] == n_img_per_gpu: raise StopIteration except StopIteration: epoch += 1 sampler.set_epoch(epoch) diter = iter(dl) im, lb = next(diter) im = im.cuda() lb = lb.cuda() H, W = im.size()[2:] lb = torch.squeeze(lb, 1) optim.zero_grad() out, out16, out32 = net(im) lossp = LossP(out, lb) loss2 = Loss2(out16, lb) loss3 = Loss3(out32, lb) loss = lossp + loss2 + loss3 loss.backward() optim.step() loss_avg.append(loss.item()) # print training log message if (it+1) % msg_iter == 0: loss_avg = sum(loss_avg) / len(loss_avg) lr = optim.lr ed = time.time() t_intv, glob_t_intv = ed - st, ed - glob_st eta = int((max_iter - it) * (glob_t_intv / it)) eta = str(datetime.timedelta(seconds=eta)) msg = ', '.join([ 'it: {it}/{max_it}', 'lr: {lr:4f}', 'loss: {loss:.4f}', 'eta: {eta}', 'time: {time:.4f}', ]).format( it = it+1, max_it = max_iter, lr = lr, loss = loss_avg, time = t_intv, eta = eta ) logger.info(msg) loss_avg = [] st = ed if dist.get_rank() == 0: if (it+1) % 5000 == 0: state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() if dist.get_rank() == 0: torch.save(state, './res/cp/{}_iter.pth'.format(it)) evaluate(dspth='/home/zll/data/CelebAMask-HQ/test-img', cp='{}_iter.pth'.format(it)) # dump the final model save_pth = osp.join(respth, 'model_final_diss.pth') # net.cpu() state = net.module.state_dict() if hasattr(net, 'module') else net.state_dict() if dist.get_rank() == 0: torch.save(state, save_pth) logger.info('training done, model saved to: {}'.format(save_pth)) if __name__ == "__main__": train()