TextureScraping / swapae /optimizers /swapping_autoencoder_optimizer.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame contribute delete
No virus
4.69 kB
import torch
import swapae.util as util
from swapae.models import MultiGPUModelWrapper
from swapae.optimizers.base_optimizer import BaseOptimizer
class SwappingAutoencoderOptimizer(BaseOptimizer):
""" Class for running the optimization of the model parameters.
Implements Generator / Discriminator training, R1 gradient penalty,
decaying learning rates, and reporting training progress.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser.add_argument("--lr", default=0.002, type=float)
parser.add_argument("--beta1", default=0.0, type=float)
parser.add_argument("--beta2", default=0.99, type=float)
parser.add_argument(
"--R1_once_every", default=16, type=int,
help="lazy R1 regularization. R1 loss is computed "
"once in 1/R1_freq times",
)
return parser
def __init__(self, model: MultiGPUModelWrapper):
self.opt = model.opt
opt = self.opt
self.model = model
self.train_mode_counter = 0
self.discriminator_iter_counter = 0
self.Gparams = self.model.get_parameters_for_mode("generator")
self.Dparams = self.model.get_parameters_for_mode("discriminator")
self.optimizer_G = torch.optim.Adam(
self.Gparams, lr=opt.lr, betas=(opt.beta1, opt.beta2)
)
# c.f. StyleGAN2 (https://arxiv.org/abs/1912.04958) Appendix B
c = opt.R1_once_every / (1 + opt.R1_once_every)
self.optimizer_D = torch.optim.Adam(
self.Dparams, lr=opt.lr * c, betas=(opt.beta1 ** c, opt.beta2 ** c)
)
def set_requires_grad(self, params, requires_grad):
""" For more efficient optimization, turn on and off
recording of gradients for |params|.
"""
for p in params:
p.requires_grad_(requires_grad)
def prepare_images(self, data_i):
return data_i["real_A"]
def toggle_training_mode(self):
modes = ["discriminator", "generator"]
self.train_mode_counter = (self.train_mode_counter + 1) % len(modes)
return modes[self.train_mode_counter]
def train_one_step(self, data_i, total_steps_so_far):
images_minibatch = self.prepare_images(data_i)
if self.toggle_training_mode() == "generator":
losses = self.train_discriminator_one_step(images_minibatch)
else:
losses = self.train_generator_one_step(images_minibatch)
return util.to_numpy(losses)
def train_generator_one_step(self, images):
self.set_requires_grad(self.Dparams, False)
self.set_requires_grad(self.Gparams, True)
sp_ma, gl_ma = None, None
self.optimizer_G.zero_grad()
g_losses, g_metrics = self.model(
images, sp_ma, gl_ma, command="compute_generator_losses"
)
g_loss = sum([v.mean() for v in g_losses.values()])
g_loss.backward()
self.optimizer_G.step()
g_losses.update(g_metrics)
return g_losses
def train_discriminator_one_step(self, images):
if self.opt.lambda_GAN == 0.0 and self.opt.lambda_PatchGAN == 0.0:
return {}
self.set_requires_grad(self.Dparams, True)
self.set_requires_grad(self.Gparams, False)
self.discriminator_iter_counter += 1
self.optimizer_D.zero_grad()
d_losses, d_metrics, sp, gl = self.model(
images, command="compute_discriminator_losses"
)
self.previous_sp = sp.detach()
self.previous_gl = gl.detach()
d_loss = sum([v.mean() for v in d_losses.values()])
d_loss.backward()
self.optimizer_D.step()
needs_R1 = self.opt.lambda_R1 > 0.0 or self.opt.lambda_patch_R1 > 0.0
needs_R1_at_current_iter = needs_R1 and \
self.discriminator_iter_counter % self.opt.R1_once_every == 0
if needs_R1_at_current_iter:
self.optimizer_D.zero_grad()
r1_losses = self.model(images, command="compute_R1_loss")
d_losses.update(r1_losses)
r1_loss = sum([v.mean() for v in r1_losses.values()])
r1_loss = r1_loss * self.opt.R1_once_every
r1_loss.backward()
self.optimizer_D.step()
d_losses["D_total"] = sum([v.mean() for v in d_losses.values()])
d_losses.update(d_metrics)
return d_losses
def get_visuals_for_snapshot(self, data_i):
images = self.prepare_images(data_i)
with torch.no_grad():
return self.model(images, command="get_visuals_for_snapshot")
def save(self, total_steps_so_far):
self.model.save(total_steps_so_far)