File size: 2,225 Bytes
00b1139 56346a3 00b1139 8434f66 00b1139 d393972 00b1139 970bb54 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import gradio as gr
import torch
import skimage
import skimage.io
from skimage.transform import rescale, resize
from skimage import io, color
import cv2
from colorizer import normalize_lab_channels, torch_normalized_lab_to_rgb
model = torch.load('colorizer.pth', weights_only=False)
model = model.eval()
def colorize_img( img_float32, model, res=512, border=0.2, apply_blur=True):
img = img_float32
# Add white border
border = int(img.shape[0] * border)
img2 = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, value=(1.,1.,1.))
# Resize to expected resolution
img_resized = resize(img2, (res,res), anti_aliasing=True)
# Blur image a bit
if apply_blur:
img2 = skimage.filters.gaussian( img_resized, sigma=1, channel_axis=-1 )
else:
img2 = img_resized
# Convert to Lab color space and normalize between 0 and 1 all channels
img2 = normalize_lab_channels(color.rgb2lab(img2))
# Keep copy of unblured + resized image
img_resized = normalize_lab_channels(color.rgb2lab(img_resized))
img_resized = torch.from_numpy(img_resized)
img_resized = img_resized.permute(2,0,1).unsqueeze(dim=0)
# Convert to expected tensor of 'LLL'
x = torch.from_numpy(img2)
x = x.permute(2,0,1).unsqueeze(dim=0)
x[:,1,:,:] = x[:,0,:,:]
x[:,2,:,:] = x[:,0,:,:]
x_hat_ab = model( x )
x_hat = img_resized.clone()
x_hat[:,1:,:,:] = x_hat_ab.clone()
colored_img = torch_normalized_lab_to_rgb( x_hat )
return colored_img.detach().cpu().squeeze().permute(1,2,0).numpy()
def process_image(img):
return colorize_img( (img / 255).astype('float32'), model )
image = gr.Image()
label = gr.Label()
title = "Colorizer"
description = "A model that colorizes b&w images."
interpretation='default'
enable_queue=True
examples = ['ka0001.jpg', 'ka0003.jpg', 'ka0009.jpg', 'ka0010.jpg']
css = ".h-60 {min-height: 512px !important;}"
gr.Interface(fn=process_image,
inputs=gr.Image(),
outputs=gr.Image(),
title=title,
description=description,
css=css,
examples=examples).launch() |