Spaces:
Runtime error
Runtime error
File size: 1,694 Bytes
a0cc3ab 4a10914 2620eb0 a0cc3ab e189f0a e1b37de 2620eb0 28801c1 fb94260 a0cc3ab e189f0a 2620eb0 4a10914 e189f0a 2620eb0 e189f0a 2620eb0 e189f0a 4a10914 e189f0a a0cc3ab 4a10914 2620eb0 a0cc3ab 2620eb0 a0cc3ab 2620eb0 4a10914 a0cc3ab e189f0a a0cc3ab 28801c1 8d6bc60 a0cc3ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms
from utils import normalize_lab, denormalize_lab, pad_image
from model import Generator
import kornia.color as color
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Generator()
model_weights = torch.load('model.pth', map_location=device, weights_only=True)
model.load_state_dict(model_weights)
model = model.to(device)
model.eval()
def preprocess(image):
image = image.convert('RGB')
image = pad_image(image)
transform = transforms.Compose([
transforms.ToTensor(),
])
image = transform(image)
image = image.to(device)
image = color.rgb_to_lab(image)
L = image[[0], ...]
L, _ = normalize_lab(L, 0)
return L.unsqueeze(0)
def crop_to_original_size(image, original_size):
width, height = original_size
return transforms.functional.crop(image, top=0, left=0, height=height, width=width)
def predict(image):
original_size = image.size
L = preprocess(image)
with torch.no_grad():
output = model(L)
L, ab = denormalize_lab(L, output)
output = torch.cat([L, ab], dim=1)
output = color.lab_to_rgb(output)
output = crop_to_original_size(output, original_size)
image = transforms.ToPILImage()(output.squeeze().cpu())
return image
iface = gr.Interface(fn=predict,
inputs=gr.Image(type="pil"),
outputs=gr.Image(type="pil"),
title="Photo Colorizer",
description="This model colorizes grayscale images. Upload an image and see the magic happen! (works best with 256x256 size)",)
iface.launch() |