Spaces:
Runtime error
Runtime error
import torch | |
import swapae.util as util | |
from swapae.models import BaseModel | |
import swapae.models.networks as networks | |
import swapae.models.networks.loss as loss | |
class SwappingAutoencoderModel(BaseModel): | |
def modify_commandline_options(parser, is_train): | |
BaseModel.modify_commandline_options(parser, is_train) | |
parser.add_argument("--spatial_code_ch", default=8, type=int) | |
parser.add_argument("--global_code_ch", default=2048, type=int) | |
parser.add_argument("--lambda_R1", default=10.0, type=float) | |
parser.add_argument("--lambda_patch_R1", default=1.0, type=float) | |
parser.add_argument("--lambda_L1", default=1.0, type=float) | |
parser.add_argument("--lambda_GAN", default=1.0, type=float) | |
parser.add_argument("--lambda_PatchGAN", default=1.0, type=float) | |
parser.add_argument("--patch_min_scale", default=1 / 8, type=float) | |
parser.add_argument("--patch_max_scale", default=1 / 4, type=float) | |
parser.add_argument("--patch_num_crops", default=8, type=int) | |
parser.add_argument("--patch_use_aggregation", | |
type=util.str2bool, default=True) | |
return parser | |
def initialize(self): | |
self.E = networks.create_network(self.opt, self.opt.netE, "encoder") | |
self.G = networks.create_network(self.opt, self.opt.netG, "generator") | |
if self.opt.lambda_GAN > 0.0: | |
self.D = networks.create_network( | |
self.opt, self.opt.netD, "discriminator") | |
if self.opt.lambda_PatchGAN > 0.0: | |
self.Dpatch = networks.create_network( | |
self.opt, self.opt.netPatchD, "patch_discriminator" | |
) | |
# Count the iteration count of the discriminator | |
# Used for lazy R1 regularization (c.f. Appendix B of StyleGAN2) | |
self.register_buffer( | |
"num_discriminator_iters", torch.zeros(1, dtype=torch.long) | |
) | |
self.l1_loss = torch.nn.L1Loss() | |
if (not self.opt.isTrain) or self.opt.continue_train: | |
self.load() | |
if self.opt.num_gpus > 0: | |
self.to("cuda:0") | |
def per_gpu_initialize(self): | |
pass | |
def swap(self, x): | |
""" Swaps (or mixes) the ordering of the minibatch """ | |
shape = x.shape | |
assert shape[0] % 2 == 0, "Minibatch size must be a multiple of 2" | |
new_shape = [shape[0] // 2, 2] + list(shape[1:]) | |
x = x.view(*new_shape) | |
x = torch.flip(x, [1]) | |
return x.view(*shape) | |
def compute_image_discriminator_losses(self, real, rec, mix): | |
if self.opt.lambda_GAN == 0.0: | |
return {} | |
pred_real = self.D(real) | |
pred_rec = self.D(rec) | |
pred_mix = self.D(mix) | |
losses = {} | |
losses["D_real"] = loss.gan_loss( | |
pred_real, should_be_classified_as_real=True | |
) * self.opt.lambda_GAN | |
losses["D_rec"] = loss.gan_loss( | |
pred_rec, should_be_classified_as_real=False | |
) * (0.5 * self.opt.lambda_GAN) | |
losses["D_mix"] = loss.gan_loss( | |
pred_mix, should_be_classified_as_real=False | |
) * (0.5 * self.opt.lambda_GAN) | |
return losses | |
def get_random_crops(self, x, crop_window=None): | |
""" Make random crops. | |
Corresponds to the yellow and blue random crops of Figure 2. | |
""" | |
crops = util.apply_random_crop( | |
x, self.opt.patch_size, | |
(self.opt.patch_min_scale, self.opt.patch_max_scale), | |
num_crops=self.opt.patch_num_crops | |
) | |
return crops | |
def compute_patch_discriminator_losses(self, real, mix): | |
losses = {} | |
real_feat = self.Dpatch.extract_features( | |
self.get_random_crops(real), | |
aggregate=self.opt.patch_use_aggregation | |
) | |
target_feat = self.Dpatch.extract_features(self.get_random_crops(real)) | |
mix_feat = self.Dpatch.extract_features(self.get_random_crops(mix)) | |
losses["PatchD_real"] = loss.gan_loss( | |
self.Dpatch.discriminate_features(real_feat, target_feat), | |
should_be_classified_as_real=True, | |
) * self.opt.lambda_PatchGAN | |
losses["PatchD_mix"] = loss.gan_loss( | |
self.Dpatch.discriminate_features(real_feat, mix_feat), | |
should_be_classified_as_real=False, | |
) * self.opt.lambda_PatchGAN | |
return losses | |
def compute_discriminator_losses(self, real): | |
self.num_discriminator_iters.add_(1) | |
sp, gl = self.E(real) | |
B = real.size(0) | |
assert B % 2 == 0, "Batch size must be even on each GPU." | |
# To save memory, compute the GAN loss on only | |
# half of the reconstructed images | |
rec = self.G(sp[:B // 2], gl[:B // 2]) | |
mix = self.G(self.swap(sp), gl) | |
losses = self.compute_image_discriminator_losses(real, rec, mix) | |
if self.opt.lambda_PatchGAN > 0.0: | |
patch_losses = self.compute_patch_discriminator_losses(real, mix) | |
losses.update(patch_losses) | |
metrics = {} # no metrics to report for the Discriminator iteration | |
return losses, metrics, sp.detach(), gl.detach() | |
def compute_R1_loss(self, real): | |
losses = {} | |
if self.opt.lambda_R1 > 0.0: | |
real.requires_grad_() | |
pred_real = self.D(real).sum() | |
grad_real, = torch.autograd.grad( | |
outputs=pred_real, | |
inputs=[real], | |
create_graph=True, | |
retain_graph=True, | |
) | |
grad_real2 = grad_real.pow(2) | |
dims = list(range(1, grad_real2.ndim)) | |
grad_penalty = grad_real2.sum(dims) * (self.opt.lambda_R1 * 0.5) | |
else: | |
grad_penalty = 0.0 | |
if self.opt.lambda_patch_R1 > 0.0: | |
real_crop = self.get_random_crops(real).detach() | |
real_crop.requires_grad_() | |
target_crop = self.get_random_crops(real).detach() | |
target_crop.requires_grad_() | |
real_feat = self.Dpatch.extract_features( | |
real_crop, | |
aggregate=self.opt.patch_use_aggregation) | |
target_feat = self.Dpatch.extract_features(target_crop) | |
pred_real_patch = self.Dpatch.discriminate_features( | |
real_feat, target_feat | |
).sum() | |
grad_real, grad_target = torch.autograd.grad( | |
outputs=pred_real_patch, | |
inputs=[real_crop, target_crop], | |
create_graph=True, | |
retain_graph=True, | |
) | |
dims = list(range(1, grad_real.ndim)) | |
grad_crop_penalty = grad_real.pow(2).sum(dims) + \ | |
grad_target.pow(2).sum(dims) | |
grad_crop_penalty *= (0.5 * self.opt.lambda_patch_R1 * 0.5) | |
else: | |
grad_crop_penalty = 0.0 | |
losses["D_R1"] = grad_penalty + grad_crop_penalty | |
return losses | |
def compute_generator_losses(self, real, sp_ma, gl_ma): | |
losses, metrics = {}, {} | |
B = real.size(0) | |
sp, gl = self.E(real) | |
rec = self.G(sp[:B // 2], gl[:B // 2]) # only on B//2 to save memory | |
sp_mix = self.swap(sp) | |
if self.opt.crop_size >= 1024: | |
# another momery-saving trick: reduce #outputs to save memory | |
real = real[B // 2:] | |
gl = gl[B // 2:] | |
sp_mix = sp_mix[B // 2:] | |
mix = self.G(sp_mix, gl) | |
# record the error of the reconstructed images for monitoring purposes | |
metrics["L1_dist"] = self.l1_loss(rec, real[:B // 2]) | |
if self.opt.lambda_L1 > 0.0: | |
losses["G_L1"] = metrics["L1_dist"] * self.opt.lambda_L1 | |
if self.opt.lambda_GAN > 0.0: | |
losses["G_GAN_rec"] = loss.gan_loss( | |
self.D(rec), | |
should_be_classified_as_real=True | |
) * (self.opt.lambda_GAN * 0.5) | |
losses["G_GAN_mix"] = loss.gan_loss( | |
self.D(mix), | |
should_be_classified_as_real=True | |
) * (self.opt.lambda_GAN * 1.0) | |
if self.opt.lambda_PatchGAN > 0.0: | |
real_feat = self.Dpatch.extract_features( | |
self.get_random_crops(real), | |
aggregate=self.opt.patch_use_aggregation).detach() | |
mix_feat = self.Dpatch.extract_features(self.get_random_crops(mix)) | |
losses["G_mix"] = loss.gan_loss( | |
self.Dpatch.discriminate_features(real_feat, mix_feat), | |
should_be_classified_as_real=True, | |
) * self.opt.lambda_PatchGAN | |
return losses, metrics | |
def get_visuals_for_snapshot(self, real): | |
if self.opt.isTrain: | |
# avoid the overhead of generating too many visuals during training | |
real = real[:2] if self.opt.num_gpus > 1 else real[:4] | |
sp, gl = self.E(real) | |
layout = util.resize2d_tensor(util.visualize_spatial_code(sp), real) | |
rec = self.G(sp, gl) | |
mix = self.G(sp, self.swap(gl)) | |
visuals = {"real": real, "layout": layout, "rec": rec, "mix": mix} | |
return visuals | |
def fix_noise(self, sample_image=None): | |
""" The generator architecture is stochastic because of the noise | |
input at each layer (StyleGAN2 architecture). It could lead to | |
flickering of the outputs even when identical inputs are given. | |
Prevent flickering by fixing the noise injection of the generator. | |
""" | |
if sample_image is not None: | |
# The generator should be run at least once, | |
# so that the noise dimensions could be computed | |
sp, gl = self.E(sample_image) | |
self.G(sp, gl) | |
noise_var = self.G.fix_and_gather_noise_parameters() | |
return noise_var | |
def encode(self, image, extract_features=False): | |
return self.E(image, extract_features=extract_features) | |
def decode(self, spatial_code, global_code): | |
return self.G(spatial_code, global_code) | |
def get_parameters_for_mode(self, mode): | |
if mode == "generator": | |
return list(self.G.parameters()) + list(self.E.parameters()) | |
elif mode == "discriminator": | |
Dparams = [] | |
if self.opt.lambda_GAN > 0.0: | |
Dparams += list(self.D.parameters()) | |
if self.opt.lambda_PatchGAN > 0.0: | |
Dparams += list(self.Dpatch.parameters()) | |
return Dparams | |