Colorizer / app.py
ialhashim's picture
Update app.py
970bb54 verified
import gradio as gr
import torch
import skimage
import skimage.io
from skimage.transform import rescale, resize
from skimage import io, color
import cv2
from colorizer import normalize_lab_channels, torch_normalized_lab_to_rgb
model = torch.load('colorizer.pth', weights_only=False)
model = model.eval()
def colorize_img( img_float32, model, res=512, border=0.2, apply_blur=True):
img = img_float32
# Add white border
border = int(img.shape[0] * border)
img2 = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, value=(1.,1.,1.))
# Resize to expected resolution
img_resized = resize(img2, (res,res), anti_aliasing=True)
# Blur image a bit
if apply_blur:
img2 = skimage.filters.gaussian( img_resized, sigma=1, channel_axis=-1 )
else:
img2 = img_resized
# Convert to Lab color space and normalize between 0 and 1 all channels
img2 = normalize_lab_channels(color.rgb2lab(img2))
# Keep copy of unblured + resized image
img_resized = normalize_lab_channels(color.rgb2lab(img_resized))
img_resized = torch.from_numpy(img_resized)
img_resized = img_resized.permute(2,0,1).unsqueeze(dim=0)
# Convert to expected tensor of 'LLL'
x = torch.from_numpy(img2)
x = x.permute(2,0,1).unsqueeze(dim=0)
x[:,1,:,:] = x[:,0,:,:]
x[:,2,:,:] = x[:,0,:,:]
x_hat_ab = model( x )
x_hat = img_resized.clone()
x_hat[:,1:,:,:] = x_hat_ab.clone()
colored_img = torch_normalized_lab_to_rgb( x_hat )
return colored_img.detach().cpu().squeeze().permute(1,2,0).numpy()
def process_image(img):
return colorize_img( (img / 255).astype('float32'), model )
image = gr.Image()
label = gr.Label()
title = "Colorizer"
description = "A model that colorizes b&w images."
interpretation='default'
enable_queue=True
examples = ['ka0001.jpg', 'ka0003.jpg', 'ka0009.jpg', 'ka0010.jpg']
css = ".h-60 {min-height: 512px !important;}"
gr.Interface(fn=process_image,
inputs=gr.Image(),
outputs=gr.Image(),
title=title,
description=description,
css=css,
examples=examples).launch()