File size: 2,225 Bytes
00b1139
 
 
 
 
 
 
 
 
 
 
 
56346a3
00b1139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8434f66
 
00b1139
 
 
 
 
 
 
 
 
d393972
 
00b1139
 
 
970bb54
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import gradio as gr

import torch

import skimage
import skimage.io
from skimage.transform import rescale, resize
from skimage import io, color

import cv2
from colorizer import normalize_lab_channels, torch_normalized_lab_to_rgb

model = torch.load('colorizer.pth', weights_only=False)
model = model.eval()

def colorize_img( img_float32, model, res=512, border=0.2, apply_blur=True):
    img = img_float32
    
    # Add white border
    border = int(img.shape[0] * border)
    img2 = cv2.copyMakeBorder(img, border, border, border, border, cv2.BORDER_CONSTANT, value=(1.,1.,1.))
    
    # Resize to expected resolution
    img_resized = resize(img2, (res,res), anti_aliasing=True) 
    
    # Blur image a bit
    if apply_blur: 
        img2 = skimage.filters.gaussian( img_resized, sigma=1, channel_axis=-1 )   
    else:
        img2 = img_resized
        
    # Convert to Lab color space and normalize between 0 and 1 all channels
    img2 = normalize_lab_channels(color.rgb2lab(img2)) 
    
    # Keep copy of unblured + resized image
    img_resized = normalize_lab_channels(color.rgb2lab(img_resized)) 
    img_resized = torch.from_numpy(img_resized)
    img_resized = img_resized.permute(2,0,1).unsqueeze(dim=0)
    
    # Convert to expected tensor of 'LLL'
    x = torch.from_numpy(img2)
    x = x.permute(2,0,1).unsqueeze(dim=0)
    x[:,1,:,:] = x[:,0,:,:]
    x[:,2,:,:] = x[:,0,:,:]
    
    x_hat_ab = model( x )
    
    x_hat = img_resized.clone()
    x_hat[:,1:,:,:] = x_hat_ab.clone()
    
    colored_img = torch_normalized_lab_to_rgb( x_hat )
    
    return colored_img.detach().cpu().squeeze().permute(1,2,0).numpy()

def process_image(img):
    return colorize_img( (img / 255).astype('float32'), model  )

image = gr.Image()
label = gr.Label()
title = "Colorizer"
description = "A model that colorizes b&w images."
interpretation='default'
enable_queue=True

examples = ['ka0001.jpg', 'ka0003.jpg', 'ka0009.jpg', 'ka0010.jpg']
css = ".h-60 {min-height: 512px !important;}"

gr.Interface(fn=process_image,
             inputs=gr.Image(),
             outputs=gr.Image(),
             title=title,
             description=description,
             css=css,
             examples=examples).launch()