Spaces:
Runtime error
Runtime error
File size: 2,809 Bytes
77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f 77b60c4 1c7b15f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import os
import glob
import time
import numpy as np
from PIL import Image
from pathlib import Path
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
from skimage.color import rgb2lab, lab2rgb
import torch
from torch import nn, optim
from torchvision import transforms
from torchvision.utils import make_grid
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class UnetBlock(nn.Module):
def __init__(
self,
nf,
ni,
submodule=None,
input_c=None,
dropout=False,
innermost=False,
outermost=False,
):
super().__init__()
self.outermost = outermost
if input_c is None:
input_c = nf
downconv = nn.Conv2d(
input_c, ni, kernel_size=4, stride=2, padding=1, bias=False
)
downrelu = nn.LeakyReLU(0.2, True)
downnorm = nn.BatchNorm2d(ni)
uprelu = nn.ReLU(True)
upnorm = nn.BatchNorm2d(nf)
if outermost:
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, stride=2, padding=1)
down = [downconv]
up = [uprelu, upconv, nn.Tanh()]
model = down + [submodule] + up
elif innermost:
upconv = nn.ConvTranspose2d(
ni, nf, kernel_size=4, stride=2, padding=1, bias=False
)
down = [downrelu, downconv]
up = [uprelu, upconv, upnorm]
model = down + up
else:
upconv = nn.ConvTranspose2d(
ni * 2, nf, kernel_size=4, stride=2, padding=1, bias=False
)
down = [downrelu, downconv, downnorm]
up = [uprelu, upconv, upnorm]
if dropout:
up += [nn.Dropout(0.5)]
model = down + [submodule] + up
self.model = nn.Sequential(*model)
def forward(self, x):
if self.outermost:
return self.model(x)
else:
return torch.cat([x, self.model(x)], 1)
class Unet(nn.Module):
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64):
super().__init__()
unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True)
for _ in range(n_down - 5):
unet_block = UnetBlock(
num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True
)
out_filters = num_filters * 8
for _ in range(3):
unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block)
out_filters //= 2
self.model = UnetBlock(
output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True
)
def forward(self, x):
return self.model(x)
|