File size: 1,694 Bytes
a0cc3ab
 
 
 
4a10914
2620eb0
 
a0cc3ab
e189f0a
e1b37de
2620eb0
28801c1
fb94260
a0cc3ab
 
 
e189f0a
 
2620eb0
4a10914
 
 
 
e189f0a
2620eb0
 
 
 
e189f0a
2620eb0
e189f0a
4a10914
 
 
 
e189f0a
a0cc3ab
4a10914
2620eb0
a0cc3ab
2620eb0
a0cc3ab
2620eb0
 
 
4a10914
a0cc3ab
 
 
 
e189f0a
a0cc3ab
 
28801c1
8d6bc60
 
a0cc3ab
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from utils import normalize_lab, denormalize_lab, pad_image
from model import Generator
import kornia.color as color


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Generator()
model_weights = torch.load('model.pth', map_location=device, weights_only=True)
model.load_state_dict(model_weights)
model = model.to(device)
model.eval()


def preprocess(image):
    image = image.convert('RGB')
    image = pad_image(image)
    transform = transforms.Compose([
        transforms.ToTensor(),
    ])
    image = transform(image)
    image = image.to(device)
    image = color.rgb_to_lab(image)
    L = image[[0], ...]
    L, _ = normalize_lab(L, 0)

    return L.unsqueeze(0)

def crop_to_original_size(image, original_size):
    width, height = original_size
    return transforms.functional.crop(image, top=0, left=0, height=height, width=width)


def predict(image):
    original_size = image.size
    L = preprocess(image)
    with torch.no_grad():
        output = model(L)

    L, ab = denormalize_lab(L, output)
    output = torch.cat([L, ab], dim=1)
    output = color.lab_to_rgb(output)
    output = crop_to_original_size(output, original_size)
    image = transforms.ToPILImage()(output.squeeze().cpu())

    return image


iface = gr.Interface(fn=predict,
                     inputs=gr.Image(type="pil"),
                     outputs=gr.Image(type="pil"),
                     title="Photo Colorizer",
                     description="This model colorizes grayscale images. Upload an image and see the magic happen! (works best with 256x256 size)",)

iface.launch()