col / app.py
Erwann Millon
test
fb97152
raw
history blame
1.48 kB
import gradio as gr
from transformers import pipeline
from PLModel import PLColorDiff
from dataset import ColorizationDataset
from utils import lab_to_rgb, split_lab
import torch
import default_configs
from icecream import ic
from unet import SimpleUnet
def get_image(image):
print(image)
dataset = ColorizationDataset([image], split="val", config=conf, size=128)
lab_img = dataset.get_tensor_from_path(image)
batch = lab_img.unsqueeze(0)
# x_l, _ = split_lab(batch)
# bw = torch.cat((x_l, *[torch.zeros_like(x_l)] * 2), dim=1)
model.eval()
img = model.sample_plot_image(batch)
rgb_img = lab_to_rgb(*split_lab(img))
# model.test_step(batch)
return(rgb_img[0])
conf = SimpleUnetConfig = dict (
# device = get_device(),
device = "mps",
pin_memory = torch.cuda.is_available(),
T=300,
lr=5e-4,
batch_size=64,
img_size = 128,
sample=False,
log=False,
should_log=False,
sample_fn = None,
val_every=20,
epochs=100,
using_cond=False
)
ckpt_path = "checkpoints/epoch=1-step=706.ckpt"
ckpt = torch.load(ckpt_path, map_location=torch.device("mps"))
unet = SimpleUnet()
model = PLColorDiff(unet, None, None)
ic.disable()
model.load_state_dict(ckpt["state_dict"])
demo = gr.Interface(
get_image,
inputs=gr.inputs.Image(label="Upload a black and white face", type="filepath"),
outputs="image",
title="Upload a black and white face and get a colorized image!",
)
demo.launch()