import os import glob import time import numpy as np from PIL import Image from pathlib import Path from tqdm.notebook import tqdm import matplotlib.pyplot as plt from skimage.color import rgb2lab, lab2rgb import torch from torch import nn, optim from torchvision import transforms from torchvision.utils import make_grid from torch.utils.data import Dataset, DataLoader device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class PatchDiscriminator(nn.Module): def __init__(self, input_c, num_filters=64, n_down=3): super().__init__() model = [self.get_layers(input_c, num_filters, norm=False)] model += [ self.get_layers( num_filters * 2**i, num_filters * 2 ** (i + 1), s=1 if i == (n_down - 1) else 2, ) for i in range(n_down) ] # the 'if' statement is taking care of not using # stride of 2 for the last block in this loop model += [ self.get_layers(num_filters * 2**n_down, 1, s=1, norm=False, act=False) ] # Make sure to not use normalization or # activation for the last layer of the model self.model = nn.Sequential(*model) def get_layers( self, ni, nf, k=4, s=2, p=1, norm=True, act=True ): # when needing to make some repeatitive blocks of layers, layers = [ nn.Conv2d(ni, nf, k, s, p, bias=not norm) ] # it's always helpful to make a separate method for that purpose if norm: layers += [nn.BatchNorm2d(nf)] if act: layers += [nn.LeakyReLU(0.2, True)] return nn.Sequential(*layers) def forward(self, x): return self.model(x)