import pytorch_lightning as pl import torch import torch.nn as nn import torch.nn.functional as F from typing import Dict, List, Optional, OrderedDict, Tuple class Discriminator(nn.Module): def __init__( self, hidden_size: Optional[int] = 64, channels: Optional[int] = 3, kernel_size: Optional[int] = 4, stride: Optional[int] = 2, padding: Optional[int] = 1, negative_slope: Optional[float] = 0.2, bias: Optional[bool] = False, ): """ Initializes the discriminator. Parameters ---------- hidden_size : int, optional The input size. (the default is 64) channels : int, optional The number of channels. (default: 3) kernel_size : int, optional The kernal size. (default: 4) stride : int, optional The stride. (default: 2) padding : int, optional The padding. (default: 1) negative_slope : float, optional The negative slope. (default: 0.2) bias : bool, optional Whether to use bias. (default: False) """ super().__init__() self.hidden_size = hidden_size self.channels = channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.negative_slope = negative_slope self.bias = bias self.model = nn.Sequential( nn.utils.spectral_norm( nn.Conv2d( self.channels, self.hidden_size, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), ), nn.LeakyReLU(self.negative_slope, inplace=True), nn.utils.spectral_norm( nn.Conv2d( hidden_size, hidden_size * 2, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), ), nn.BatchNorm2d(hidden_size * 2), nn.LeakyReLU(self.negative_slope, inplace=True), nn.utils.spectral_norm( nn.Conv2d( hidden_size * 2, hidden_size * 4, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), ), nn.BatchNorm2d(hidden_size * 4), nn.LeakyReLU(self.negative_slope, inplace=True), nn.utils.spectral_norm( nn.Conv2d( hidden_size * 4, hidden_size * 8, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), ), nn.BatchNorm2d(hidden_size * 8), nn.LeakyReLU(self.negative_slope, inplace=True), nn.utils.spectral_norm( nn.Conv2d(hidden_size * 8, 1, kernel_size=4, stride=1, padding=0, bias=self.bias), # output size: (1, 1, 1) ), nn.Flatten(), nn.Sigmoid(), ) def forward(self, input_img: torch.Tensor) -> torch.Tensor: """ Forward propagation. Parameters ---------- input_img : torch.Tensor The input image. Returns ------- torch.Tensor The output. """ logits = self.model(input_img) return logits class Generator(nn.Module): def __init__( self, hidden_size: Optional[int] = 64, latent_size: Optional[int] = 128, channels: Optional[int] = 3, kernel_size: Optional[int] = 4, stride: Optional[int] = 2, padding: Optional[int] = 1, bias: Optional[bool] = False, ): """ Initializes the generator. Parameters ---------- hidden_size : int, optional The hidden size. (default: 64) latent_size : int, optional The latent size. (default: 128) channels : int, optional The number of channels. (default: 3) kernel_size : int, optional The kernel size. (default: 4) stride : int, optional The stride. (default: 2) padding : int, optional The padding. (default: 1) bias : bool, optional Whether to use bias. (default: False) """ super().__init__() self.hidden_size = hidden_size self.latent_size = latent_size self.channels = channels self.kernel_size = kernel_size self.stride = stride self.padding = padding self.bias = bias self.model = nn.Sequential( nn.ConvTranspose2d( self.latent_size, self.hidden_size * 8, kernel_size=self.kernel_size, stride=1, padding=0, bias=self.bias ), nn.BatchNorm2d(self.hidden_size * 8), nn.ReLU(inplace=True), nn.ConvTranspose2d( self.hidden_size * 8, self.hidden_size * 4, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), nn.BatchNorm2d(self.hidden_size * 4), nn.ReLU(inplace=True), nn.ConvTranspose2d( self.hidden_size * 4, self.hidden_size * 2, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), nn.BatchNorm2d(self.hidden_size * 2), nn.ReLU(inplace=True), nn.ConvTranspose2d( self.hidden_size * 2, self.hidden_size, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), nn.BatchNorm2d(self.hidden_size), nn.ReLU(inplace=True), nn.ConvTranspose2d( self.hidden_size, self.channels, kernel_size=self.kernel_size, stride=self.stride, padding=self.padding, bias=self.bias ), nn.Tanh() # output size: (channels, 64, 64) ) def forward(self, input_noise: torch.Tensor) -> torch.Tensor: """ Forward propagation. Parameters ---------- input_noise : torch.Tensor The input image. Returns ------- torch.Tensor The output. """ fake_img = self.model(input_noise) return fake_img class DocuGAN(pl.LightningModule): def __init__( self, hidden_size: Optional[int] = 64, latent_size: Optional[int] = 128, num_channel: Optional[int] = 3, learning_rate: Optional[float] = 0.0002, batch_size: Optional[int] = 128, bias1: Optional[float] = 0.5, bias2: Optional[float] = 0.999, ): """ Initializes the LightningGan. Parameters ---------- hidden_size : int, optional The hidden size. (default: 64) latent_size : int, optional The latent size. (default: 128) num_channel : int, optional The number of channels. (default: 3) learning_rate : float, optional The learning rate. (default: 0.0002) batch_size : int, optional The batch size. (default: 128) bias1 : float, optional The bias1. (default: 0.5) bias2 : float, optional The bias2. (default: 0.999) """ super().__init__() self.hidden_size = hidden_size self.latent_size = latent_size self.num_channel = num_channel self.learning_rate = learning_rate self.batch_size = batch_size self.bias1 = bias1 self.bias2 = bias2 self.criterion = nn.BCELoss() self.validation = torch.randn(self.batch_size, self.latent_size, 1, 1) self.save_hyperparameters() self.generator = Generator( latent_size=self.latent_size, channels=self.num_channel, hidden_size=self.hidden_size ) self.generator.apply(self.weights_init) self.discriminator = Discriminator(channels=self.num_channel, hidden_size=self.hidden_size) self.discriminator.apply(self.weights_init) # self.model = InceptionV3() # For FID metric def weights_init(self, m: nn.Module) -> None: """ Initializes the weights. Parameters ---------- m : nn.Module The module. """ classname = m.__class__.__name__ if classname.find("Conv") != -1: nn.init.normal_(m.weight.data, 0.0, 0.02) elif classname.find("BatchNorm") != -1: nn.init.normal_(m.weight.data, 1.0, 0.02) nn.init.constant_(m.bias.data, 0) def configure_optimizers(self) -> Tuple[List[torch.optim.Optimizer], List]: """ Configures the optimizers. Returns ------- Tuple[List[torch.optim.Optimizer], List] The optimizers and the LR schedulers. """ opt_generator = torch.optim.Adam( self.generator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2) ) opt_discriminator = torch.optim.Adam( self.discriminator.parameters(), lr=self.learning_rate, betas=(self.bias1, self.bias2) ) return [opt_generator, opt_discriminator], [] def forward(self, z: torch.Tensor) -> torch.Tensor: """ Forward propagation. Parameters ---------- z : torch.Tensorh The latent vector. Returns ------- torch.Tensor The output. """ return self.generator(z) def training_step( self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int, optimizer_idx: int ) -> Dict: """ Training step. Parameters ---------- batch : Tuple[torch.Tensor, torch.Tensor] The batch. batch_idx : int The batch index. optimizer_idx : int The optimizer index. Returns ------- Dict The training loss. """ real_images = batch["tr_image"] if optimizer_idx == 0: # Only train the generator fake_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1) fake_random_noise = fake_random_noise.type_as(real_images) fake_images = self(fake_random_noise) # Try to fool the discriminator preds = self.discriminator(fake_images) loss = self.criterion(preds, torch.ones_like(preds)) self.log("g_loss", loss, on_step=False, on_epoch=True) tqdm_dict = {"g_loss": loss} output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) return output elif optimizer_idx == 1: # Only train the discriminator real_preds = self.discriminator(real_images) real_loss = self.criterion(real_preds, torch.ones_like(real_preds)) # Generate fake images real_random_noise = torch.randn(self.batch_size, self.latent_size, 1, 1) real_random_noise = real_random_noise.type_as(real_images) fake_images = self(real_random_noise) # Pass fake images though discriminator fake_preds = self.discriminator(fake_images) fake_loss = self.criterion(fake_preds, torch.zeros_like(fake_preds)) # Update discriminator weights loss = real_loss + fake_loss self.log("d_loss", loss, on_step=False, on_epoch=True) tqdm_dict = {"d_loss": loss} output = OrderedDict({"loss": loss, "progress_bar": tqdm_dict, "log": tqdm_dict}) return output