|
import streamlit as st |
|
from PIL import Image |
|
import cv2 as cv |
|
|
|
import os |
|
import glob |
|
import time |
|
import numpy as np |
|
from PIL import Image |
|
from pathlib import Path |
|
from tqdm.notebook import tqdm |
|
import matplotlib.pyplot as plt |
|
from skimage.color import rgb2lab, lab2rgb |
|
|
|
|
|
|
|
import torch |
|
from torch import nn, optim |
|
from torchvision import transforms |
|
from torchvision.utils import make_grid |
|
from torch.utils.data import Dataset, DataLoader |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
use_colab = None |
|
|
|
SIZE = 256 |
|
class ColorizationDataset(Dataset): |
|
def __init__(self, paths, split='train'): |
|
if split == 'train': |
|
self.transforms = transforms.Compose([ |
|
transforms.Resize((SIZE, SIZE), Image.BICUBIC), |
|
transforms.RandomHorizontalFlip(), |
|
]) |
|
elif split == 'val': |
|
self.transforms = transforms.Resize((SIZE, SIZE), Image.BICUBIC) |
|
|
|
self.split = split |
|
self.size = SIZE |
|
self.paths = paths |
|
|
|
def __getitem__(self, idx): |
|
img = Image.open(self.paths[idx]).convert("RGB") |
|
img = self.transforms(img) |
|
img = np.array(img) |
|
img_lab = rgb2lab(img).astype("float32") |
|
img_lab = transforms.ToTensor()(img_lab) |
|
L = img_lab[[0], ...] / 50. - 1. |
|
ab = img_lab[[1, 2], ...] / 110. |
|
|
|
return {'L': L, 'ab': ab} |
|
|
|
def __len__(self): |
|
return len(self.paths) |
|
|
|
def make_dataloaders(batch_size=16, n_workers=4, pin_memory=True, **kwargs): |
|
dataset = ColorizationDataset(**kwargs) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, |
|
pin_memory=pin_memory) |
|
return dataloader |
|
|
|
class UnetBlock(nn.Module): |
|
def __init__(self, nf, ni, submodule=None, input_c=None, dropout=False, |
|
innermost=False, outermost=False): |
|
super().__init__() |
|
self.outermost = outermost |
|
if input_c is None: input_c = nf |
|
downconv = nn.Conv2d(input_c, ni, kernel_size=4, |
|
stride=2, padding=1, bias=False) |
|
downrelu = nn.LeakyReLU(0.2, True) |
|
downnorm = nn.BatchNorm2d(ni) |
|
uprelu = nn.ReLU(True) |
|
upnorm = nn.BatchNorm2d(nf) |
|
|
|
if outermost: |
|
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, |
|
stride=2, padding=1) |
|
down = [downconv] |
|
up = [uprelu, upconv, nn.Tanh()] |
|
model = down + [submodule] + up |
|
elif innermost: |
|
upconv = nn.ConvTranspose2d(ni, nf, kernel_size=4, |
|
stride=2, padding=1, bias=False) |
|
down = [downrelu, downconv] |
|
up = [uprelu, upconv, upnorm] |
|
model = down + up |
|
else: |
|
upconv = nn.ConvTranspose2d(ni * 2, nf, kernel_size=4, |
|
stride=2, padding=1, bias=False) |
|
down = [downrelu, downconv, downnorm] |
|
up = [uprelu, upconv, upnorm] |
|
if dropout: up += [nn.Dropout(0.5)] |
|
model = down + [submodule] + up |
|
self.model = nn.Sequential(*model) |
|
|
|
def forward(self, x): |
|
if self.outermost: |
|
return self.model(x) |
|
else: |
|
return torch.cat([x, self.model(x)], 1) |
|
|
|
class Unet(nn.Module): |
|
def __init__(self, input_c=1, output_c=2, n_down=8, num_filters=64): |
|
super().__init__() |
|
unet_block = UnetBlock(num_filters * 8, num_filters * 8, innermost=True) |
|
for _ in range(n_down - 5): |
|
unet_block = UnetBlock(num_filters * 8, num_filters * 8, submodule=unet_block, dropout=True) |
|
out_filters = num_filters * 8 |
|
for _ in range(3): |
|
unet_block = UnetBlock(out_filters // 2, out_filters, submodule=unet_block) |
|
out_filters //= 2 |
|
self.model = UnetBlock(output_c, out_filters, input_c=input_c, submodule=unet_block, outermost=True) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class PatchDiscriminator(nn.Module): |
|
def __init__(self, input_c, num_filters=64, n_down=3): |
|
super().__init__() |
|
model = [self.get_layers(input_c, num_filters, norm=False)] |
|
model += [self.get_layers(num_filters * 2 ** i, num_filters * 2 ** (i + 1), s=1 if i == (n_down-1) else 2) |
|
for i in range(n_down)] |
|
|
|
model += [self.get_layers(num_filters * 2 ** n_down, 1, s=1, norm=False, act=False)] |
|
|
|
self.model = nn.Sequential(*model) |
|
|
|
def get_layers(self, ni, nf, k=4, s=2, p=1, norm=True, act=True): |
|
layers = [nn.Conv2d(ni, nf, k, s, p, bias=not norm)] |
|
if norm: layers += [nn.BatchNorm2d(nf)] |
|
if act: layers += [nn.LeakyReLU(0.2, True)] |
|
return nn.Sequential(*layers) |
|
|
|
def forward(self, x): |
|
return self.model(x) |
|
|
|
class GANLoss(nn.Module): |
|
def __init__(self, gan_mode='vanilla', real_label=1.0, fake_label=0.0): |
|
super().__init__() |
|
self.register_buffer('real_label', torch.tensor(real_label)) |
|
self.register_buffer('fake_label', torch.tensor(fake_label)) |
|
if gan_mode == 'vanilla': |
|
self.loss = nn.BCEWithLogitsLoss() |
|
elif gan_mode == 'lsgan': |
|
self.loss = nn.MSELoss() |
|
|
|
def get_labels(self, preds, target_is_real): |
|
if target_is_real: |
|
labels = self.real_label |
|
else: |
|
labels = self.fake_label |
|
return labels.expand_as(preds) |
|
|
|
def __call__(self, preds, target_is_real): |
|
labels = self.get_labels(preds, target_is_real) |
|
loss = self.loss(preds, labels) |
|
return loss |
|
|
|
def init_weights(net, init='norm', gain=0.02): |
|
|
|
def init_func(m): |
|
classname = m.__class__.__name__ |
|
if hasattr(m, 'weight') and 'Conv' in classname: |
|
if init == 'norm': |
|
nn.init.normal_(m.weight.data, mean=0.0, std=gain) |
|
elif init == 'xavier': |
|
nn.init.xavier_normal_(m.weight.data, gain=gain) |
|
elif init == 'kaiming': |
|
nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in') |
|
|
|
if hasattr(m, 'bias') and m.bias is not None: |
|
nn.init.constant_(m.bias.data, 0.0) |
|
elif 'BatchNorm2d' in classname: |
|
nn.init.normal_(m.weight.data, 1., gain) |
|
nn.init.constant_(m.bias.data, 0.) |
|
|
|
net.apply(init_func) |
|
print(f"model initialized with {init} initialization") |
|
return net |
|
|
|
def init_model(model, device): |
|
model = model.to(device) |
|
model = init_weights(model) |
|
return model |
|
|
|
class MainModel(nn.Module): |
|
def __init__(self, net_G=None, lr_G=2e-4, lr_D=2e-4, |
|
beta1=0.5, beta2=0.999, lambda_L1=100.): |
|
super().__init__() |
|
|
|
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
self.lambda_L1 = lambda_L1 |
|
|
|
if net_G is None: |
|
self.net_G = init_model(Unet(input_c=1, output_c=2, n_down=8, num_filters=64), self.device) |
|
else: |
|
self.net_G = net_G.to(self.device) |
|
self.net_D = init_model(PatchDiscriminator(input_c=3, n_down=3, num_filters=64), self.device) |
|
self.GANcriterion = GANLoss(gan_mode='vanilla').to(self.device) |
|
self.L1criterion = nn.L1Loss() |
|
self.opt_G = optim.Adam(self.net_G.parameters(), lr=lr_G, betas=(beta1, beta2)) |
|
self.opt_D = optim.Adam(self.net_D.parameters(), lr=lr_D, betas=(beta1, beta2)) |
|
|
|
def set_requires_grad(self, model, requires_grad=True): |
|
for p in model.parameters(): |
|
p.requires_grad = requires_grad |
|
|
|
def setup_input(self, data): |
|
self.L = data['L'].to(self.device) |
|
self.ab = data['ab'].to(self.device) |
|
|
|
def forward(self): |
|
self.fake_color = self.net_G(self.L) |
|
|
|
def backward_D(self): |
|
fake_image = torch.cat([self.L, self.fake_color], dim=1) |
|
fake_preds = self.net_D(fake_image.detach()) |
|
self.loss_D_fake = self.GANcriterion(fake_preds, False) |
|
real_image = torch.cat([self.L, self.ab], dim=1) |
|
real_preds = self.net_D(real_image) |
|
self.loss_D_real = self.GANcriterion(real_preds, True) |
|
self.loss_D = (self.loss_D_fake + self.loss_D_real) * 0.5 |
|
self.loss_D.backward() |
|
|
|
def backward_G(self): |
|
fake_image = torch.cat([self.L, self.fake_color], dim=1) |
|
fake_preds = self.net_D(fake_image) |
|
self.loss_G_GAN = self.GANcriterion(fake_preds, True) |
|
self.loss_G_L1 = self.L1criterion(self.fake_color, self.ab) * self.lambda_L1 |
|
self.loss_G = self.loss_G_GAN + self.loss_G_L1 |
|
self.loss_G.backward() |
|
|
|
def optimize(self): |
|
self.forward() |
|
self.net_D.train() |
|
self.set_requires_grad(self.net_D, True) |
|
self.opt_D.zero_grad() |
|
self.backward_D() |
|
self.opt_D.step() |
|
|
|
self.net_G.train() |
|
self.set_requires_grad(self.net_D, False) |
|
self.opt_G.zero_grad() |
|
self.backward_G() |
|
self.opt_G.step() |
|
|
|
class AverageMeter: |
|
def __init__(self): |
|
self.reset() |
|
|
|
def reset(self): |
|
self.count, self.avg, self.sum = [0.] * 3 |
|
|
|
def update(self, val, count=1): |
|
self.count += count |
|
self.sum += count * val |
|
self.avg = self.sum / self.count |
|
|
|
def create_loss_meters(): |
|
loss_D_fake = AverageMeter() |
|
loss_D_real = AverageMeter() |
|
loss_D = AverageMeter() |
|
loss_G_GAN = AverageMeter() |
|
loss_G_L1 = AverageMeter() |
|
loss_G = AverageMeter() |
|
|
|
return {'loss_D_fake': loss_D_fake, |
|
'loss_D_real': loss_D_real, |
|
'loss_D': loss_D, |
|
'loss_G_GAN': loss_G_GAN, |
|
'loss_G_L1': loss_G_L1, |
|
'loss_G': loss_G} |
|
|
|
def update_losses(model, loss_meter_dict, count): |
|
for loss_name, loss_meter in loss_meter_dict.items(): |
|
loss = getattr(model, loss_name) |
|
loss_meter.update(loss.item(), count=count) |
|
|
|
def lab_to_rgb(L, ab): |
|
""" |
|
Takes a batch of images |
|
""" |
|
|
|
L = (L + 1.) * 50. |
|
ab = ab * 110. |
|
Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy() |
|
rgb_imgs = [] |
|
for img in Lab: |
|
img_rgb = lab2rgb(img) |
|
rgb_imgs.append(img_rgb) |
|
return np.stack(rgb_imgs, axis=0) |
|
|
|
def visualize(model, data, dims): |
|
model.net_G.eval() |
|
with torch.no_grad(): |
|
model.setup_input(data) |
|
model.forward() |
|
model.net_G.train() |
|
fake_color = model.fake_color.detach() |
|
real_color = model.ab |
|
L = model.L |
|
fake_imgs = lab_to_rgb(L, fake_color) |
|
real_imgs = lab_to_rgb(L, real_color) |
|
for i in range(1): |
|
|
|
img = Image.fromarray(np.uint8(fake_imgs[i])) |
|
img = cv.resize(fake_imgs[i], dsize=(dims[1], dims[0]), interpolation=cv.INTER_CUBIC) |
|
|
|
st.image(img, caption="Output image", use_column_width='auto', clamp=True) |
|
|
|
def log_results(loss_meter_dict): |
|
for loss_name, loss_meter in loss_meter_dict.items(): |
|
print(f"{loss_name}: {loss_meter.avg:.5f}") |
|
|
|
|
|
from fastai.vision.learner import create_body |
|
from torchvision.models.resnet import resnet18 |
|
from fastai.vision.models.unet import DynamicUnet |
|
|
|
def build_res_unet(n_input=1, n_output=2, size=256): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
body = create_body(resnet18(), pretrained=True, n_in=n_input, cut=-2) |
|
net_G = DynamicUnet(body, n_output, (size, size)).to(device) |
|
return net_G |
|
|
|
net_G = build_res_unet(n_input=1, n_output=2, size=256) |
|
net_G.load_state_dict(torch.load("res18-unet.pt", map_location=device)) |
|
model = MainModel(net_G=net_G) |
|
model.load_state_dict(torch.load("final_model_weights.pt", map_location=device)) |
|
|
|
class MyDataset(torch.utils.data.Dataset): |
|
def __init__(self, img_list): |
|
super(MyDataset, self).__init__() |
|
self.img_list = img_list |
|
self.augmentations = transforms.Resize((SIZE, SIZE), Image.BICUBIC) |
|
|
|
|
|
def __len__(self): |
|
return len(self.img_list) |
|
|
|
def __getitem__(self, idx): |
|
img = self.img_list[idx] |
|
img = self.augmentations(img) |
|
img = np.array(img) |
|
img_lab = rgb2lab(img).astype("float32") |
|
img_lab = transforms.ToTensor()(img_lab) |
|
L = img_lab[[0], ...] / 50. - 1. |
|
ab = img_lab[[1, 2], ...] / 110. |
|
return {'L': L, 'ab': ab} |
|
|
|
def make_dataloaders2(batch_size=16, n_workers=4, pin_memory=True, **kwargs): |
|
dataset = MyDataset(**kwargs) |
|
dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=n_workers, |
|
pin_memory=pin_memory) |
|
return dataloader |
|
|
|
file_up = st.file_uploader("Upload an jpg image", type="jpg") |
|
if file_up is not None: |
|
im = Image.open(file_up) |
|
st.text(body=f"Size of uploaded image {im.shape}") |
|
a = im.shape |
|
st.image(im, caption="Uploaded Image.", use_column_width='auto') |
|
test_dl = make_dataloaders2(img_list=[im]) |
|
for data in test_dl: |
|
model.setup_input(data) |
|
model.optimize() |
|
visualize(model, data, a) |