import os import sys import time import glob import random import skimage import skimage.io import numpy as np from skimage import io, color import skimage import skimage.io from PIL import Image import cv2 import torch import torch.nn as nn from torch.nn import functional as F import timm import torchvision from torchvision.models.feature_extraction import create_feature_extractor L_range = 100 ab_min = -128 ab_max = 127 ab_range = ab_max - ab_min def extract_zip(input_zip): input_zip=ZipFile(input_zip) return {name: input_zip.read(name) for name in input_zip.namelist()} def normalize_lab_channels(x): # Normalize L x[:,:,0] = x[:,:,0] / L_range # Normalize AB x[:,:,1] = (x[:,:,1]-ab_min) / ab_range x[:,:,2] = (x[:,:,2]-ab_min) / ab_range return x def normalized_lab_to_rgb(lab): lab[:,:,0] = (lab[:,:,0] * L_range) lab[:,:,1] = (lab[:,:,1] * ab_range) + ab_min lab[:,:,2] = (lab[:,:,2] * ab_range) + ab_min return color.lab2rgb(lab) def torch_normalized_lab_to_rgb(lab): for i in range(lab.shape[0]): lab[i,0,:,:] = torch.clip(lab[i,0,:,:] * L_range, 0, L_range) lab[i,1,:,:] = torch.clip((lab[i,1,:,:] * ab_range) + ab_min, ab_min, ab_max) lab[i,2,:,:] = torch.clip((lab[i,2,:,:] * ab_range) + ab_min, ab_min, ab_max) for i in range(lab.shape[0]): lab[i] = torch.from_numpy( color.lab2rgb(lab[i].permute(1,2,0).detach().cpu().numpy()) ).permute(2,0,1) return lab class Encoder(nn.Module): def __init__(self): super(Encoder, self).__init__() self.backend_model = timm.create_model('efficientnetv2_rw_s', pretrained=True) self.backend = create_feature_extractor(self.backend_model, return_nodes=['blocks.0', 'blocks.1', 'blocks.2', 'blocks.3', 'act2']) def forward(self, x): features = self.backend(x) return list(features.values()) class UpSample(nn.Sequential): def __init__(self, in_channels, out_channels): skip_input, output_features = in_channels, out_channels super(UpSample, self).__init__() self.convA = nn.Conv2d(skip_input, output_features, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) self.leakyreluA = nn.LeakyReLU(0.2) self.convB = nn.Conv2d(output_features, output_features, kernel_size=3, stride=1, padding=1, padding_mode='reflect', bias=False) self.leakyreluB = nn.LeakyReLU(0.2) self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=False) def forward(self, x, concat_with=None): up_x = self.upsample(x) if concat_with is not None: up_x = torch.cat([up_x, concat_with], dim=1) return self.leakyreluB( self.convB( self.leakyreluA( self.convA( up_x ) ) ) ) class Decoder(nn.Module): def __init__(self, num_features=1792 * 1, decoder_width=None): super(Decoder, self).__init__() features = int(num_features * decoder_width) self.conv2 = nn.Sequential( nn.Conv2d(num_features, features, kernel_size=1, stride=1, padding=0, bias=False), nn.LeakyReLU(0.2), ) self.up1 = UpSample(in_channels=features//1 + 152 - 24, out_channels=features//2) self.up2 = UpSample(in_channels=features//2 + 80 - 16, out_channels=features//4) self.up3 = UpSample(in_channels=features//4 + 56 - 8, out_channels=features//8) self.up4 = UpSample(in_channels=features//8 + 32 - 8, out_channels=features//16) self.up5 = UpSample(in_channels=features//16, out_channels=features//16) self.conv3 = nn.Conv2d(features//16, 2, kernel_size=1, stride=1, padding=0, bias=False) def forward(self, features): blocks0, blocks1, blocks2, blocks3, x = features x = self.conv2(x) x = self.up1(x, blocks3) x = self.up2(x, blocks2) x = self.up3(x, blocks1) x = self.up4(x, blocks0) x = self.up5(x) x_final = self.conv3(x) return x_final class ColorizeNet(nn.Module): def __init__(self, decoder_width): super(ColorizeNet, self).__init__() self.encoder = Encoder() self.decoder = Decoder(decoder_width=decoder_width) def forward(self, x): features_x = self.encoder(x) return self.decoder( features_x )