|
import gradio as gr |
|
|
|
import torch |
|
|
|
import skimage |
|
import skimage.io |
|
from skimage.transform import rescale, resize |
|
from skimage import io, color |
|
|
|
import cv2 |
|
from colorizer import normalize_lab_channels, torch_normalized_lab_to_rgb |
|
|
|
model = torch.load('colorizer.pth', weights_only=False) |
|
model = model.eval() |
|
|
|
def colorize_img( img_float32, model, res=512, border=0.2, apply_blur=True): |
|
img = img_float32 |
|
|
|
|
|
border = int(img.shape[0] * border) |
|
img2 = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) |
|
|
|
|
|
img_resized = resize(img2, (res,res), anti_aliasing=True) |
|
|
|
|
|
if apply_blur: |
|
img2 = skimage.filters.gaussian( img_resized, sigma=1, channel_axis=-1 ) |
|
else: |
|
img2 = img_resized |
|
|
|
|
|
img2 = normalize_lab_channels(color.rgb2lab(img2)) |
|
|
|
|
|
img_resized = normalize_lab_channels(color.rgb2lab(img_resized)) |
|
img_resized = torch.from_numpy(img_resized) |
|
img_resized = img_resized.permute(2,0,1).unsqueeze(dim=0) |
|
|
|
|
|
x = torch.from_numpy(img2) |
|
x = x.permute(2,0,1).unsqueeze(dim=0) |
|
x[:,1,:,:] = x[:,0,:,:] |
|
x[:,2,:,:] = x[:,0,:,:] |
|
|
|
x_hat_ab = model( x ) |
|
|
|
x_hat = img_resized.clone() |
|
x_hat[:,1:,:,:] = x_hat_ab.clone() |
|
|
|
colored_img = torch_normalized_lab_to_rgb( x_hat ) |
|
|
|
return colored_img.detach().cpu().squeeze().permute(1,2,0).numpy() |
|
|
|
def process_image(img): |
|
return colorize_img( (img / 255).astype('float32'), model ) |
|
|
|
image = gr.Image() |
|
label = gr.Label() |
|
title = "Colorizer" |
|
description = "A model that colorizes b&w images." |
|
interpretation='default' |
|
enable_queue=True |
|
|
|
examples = ['ka0001.jpg', 'ka0003.jpg', 'ka0009.jpg', 'ka0010.jpg'] |
|
css = ".h-60 {min-height: 512px !important;}" |
|
|
|
gr.Interface(fn=process_image, |
|
inputs=gr.Image(), |
|
outputs=gr.Image(), |
|
title=title, |
|
description=description, |
|
css=css, |
|
examples=examples).launch() |