import torch from torch import nn import torch.optim as optim import torch.nn.functional as F from torch.utils.data.dataloader import DataLoader from torchvision import transforms from torchvision import utils as vutils import argparse import random from tqdm import tqdm from models import weights_init, Discriminator, Generator from operation import copy_G_params, load_params, get_dir from operation import ImageFolder, InfiniteSamplerWrapper from diffaug import DiffAugment policy = 'color,translation' import lpips percept = lpips.PerceptualLoss(model='net-lin', net='vgg', use_gpu=True) #torch.backends.cudnn.benchmark = True def crop_image_by_part(image, part): hw = image.shape[2]//2 if part==0: return image[:,:,:hw,:hw] if part==1: return image[:,:,:hw,hw:] if part==2: return image[:,:,hw:,:hw] if part==3: return image[:,:,hw:,hw:] def train_d(net, data, label="real"): """Train function of discriminator""" if label=="real": part = random.randint(0, 3) pred, [rec_all, rec_small, rec_part] = net(data, label, part=part) err = F.relu( torch.rand_like(pred) * 0.2 + 0.8 - pred).mean() + \ percept( rec_all, F.interpolate(data, rec_all.shape[2]) ).sum() +\ percept( rec_small, F.interpolate(data, rec_small.shape[2]) ).sum() +\ percept( rec_part, F.interpolate(crop_image_by_part(data, part), rec_part.shape[2]) ).sum() err.backward() return pred.mean().item(), rec_all, rec_small, rec_part else: pred = net(data, label) err = F.relu( torch.rand_like(pred) * 0.2 + 0.8 + pred).mean() err.backward() return pred.mean().item() def train(args): data_root = args.path total_iterations = args.iter checkpoint = args.ckpt batch_size = args.batch_size im_size = args.im_size ndf = 64 ngf = 64 nz = 256 nlr = 0.0002 nbeta1 = 0.5 use_cuda = True multi_gpu = True dataloader_workers = 8 current_iteration = 0 save_interval = 100 saved_model_folder, saved_image_folder = get_dir(args) device = torch.device("cpu") if use_cuda: device = torch.device("cuda:0") transform_list = [ transforms.Resize((int(im_size),int(im_size))), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) ] trans = transforms.Compose(transform_list) if 'lmdb' in data_root: from operation import MultiResolutionDataset dataset = MultiResolutionDataset(data_root, trans, 1024) else: dataset = ImageFolder(root=data_root, transform=trans) dataloader = iter(DataLoader(dataset, batch_size=batch_size, shuffle=False, sampler=InfiniteSamplerWrapper(dataset), num_workers=dataloader_workers, pin_memory=True)) ''' loader = MultiEpochsDataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=dataloader_workers, pin_memory=True) dataloader = CudaDataLoader(loader, 'cuda') ''' #from model_s import Generator, Discriminator netG = Generator(ngf=ngf, nz=nz, im_size=im_size) netG.apply(weights_init) netD = Discriminator(ndf=ndf, im_size=im_size) netD.apply(weights_init) netG.to(device) netD.to(device) avg_param_G = copy_G_params(netG) fixed_noise = torch.FloatTensor(8, nz).normal_(0, 1).to(device) if checkpoint != 'None': ckpt = torch.load(checkpoint) netG.load_state_dict(ckpt['g']) netD.load_state_dict(ckpt['d']) avg_param_G = ckpt['g_ema'] optimizerG.load_state_dict(ckpt['opt_g']) optimizerD.load_state_dict(ckpt['opt_d']) current_iteration = int(checkpoint.split('_')[-1].split('.')[0]) del ckpt if multi_gpu: netG = nn.DataParallel(netG.to(device)) netD = nn.DataParallel(netD.to(device)) optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999)) optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999)) for iteration in tqdm(range(current_iteration, total_iterations+1)): real_image = next(dataloader) real_image = real_image.to(device) current_batch_size = real_image.size(0) noise = torch.Tensor(current_batch_size, nz).normal_(0, 1).to(device) fake_images = netG(noise) real_image = DiffAugment(real_image, policy=policy) fake_images = [DiffAugment(fake, policy=policy) for fake in fake_images] ## 2. train Discriminator netD.zero_grad() err_dr, rec_img_all, rec_img_small, rec_img_part = train_d(netD, real_image, label="real") train_d(netD, [fi.detach() for fi in fake_images], label="fake") optimizerD.step() ## 3. train Generator netG.zero_grad() pred_g = netD(fake_images, "fake") err_g = -pred_g.mean() err_g.backward() optimizerG.step() for p, avg_p in zip(netG.parameters(), avg_param_G): avg_p.mul_(0.999).add_(0.001 * p.data) if iteration % 100 == 0: print("GAN: loss d: %.5f loss g: %.5f"%(err_dr, -err_g.item())) if iteration % (save_interval*10) == 0: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) with torch.no_grad(): vutils.save_image(netG(fixed_noise)[0].add(1).mul(0.5), saved_image_folder+'/%d.jpg'%iteration, nrow=4) vutils.save_image( torch.cat([ F.interpolate(real_image, 128), rec_img_all, rec_img_small, rec_img_part]).add(1).mul(0.5), saved_image_folder+'/rec_%d.jpg'%iteration ) load_params(netG, backup_para) if iteration % (save_interval*50) == 0 or iteration == total_iterations: backup_para = copy_G_params(netG) load_params(netG, avg_param_G) torch.save({'g':netG.state_dict(),'d':netD.state_dict()}, saved_model_folder+'/%d.pth'%iteration) load_params(netG, backup_para) torch.save({'g':netG.state_dict(), 'd':netD.state_dict(), 'g_ema': avg_param_G, 'opt_g': optimizerG.state_dict(), 'opt_d': optimizerD.state_dict()}, saved_model_folder+'/all_%d.pth'%iteration) if __name__ == "__main__": parser = argparse.ArgumentParser(description='region gan') parser.add_argument('--path', type=str, default='../lmdbs/art_landscape_1k', help='path of resource dataset, should be a folder that has one or many sub image folders inside') parser.add_argument('--cuda', type=int, default=1, help='index of gpu to use') parser.add_argument('--name', type=str, default='test1', help='experiment name') parser.add_argument('--iter', type=int, default=50000, help='number of iterations') parser.add_argument('--start_iter', type=int, default=0, help='the iteration to start training') parser.add_argument('--batch_size', type=int, default=8, help='mini batch number of images') parser.add_argument('--im_size', type=int, default=256, help='image resolution') parser.add_argument('--ckpt', type=str, default='None', help='checkpoint weight path if have one') args = parser.parse_args() print(args) train(args)