import gradio as gr import torch import skimage import skimage.io from skimage.transform import rescale, resize from skimage import io, color import cv2 from colorizer import normalize_lab_channels, torch_normalized_lab_to_rgb model = torch.load('colorizer.pth', weights_only=False) model = model.eval() def colorize_img( img_float32, model, res=512, border=0.2, apply_blur=True): img = img_float32 # Add white border border = int(img.shape[0] * border) img2 = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, value=(1.,1.,1.)) # Resize to expected resolution img_resized = resize(img2, (res,res), anti_aliasing=True) # Blur image a bit if apply_blur: img2 = skimage.filters.gaussian( img_resized, sigma=1, channel_axis=-1 ) else: img2 = img_resized # Convert to Lab color space and normalize between 0 and 1 all channels img2 = normalize_lab_channels(color.rgb2lab(img2)) # Keep copy of unblured + resized image img_resized = normalize_lab_channels(color.rgb2lab(img_resized)) img_resized = torch.from_numpy(img_resized) img_resized = img_resized.permute(2,0,1).unsqueeze(dim=0) # Convert to expected tensor of 'LLL' x = torch.from_numpy(img2) x = x.permute(2,0,1).unsqueeze(dim=0) x[:,1,:,:] = x[:,0,:,:] x[:,2,:,:] = x[:,0,:,:] x_hat_ab = model( x ) x_hat = img_resized.clone() x_hat[:,1:,:,:] = x_hat_ab.clone() colored_img = torch_normalized_lab_to_rgb( x_hat ) return colored_img.detach().cpu().squeeze().permute(1,2,0).numpy() def process_image(img): return colorize_img( (img / 255).astype('float32'), model ) image = gr.Image() label = gr.Label() title = "Colorizer" description = "A model that colorizes b&w images." interpretation='default' enable_queue=True examples = ['ka0001.jpg', 'ka0003.jpg', 'ka0009.jpg', 'ka0010.jpg'] css = ".h-60 {min-height: 512px !important;}" gr.Interface(fn=process_image, inputs=gr.Image(), outputs=gr.Image(), title=title, description=description, css=css, examples=examples).launch()