|
import os |
|
import sys |
|
import time |
|
import glob |
|
import random |
|
import skimage |
|
import skimage.io |
|
import numpy as np |
|
from skimage import io, color |
|
|
|
import skimage |
|
import skimage.io |
|
from PIL import Image |
|
import cv2 |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torch.nn import functional as F |
|
|
|
import timm |
|
import torchvision |
|
from torchvision.models.feature_extraction import create_feature_extractor |
|
|
|
L_range = 100 |
|
ab_min = -128 |
|
ab_max = 127 |
|
ab_range = ab_max - ab_min |
|
|
|
def extract_zip(input_zip): |
|
input_zip=ZipFile(input_zip) |
|
return {name: input_zip.read(name) for name in input_zip.namelist()} |
|
|
|
def normalize_lab_channels(x): |
|
|
|
x[:,:,0] = x[:,:,0] / L_range |
|
|
|
|
|
x[:,:,1] = (x[:,:,1]-ab_min) / ab_range |
|
x[:,:,2] = (x[:,:,2]-ab_min) / ab_range |
|
return x |
|
|
|
def normalized_lab_to_rgb(lab): |
|
lab[:,:,0] = (lab[:,:,0] * L_range) |
|
lab[:,:,1] = (lab[:,:,1] * ab_range) + ab_min |
|
lab[:,:,2] = (lab[:,:,2] * ab_range) + ab_min |
|
return color.lab2rgb(lab) |
|
|
|
def torch_normalized_lab_to_rgb(lab): |
|
for i in range(lab.shape[0]): |
|
lab[i,0,:,:] = torch.clip(lab[i,0,:,:] * L_range, 0, L_range) |
|
lab[i,1,:,:] = torch.clip((lab[i,1,:,:] * ab_range) + ab_min, ab_min, ab_max) |
|
lab[i,2,:,:] = torch.clip((lab[i,2,:,:] * ab_range) + ab_min, ab_min, ab_max) |
|
|
|
for i in range(lab.shape[0]): |
|
lab[i] = torch.from_numpy( color.lab2rgb(lab[i].permute(1,2,0).detach().cpu().numpy()) ).permute(2,0,1) |
|
|
|
return lab |
|
|
|
class Encoder(nn.Module): |
|
def __init__(self): |
|
super(Encoder, self).__init__() |
|
self.backend_model = timm.create_model('efficientnetv2_rw_s', pretrained=True) |
|
self.backend = create_feature_extractor(self.backend_model, |
|
return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'act2']) |
|
|
|
def forward(self, x): |
|
features = self.backend(x) |
|
return list(features.values()) |
|
|
|
class UpSample(nn.Sequential): |
|
def __init__(self, in_channels, out_channels): |
|
skip_input, output_features = in_channels, out_channels |
|
|
|
super(UpSample, self).__init__() |
|
self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) |
|
self.leakyreluA = nn.LeakyReLU(0.2) |
|
self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) |
|
self.leakyreluB = nn.LeakyReLU(0.2) |
|
|
|
self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) |
|
|
|
def forward(self, x, concat_with=None): |
|
up_x = self.upsample(x) |
|
|
|
if concat_with is not None: |
|
up_x = torch.cat([up_x, concat_with], dim=1) |
|
|
|
return self.leakyreluB( self.convB( self.leakyreluA( self.convA( up_x ) ) ) ) |
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, num_features=1792 * 1, decoder_width=None): |
|
super(Decoder, self).__init__() |
|
features = int(num_features * decoder_width) |
|
|
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0, bias=False), |
|
nn.LeakyReLU(0.2), |
|
) |
|
|
|
self.up1 = UpSample(in_channels=features//1 + 152 - 24, out_channels=features//2) |
|
self.up2 = UpSample(in_channels=features//2 + 80 - 16, out_channels=features//4) |
|
self.up3 = UpSample(in_channels=features//4 + 56 - 8, out_channels=features//8) |
|
self.up4 = UpSample(in_channels=features//8 + 32 - 8, out_channels=features//16) |
|
|
|
self.up5 = UpSample(in_channels=features//16, out_channels=features//16) |
|
|
|
self.conv3 = nn.Conv2d(features//16, 2, kernel_size=1, stride=1, padding=0, bias=False) |
|
|
|
def forward(self, features): |
|
blocks0, blocks1, blocks2, blocks3, x = features |
|
|
|
x = self.conv2(x) |
|
|
|
x = self.up1(x, blocks3) |
|
x = self.up2(x, blocks2) |
|
x = self.up3(x, blocks1) |
|
x = self.up4(x, blocks0) |
|
|
|
x = self.up5(x) |
|
|
|
x_final = self.conv3(x) |
|
|
|
return x_final |
|
|
|
class ColorizeNet(nn.Module): |
|
def __init__(self, decoder_width): |
|
super(ColorizeNet, self).__init__() |
|
self.encoder = Encoder() |
|
self.decoder = Decoder(decoder_width=decoder_width) |
|
|
|
def forward(self, x): |
|
features_x = self.encoder(x) |
|
return self.decoder( features_x ) |
|
|