#!/bin/env python """Train a GAN. Usage: * Train a MNIST model: `python train_gan.py` * Train a Quickdraw model: `python train_gan.py --task quickdraw` """ import argparse import os import numpy as np import torch as th from torch.utils.data import DataLoader import ttools import ttools.interfaces import losses import data import models import pydiffvg LOG = ttools.get_logger(__name__) BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), os.pardir) OUTPUT = os.path.join(BASE_DIR, "results") class Callback(ttools.callbacks.ImageDisplayCallback): """Simple callback that visualize images.""" def visualized_image(self, batch, step_data, is_val=False): if is_val: return gen = step_data["gen_image"][:16].detach() ref = step_data["gt_image"][:16].detach() # tensor to visualize, concatenate images vizdata = th.cat([ref, gen], 2) vector = step_data["vector_image"] if vector is not None: vector = vector[:16].detach() vizdata = th.cat([vizdata, vector], 2) vizdata = (vizdata + 1.0 ) * 0.5 viz = th.clamp(vizdata, 0, 1) return viz def caption(self, batch, step_data, is_val=False): if step_data["vector_image"] is not None: s = "top: real, middle: raster, bottom: vector" else: s = "top: real, bottom: fake" return s class Interface(ttools.ModelInterface): def __init__(self, generator, vect_generator, discriminator, vect_discriminator, lr=1e-4, lr_decay=0.9999, gradient_penalty=10, wgan_gp=False, raster_resolution=32, device="cpu", grad_clip=1.0): super(Interface, self).__init__() self.wgan_gp = wgan_gp self.w_gradient_penalty = gradient_penalty self.n_critic = 1 if self.wgan_gp: self.n_critic = 5 self.grad_clip = grad_clip self.raster_resolution = raster_resolution self.gen = generator self.vect_gen = vect_generator self.discrim = discriminator self.vect_discrim = vect_discriminator self.device = device self.gen.to(self.device) self.discrim.to(self.device) beta1 = 0.5 beta2 = 0.9 self.gen_opt = th.optim.Adam( self.gen.parameters(), lr=lr, betas=(beta1, beta2)) self.discrim_opt = th.optim.Adam( self.discrim.parameters(), lr=lr, betas=(beta1, beta2)) self.schedulers = [ th.optim.lr_scheduler.ExponentialLR(self.gen_opt, lr_decay), th.optim.lr_scheduler.ExponentialLR(self.discrim_opt, lr_decay), ] self.optimizers = [self.gen_opt, self.discrim_opt] if self.vect_gen is not None: assert self.vect_discrim is not None self.vect_gen.to(self.device) self.vect_discrim.to(self.device) self.vect_gen_opt = th.optim.Adam( self.vect_gen.parameters(), lr=lr, betas=(beta1, beta2)) self.vect_discrim_opt = th.optim.Adam( self.vect_discrim.parameters(), lr=lr, betas=(beta1, beta2)) self.schedulers += [ th.optim.lr_scheduler.ExponentialLR(self.vect_gen_opt, lr_decay), th.optim.lr_scheduler.ExponentialLR(self.vect_discrim_opt, lr_decay), ] self.optimizers += [self.vect_gen_opt, self.vect_discrim_opt] # include loss on alpha self.im_loss = losses.MultiscaleMSELoss(channels=4).to(self.device) self.iter = 0 self.cross_entropy = th.nn.BCEWithLogitsLoss() self.mse = th.nn.MSELoss() def _gradient_penalty(self, discrim, fake, real): bs = real.size(0) epsilon = th.rand(bs, 1, 1, 1, device=real.device) epsilon = epsilon.expand_as(real) interpolation = epsilon * real.data + (1 - epsilon) * fake.data interpolation = th.autograd.Variable(interpolation, requires_grad=True) interpolation_logits = discrim(interpolation) grad_outputs = th.ones(interpolation_logits.size(), device=real.device) gradients = th.autograd.grad(outputs=interpolation_logits, inputs=interpolation, grad_outputs=grad_outputs, create_graph=True, retain_graph=True)[0] gradients = gradients.view(bs, -1) gradients_norm = th.sqrt(th.sum(gradients ** 2, dim=1) + 1e-12) # [Tanh-Tung 2019] https://openreview.net/pdf?id=ByxPYjC5KQ return self.w_gradient_penalty * ((gradients_norm - 0) ** 2).mean() # return self.w_gradient_penalty * ((gradients_norm - 1) ** 2).mean() def _discriminator_step(self, discrim, opt, fake, real): """Try to classify fake as 0 and real as 1.""" opt.zero_grad() # no backprop to gen fake = fake.detach() fake_pred = discrim(fake) real_pred = discrim(real) if self.wgan_gp: gradient_penalty = self._gradient_penalty(discrim, fake, real) loss_d = fake_pred.mean() - real_pred.mean() + gradient_penalty gradient_penalty = gradient_penalty.item() else: fake_loss = self.cross_entropy(fake_pred, th.zeros_like(fake_pred)) real_loss = self.cross_entropy(real_pred, th.ones_like(real_pred)) # fake_loss = self.mse(fake_pred, th.zeros_like(fake_pred)) # real_loss = self.mse(real_pred, th.ones_like(real_pred)) loss_d = 0.5*(fake_loss + real_loss) gradient_penalty = None loss_d.backward() nrm = th.nn.utils.clip_grad_norm_( discrim.parameters(), self.grad_clip) if nrm > self.grad_clip: LOG.debug("Clipped discriminator gradient (%.5f) to %.2f", nrm, self.grad_clip) opt.step() return loss_d.item(), gradient_penalty def _generator_step(self, gen, discrim, opt, fake): """Try to classify fake as 1.""" opt.zero_grad() fake_pred = discrim(fake) if self.wgan_gp: loss_g = -fake_pred.mean() else: loss_g = self.cross_entropy(fake_pred, th.ones_like(fake_pred)) # loss_g = self.mse(fake_pred, th.ones_like(fake_pred)) loss_g.backward() # clip gradients nrm = th.nn.utils.clip_grad_norm_( gen.parameters(), self.grad_clip) if nrm > self.grad_clip: LOG.debug("Clipped generator gradient (%.5f) to %.2f", nrm, self.grad_clip) opt.step() return loss_g.item() def training_step(self, batch): im = batch im = im.to(self.device) z = self.gen.sample_z(im.shape[0], device=self.device) generated = self.gen(z) vect_generated = None if self.vect_gen is not None: vect_generated = self.vect_gen(z) loss_g = None loss_d = None loss_g_vect = None loss_d_vect = None gp = None gp_vect = None if self.iter < self.n_critic: # Discriminator update self.iter += 1 loss_d, gp = self._discriminator_step( self.discrim, self.discrim_opt, generated, im) if vect_generated is not None: loss_d_vect, gp_vect = self._discriminator_step( self.vect_discrim, self.vect_discrim_opt, vect_generated, im) else: # Generator update self.iter = 0 loss_g = self._generator_step( self.gen, self.discrim, self.gen_opt, generated) if vect_generated is not None: loss_g_vect = self._generator_step( self.vect_gen, self.vect_discrim, self.vect_gen_opt, vect_generated) return { "loss_g": loss_g, "loss_d": loss_d, "loss_g_vect": loss_g_vect, "loss_d_vect": loss_d_vect, "gp": gp, "gp_vect": gp_vect, "gt_image": im, "gen_image": generated, "vector_image": vect_generated, "lr": self.gen_opt.param_groups[0]["lr"], } def init_validation(self): return dict(sample=None) def validation_step(self, batch, running_data): # Switch to eval mode for dropout, batchnorm, etc self.model.eval() return running_data def train(args): th.manual_seed(0) np.random.seed(0) color_output = False if args.task == "mnist": dataset = data.MNISTDataset(args.raster_resolution, train=True) elif args.task == "quickdraw": dataset = data.QuickDrawImageDataset( args.raster_resolution, train=True) else: raise NotImplementedError() dataloader = DataLoader( dataset, batch_size=args.bs, num_workers=args.workers, shuffle=True) val_dataloader = None model_params = { "zdim": args.zdim, "num_strokes": args.num_strokes, "imsize": args.raster_resolution, "stroke_width": args.stroke_width, "color_output": color_output, } gen = models.Generator(**model_params) gen.train() discrim = models.Discriminator(color_output=color_output) discrim.train() if args.raster_only: vect_gen = None vect_discrim = None else: if args.generator == "fc": vect_gen = models.VectorGenerator(**model_params) elif args.generator == "bezier_fc": vect_gen = models.BezierVectorGenerator(**model_params) elif args.generator in ["rnn"]: vect_gen = models.RNNVectorGenerator(**model_params) elif args.generator in ["chain_rnn"]: vect_gen = models.ChainRNNVectorGenerator(**model_params) else: raise NotImplementedError() vect_gen.train() vect_discrim = models.Discriminator(color_output=color_output) vect_discrim.train() LOG.info("Model parameters:\n%s", model_params) device = "cpu" if th.cuda.is_available(): device = "cuda" LOG.info("Using CUDA") interface = Interface(gen, vect_gen, discrim, vect_discrim, raster_resolution=args.raster_resolution, lr=args.lr, wgan_gp=args.wgan_gp, lr_decay=args.lr_decay, device=device) env_name = args.task + "_gan" if args.raster_only: env_name += "_raster" else: env_name += "_vector" env_name += "_" + args.generator if args.wgan_gp: env_name += "_wgan" chkpt = os.path.join(OUTPUT, env_name) meta = { "model_params": model_params, "task": args.task, "generator": args.generator, } checkpointer = ttools.Checkpointer( chkpt, gen, meta=meta, optimizers=interface.optimizers, schedulers=interface.schedulers, prefix="g_") checkpointer_d = ttools.Checkpointer( chkpt, discrim, prefix="d_") # Resume from checkpoint, if any extras, _ = checkpointer.load_latest() checkpointer_d.load_latest() if not args.raster_only: checkpointer_vect = ttools.Checkpointer( chkpt, vect_gen, meta=meta, optimizers=interface.optimizers, schedulers=interface.schedulers, prefix="vect_g_") checkpointer_d_vect = ttools.Checkpointer( chkpt, vect_discrim, prefix="vect_d_") extras, _ = checkpointer_vect.load_latest() checkpointer_d_vect.load_latest() epoch = extras["epoch"] if extras and "epoch" in extras.keys() else 0 # if meta is not None and meta["model_parameters"] != model_params: # LOG.info("Checkpoint's metaparams differ " # "from CLI, aborting: %s and %s", meta, model_params) trainer = ttools.Trainer(interface) # Add callbacks losses = ["loss_g", "loss_d", "loss_g_vect", "loss_d_vect", "gp", "gp_vect"] training_debug = ["lr"] trainer.add_callback(Callback( env=env_name, win="samples", port=args.port, frequency=args.freq)) trainer.add_callback(ttools.callbacks.ProgressBarCallback( keys=losses, val_keys=None)) trainer.add_callback(ttools.callbacks.MultiPlotCallback( keys=losses, val_keys=None, env=env_name, port=args.port, server=args.server, base_url=args.base_url, win="losses", frequency=args.freq)) trainer.add_callback(ttools.callbacks.VisdomLoggingCallback( keys=training_debug, smoothing=0, val_keys=None, env=env_name, server=args.server, base_url=args.base_url, port=args.port)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer, max_files=2, interval=600, max_epochs=10)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer_d, max_files=2, interval=600, max_epochs=10)) if not args.raster_only: trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer_vect, max_files=2, interval=600, max_epochs=10)) trainer.add_callback(ttools.callbacks.CheckpointingCallback( checkpointer_d_vect, max_files=2, interval=600, max_epochs=10)) trainer.add_callback( ttools.callbacks.LRSchedulerCallback(interface.schedulers)) # Start training trainer.train(dataloader, starting_epoch=epoch, val_dataloader=val_dataloader, num_epochs=args.num_epochs) if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--task", default="mnist", choices=["mnist", "quickdraw"]) parser.add_argument("--generator", default="bezier_fc", choices=["bezier_fc", "fc", "rnn", "chain_rnn"], help="model to use as generator") parser.add_argument("--raster_only", action="store_true", default=False, help="if true only train the raster baseline") parser.add_argument("--standard_gan", dest="wgan_gp", action="store_false", default=True, help="if true, use regular GAN instead of WGAN") # Training params parser.add_argument("--bs", type=int, default=4, help="batch size") parser.add_argument("--workers", type=int, default=4, help="number of dataloader threads") parser.add_argument("--num_epochs", type=int, default=200, help="number of epochs to train for") parser.add_argument("--lr", type=float, default=1e-4, help="learning rate") parser.add_argument("--lr_decay", type=float, default=0.9999, help="exponential learning rate decay rate") # Model configuration parser.add_argument("--zdim", type=int, default=32, help="latent space dimension") parser.add_argument("--stroke_width", type=float, nargs=2, default=(0.5, 1.5), help="min and max stroke width") parser.add_argument("--num_strokes", type=int, default=16, help="number of strokes to generate") parser.add_argument("--raster_resolution", type=int, default=32, help="raster canvas resolution on each side") # Viz params parser.add_argument("--freq", type=int, default=10, help="visualization frequency") parser.add_argument("--port", type=int, default=8097, help="visdom port") parser.add_argument("--server", default=None, help="visdom server if not local.") parser.add_argument("--base_url", default="", help="visdom entrypoint URL") args = parser.parse_args() pydiffvg.set_use_gpu(False) ttools.set_logger(False) train(args)