Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from utils import normalize_lab, denormalize_lab, pad_image | |
| from model import Generator | |
| import kornia.color as color | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| model = Generator() | |
| model_weights = torch.load('model.pth', map_location=device, weights_only=True) | |
| model.load_state_dict(model_weights) | |
| model = model.to(device) | |
| model.eval() | |
| def preprocess(image): | |
| image = image.convert('RGB') | |
| image = pad_image(image) | |
| transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| ]) | |
| image = transform(image) | |
| image = image.to(device) | |
| image = color.rgb_to_lab(image) | |
| L = image[[0], ...] | |
| L, _ = normalize_lab(L, 0) | |
| return L.unsqueeze(0) | |
| def crop_to_original_size(image, original_size): | |
| width, height = original_size | |
| return transforms.functional.crop(image, top=0, left=0, height=height, width=width) | |
| def predict(image): | |
| original_size = image.size | |
| L = preprocess(image) | |
| with torch.no_grad(): | |
| output = model(L) | |
| L, ab = denormalize_lab(L, output) | |
| output = torch.cat([L, ab], dim=1) | |
| output = color.lab_to_rgb(output) | |
| output = crop_to_original_size(output, original_size) | |
| image = transforms.ToPILImage()(output.squeeze().cpu()) | |
| return image | |
| iface = gr.Interface(fn=predict, | |
| inputs=gr.Image(type="pil"), | |
| outputs=gr.Image(type="pil"), | |
| title="Photo Colorizer", | |
| description="This model colorizes grayscale images. Upload an image and see the magic happen! (works best with 256x256 size)",) | |
| iface.launch() |