|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from .networks import UnetGenerator, PatchGAN |
|
|
|
class Pix2Pix(nn.Module): |
|
"""Create a Pix2Pix class. It is a model for image to image translation tasks. |
|
By default, the model uses a Unet architecture for generator with transposed |
|
convolution. The discriminator is 70x70 PatchGAN discriminator, by default. |
|
""" |
|
def __init__(self, |
|
c_in: int = 3, |
|
c_out: int = 3, |
|
is_train: bool = True, |
|
netD: str = 'patch', |
|
lambda_L1: float = 100.0, |
|
is_CGAN: bool = True, |
|
use_upsampling: bool = False, |
|
mode: str = 'nearest', |
|
c_hid: int = 64, |
|
n_layers: int = 3, |
|
lr: float = 0.0002, |
|
beta1: float = 0.5, |
|
beta2: float = 0.999 |
|
): |
|
"""Constructs the Pix2Pix class. |
|
|
|
Args: |
|
c_in: Number of input channels |
|
c_out: Number of output channels |
|
is_train: Whether the model is in training mode |
|
netD: Type of discriminator ('patch' or 'pixel') |
|
lambda_L1: Weight for L1 loss |
|
is_CGAN: If True, use conditional GAN architecture |
|
use_upsampling: If True, use upsampling in generator instead of transpose conv |
|
mode: Upsampling mode ('nearest', 'bilinear', 'bicubic') |
|
c_hid: Number of base filters in discriminator |
|
n_layers: Number of layers in discriminator |
|
lr: Learning rate |
|
beta1: Beta1 parameter for Adam optimizer |
|
beta2: Beta2 parameter for Adam optimizer |
|
""" |
|
super(Pix2Pix, self).__init__() |
|
self.is_CGAN = is_CGAN |
|
self.lambda_L1 = lambda_L1 |
|
|
|
self.gen = UnetGenerator(c_in=c_in, c_out=c_out, use_upsampling=use_upsampling, mode=mode) |
|
self.gen = self.gen.apply(self.weights_init) |
|
|
|
if is_train: |
|
|
|
disc_in = c_in + c_out if is_CGAN else c_out |
|
self.disc = PatchGAN(c_in=disc_in, c_hid=c_hid, mode=netD, n_layers=n_layers) |
|
self.disc = self.disc.apply(self.weights_init) |
|
|
|
|
|
self.gen_optimizer = torch.optim.Adam( |
|
self.gen.parameters(), lr=lr, betas=(beta1, beta2)) |
|
self.disc_optimizer = torch.optim.Adam( |
|
self.disc.parameters(), lr=lr, betas=(beta1, beta2)) |
|
|
|
|
|
self.criterion = nn.BCEWithLogitsLoss() |
|
self.criterion_L1 = nn.L1Loss() |
|
|
|
def forward(self, x): |
|
return self.gen(x) |
|
|
|
@staticmethod |
|
def weights_init(m): |
|
"""Initialize network weights. |
|
|
|
Args: |
|
m: network module |
|
""" |
|
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d): |
|
nn.init.normal_(m.weight, 0.0, 0.02) |
|
if hasattr(m, 'bias') and m.bias is not None: |
|
nn.init.constant_(m.bias, 0.0) |
|
if isinstance(m, nn.BatchNorm2d): |
|
nn.init.normal_(m.weight, 1.0, 0.02) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
def _get_disc_inputs(self, real_images, target_images, fake_images): |
|
"""Prepare discriminator inputs based on conditional/unconditional setup.""" |
|
if self.is_CGAN: |
|
|
|
|
|
real_AB = torch.cat([real_images, target_images], dim=1) |
|
fake_AB = torch.cat([real_images, |
|
fake_images.detach()], |
|
dim=1) |
|
else: |
|
real_AB = target_images |
|
fake_AB = fake_images.detach() |
|
return real_AB, fake_AB |
|
|
|
def _get_gen_inputs(self, real_images, fake_images): |
|
"""Prepare discriminator inputs based on conditional/unconditional setup.""" |
|
if self.is_CGAN: |
|
|
|
|
|
fake_AB = torch.cat([real_images, |
|
fake_images], |
|
dim=1) |
|
else: |
|
fake_AB = fake_images |
|
return fake_AB |
|
|
|
|
|
def step_discriminator(self, real_images, target_images, fake_images): |
|
"""Discriminator forward/backward pass. |
|
|
|
Args: |
|
real_images: Input images |
|
target_images: Ground truth images |
|
fake_images: Generated images |
|
|
|
Returns: |
|
Discriminator loss value |
|
""" |
|
|
|
real_AB, fake_AB = self._get_disc_inputs(real_images, target_images, |
|
fake_images) |
|
|
|
|
|
pred_real = self.disc(real_AB) |
|
pred_fake = self.disc(fake_AB) |
|
|
|
|
|
lossD_real = self.criterion(pred_real, torch.ones_like(pred_real)) |
|
lossD_fake = self.criterion(pred_fake, torch.zeros_like(pred_fake)) |
|
lossD = (lossD_real + lossD_fake) * 0.5 |
|
return lossD |
|
|
|
def step_generator(self, real_images, target_images, fake_images): |
|
"""Discriminator forward/backward pass. |
|
|
|
Args: |
|
real_images: Input images |
|
target_images: Ground truth images |
|
fake_images: Generated images |
|
|
|
Returns: |
|
Discriminator loss value |
|
""" |
|
|
|
fake_AB = self._get_gen_inputs(real_images, fake_images) |
|
|
|
|
|
pred_fake = self.disc(fake_AB) |
|
|
|
|
|
lossG_GaN = self.criterion(pred_fake, torch.ones_like(pred_fake)) |
|
lossG_L1 = self.criterion_L1(fake_images, target_images) |
|
lossG = lossG_GaN + self.lambda_L1 * lossG_L1 |
|
|
|
return lossG, { |
|
'loss_G': lossG.item(), |
|
'loss_G_GAN': lossG_GaN.item(), |
|
'loss_G_L1': lossG_L1.item() |
|
} |
|
|
|
def train_step(self, real_images, target_images): |
|
"""Performs a single training step. |
|
|
|
Args: |
|
real_images: Input images |
|
target_images: Ground truth images |
|
|
|
Returns: |
|
Dictionary containing all loss values from this step |
|
""" |
|
|
|
fake_images = self.forward(real_images) |
|
|
|
|
|
self.disc_optimizer.zero_grad() |
|
lossD = self.stepD(real_images, target_images, fake_images) |
|
lossD.backward() |
|
self.disc_optimizer.step() |
|
|
|
|
|
self.gen_optimizer.zero_grad() |
|
lossG, G_losses = self.stepG(real_images, target_images, fake_images) |
|
lossG.backward() |
|
self.gen_optimizer.step() |
|
|
|
|
|
return { |
|
'loss_D': lossD.item(), |
|
**G_losses |
|
} |
|
|
|
def get_current_visuals(self, real_images, target_images): |
|
"""Return visualization images. |
|
|
|
Args: |
|
real_images: Input images |
|
target_images: Ground truth images |
|
|
|
Returns: |
|
Dictionary containing input, target and generated images |
|
""" |
|
with torch.no_grad(): |
|
fake_images = self.gen(real_images) |
|
return { |
|
'real': real_images, |
|
'fake': fake_images, |
|
'target': target_images |
|
} |