Spaces:
Runtime error
Runtime error
| import os | |
| import glob | |
| import time | |
| import numpy as np | |
| from PIL import Image | |
| from pathlib import Path | |
| from tqdm.notebook import tqdm | |
| import matplotlib.pyplot as plt | |
| from skimage.color import rgb2lab, lab2rgb | |
| import torch | |
| from torch import nn, optim | |
| from torchvision import transforms | |
| from torchvision.utils import make_grid | |
| from torch.utils.data import Dataset, DataLoader | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| class PatchDiscriminator(nn.Module): | |
| def __init__(self, input_c, num_filters=64, n_down=3): | |
| super().__init__() | |
| model = [self.get_layers(input_c, num_filters, norm=False)] | |
| model += [ | |
| self.get_layers( | |
| num_filters * 2**i, | |
| num_filters * 2 ** (i + 1), | |
| s=1 if i == (n_down - 1) else 2, | |
| ) | |
| for i in range(n_down) | |
| ] # the 'if' statement is taking care of not using | |
| # stride of 2 for the last block in this loop | |
| model += [ | |
| self.get_layers(num_filters * 2**n_down, 1, s=1, norm=False, act=False) | |
| ] # Make sure to not use normalization or | |
| # activation for the last layer of the model | |
| self.model = nn.Sequential(*model) | |
| def get_layers( | |
| self, ni, nf, k=4, s=2, p=1, norm=True, act=True | |
| ): # when needing to make some repeatitive blocks of layers, | |
| layers = [ | |
| nn.Conv2d(ni, nf, k, s, p, bias=not norm) | |
| ] # it's always helpful to make a separate method for that purpose | |
| if norm: | |
| layers += [nn.BatchNorm2d(nf)] | |
| if act: | |
| layers += [nn.LeakyReLU(0.2, True)] | |
| return nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |