|
import gradio as gr |
|
from transformers import pipeline |
|
from PLModel import PLColorDiff |
|
from dataset import ColorizationDataset |
|
from utils import lab_to_rgb, split_lab |
|
import torch |
|
import default_configs |
|
from icecream import ic |
|
|
|
from unet import SimpleUnet |
|
def get_image(image): |
|
print(image) |
|
dataset = ColorizationDataset([image], split="val", config=conf, size=128) |
|
lab_img = dataset.get_tensor_from_path(image) |
|
batch = lab_img.unsqueeze(0) |
|
|
|
|
|
model.eval() |
|
img = model.sample_plot_image(batch) |
|
rgb_img = lab_to_rgb(*split_lab(img)) |
|
|
|
return(rgb_img[0]) |
|
conf = SimpleUnetConfig = dict ( |
|
|
|
device = "mps", |
|
pin_memory = torch.cuda.is_available(), |
|
T=300, |
|
lr=5e-4, |
|
batch_size=64, |
|
img_size = 128, |
|
sample=False, |
|
log=False, |
|
should_log=False, |
|
sample_fn = None, |
|
val_every=20, |
|
epochs=100, |
|
using_cond=False |
|
) |
|
ckpt_path = "checkpoints/epoch=1-step=706.ckpt" |
|
ckpt = torch.load(ckpt_path, map_location=torch.device("mps")) |
|
unet = SimpleUnet() |
|
model = PLColorDiff(unet, None, None) |
|
ic.disable() |
|
model.load_state_dict(ckpt["state_dict"]) |
|
|
|
demo = gr.Interface( |
|
get_image, |
|
inputs=gr.inputs.Image(label="Upload a black and white face", type="filepath"), |
|
outputs="image", |
|
title="Upload a black and white face and get a colorized image!", |
|
) |
|
|
|
demo.launch() |