Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import time | |
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
from tqdm.notebook import tqdm | |
import matplotlib.pyplot as plt | |
from skimage.color import rgb2lab, lab2rgb | |
import torch | |
from torch import nn, optim | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from torch.utils.data import Dataset, DataLoader | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class PatchDiscriminator(nn.Module): | |
def __init__(self, input_c, num_filters=64, n_down=3): | |
super().__init__() | |
model = [self.get_layers(input_c, num_filters, norm=False)] | |
model += [ | |
self.get_layers( | |
num_filters * 2**i, | |
num_filters * 2 ** (i + 1), | |
s=1 if i == (n_down - 1) else 2, | |
) | |
for i in range(n_down) | |
] # the 'if' statement is taking care of not using | |
# stride of 2 for the last block in this loop | |
model += [ | |
self.get_layers(num_filters * 2**n_down, 1, s=1, norm=False, act=False) | |
] # Make sure to not use normalization or | |
# activation for the last layer of the model | |
self.model = nn.Sequential(*model) | |
def get_layers( | |
self, ni, nf, k=4, s=2, p=1, norm=True, act=True | |
): # when needing to make some repeatitive blocks of layers, | |
layers = [ | |
nn.Conv2d(ni, nf, k, s, p, bias=not norm) | |
] # it's always helpful to make a separate method for that purpose | |
if norm: | |
layers += [nn.BatchNorm2d(nf)] | |
if act: | |
layers += [nn.LeakyReLU(0.2, True)] | |
return nn.Sequential(*layers) | |
def forward(self, x): | |
return self.model(x) | |