Spaces:
Runtime error
Runtime error
import os | |
import glob | |
import time | |
import numpy as np | |
from PIL import Image | |
from pathlib import Path | |
from tqdm.notebook import tqdm | |
import matplotlib.pyplot as plt | |
from skimage.color import rgb2lab, lab2rgb | |
import torch | |
from torch import nn, optim | |
from torchvision import transforms | |
from torchvision.utils import make_grid | |
from torch.utils.data import Dataset, DataLoader | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class UnetBlock(nn.Module): | |
def __init__( | |
self, | |
nf, | |
ni, | |
submodule=None, | |
input_c=None, | |
dropout=False, | |
innermost=False, | |
outermost=False, | |
): | |
super().__init__() | |
self.outermost = outermost | |
if input_c is None: | |
input_c = nf | |
downconv = nn.Conv2d( | |
input_c, ni, kernel_size=4, stride=2, padding=1, bias=False | |
) | |
downrelu = nn.LeakyReLU(0.2, True) | |
downnorm = nn.BatchNorm2d(ni) | |
uprelu = nn.ReLU(True) | |
upnorm = nn.BatchNorm2d(nf) | |
if outermost: | |
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1) | |
down = [downconv] | |
up = [uprelu, upconv, nn.Tanh()] | |
model = down + [submodule] + up | |
elif innermost: | |
upconv = nn.ConvTranspose2d( | |
ni, nf, kernel_size=4, stride=2, padding=1, bias=False | |
) | |
down = [downrelu, downconv] | |
up = [uprelu, upconv, upnorm] | |
model = down + up | |
else: | |
upconv = nn.ConvTranspose2d( | |
ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False | |
) | |
down = [downrelu, downconv, downnorm] | |
up = [uprelu, upconv, upnorm] | |
if dropout: | |
up += [nn.Dropout(0.5)] | |
model = down + [submodule] + up | |
self.model = nn.Sequential(*model) | |
def forward(self, x): | |
if self.outermost: | |
return self.model(x) | |
else: | |
return torch.cat([x, self.model(x)], 1) | |
class Unet(nn.Module): | |
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64): | |
super().__init__() | |
unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True) | |
for _ in range(n_down - 5): | |
unet_block = UnetBlock( | |
num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True | |
) | |
out_filters = num_filters * 8 | |
for _ in range(3): | |
unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block) | |
out_filters //= 2 | |
self.model = UnetBlock( | |
output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True | |
) | |
def forward(self, x): | |
return self.model(x) | |