File size: 2,235 Bytes
147a8af
 
e5285b0
4820090
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e5285b0
4820090
e5285b0
 
 
 
4820090
8ec48a4
e5285b0
 
 
 
 
 
4820090
 
8ec48a4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import gradio as gr
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision import transforms
from skimage.color import rgb2lab, lab2rgb
import numpy as np
import requests
from io import BytesIO

repo_id = "Hammad712/GAN-Colorization-Model"
model_filename = "generator.pt"
model_path = hf_hub_download(repo_id=repo_id, filename=model_filename)

from fastai.vision.learner import create_body
from torchvision.models import resnet34
from fastai.vision.models.unet import DynamicUnet

def build_generator(n_input=1, n_output=2, size=256):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    backbone = create_body(resnet34(), pretrained=True, n_in=n_input, cut=-2)
    G_net = DynamicUnet(backbone, n_output, (size, size)).to(device)
    return G_net

# Initialize and load the model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
G_net = build_generator(n_input=1, n_output=2, size=256)
G_net.load_state_dict(torch.load(model_path, map_location=device))
G_net.eval()

def preprocess_image(img):
    img = img.convert("RGB")
    img = transforms.Resize((256, 256), Image.BICUBIC)(img)
    img = np.array(img)
    img_to_lab = rgb2lab(img).astype("float32")
    img_to_lab = transforms.ToTensor()(img_to_lab)
    L = img_to_lab[[0], ...] / 50. - 1.
    return L.unsqueeze(0).to(device)

def colorize_image(img, model):
    L = preprocess_image(img)
    with torch.no_grad():
        ab = model(L)
    L = (L + 1.) * 50.
    ab = ab * 110.
    Lab = torch.cat([L, ab], dim=1).permute(0, 2, 3, 1).cpu().numpy()
    rgb_imgs = []
    for img in Lab:
        img_rgb = lab2rgb(img)
        rgb_imgs.append(img_rgb)
    return np.stack(rgb_imgs, axis=0)

def colorize(img):
    colorized_images = colorize_image(img, G_net)
    colorized_image = colorized_images[0]
    return Image.fromarray((colorized_image * 255).astype(np.uint8))

app = gr.Interface(
    fn=colorize,
    inputs=gr.Image(type="pil", label="Upload Grayscale Image"),
    outputs=gr.Image(type="pil", label="Colorized Image"),
    title="AI Image Colorization",
    description="Upload a black and white image, and the AI will colorize it.",
    allow_flagging="never"
)

app.launch()