File size: 1,481 Bytes
fb97152
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import gradio as gr
from transformers import pipeline
from PLModel import PLColorDiff
from dataset import ColorizationDataset
from utils import lab_to_rgb, split_lab
import torch
import default_configs
from icecream import ic

from unet import SimpleUnet
def get_image(image):
    print(image)
    dataset = ColorizationDataset([image], split="val", config=conf, size=128)
    lab_img = dataset.get_tensor_from_path(image) 
    batch = lab_img.unsqueeze(0)
    # x_l, _ = split_lab(batch)
    # bw = torch.cat((x_l, *[torch.zeros_like(x_l)] * 2), dim=1)
    model.eval()
    img = model.sample_plot_image(batch)
    rgb_img = lab_to_rgb(*split_lab(img))
    # model.test_step(batch)
    return(rgb_img[0])
conf = SimpleUnetConfig = dict (
    # device = get_device(),
    device = "mps",
    pin_memory = torch.cuda.is_available(),
    T=300,
    lr=5e-4,
    batch_size=64,
    img_size = 128,
    sample=False,
    log=False,
    should_log=False,
    sample_fn = None,
    val_every=20,
    epochs=100,
    using_cond=False
)
ckpt_path = "checkpoints/epoch=1-step=706.ckpt"
ckpt = torch.load(ckpt_path, map_location=torch.device("mps"))
unet = SimpleUnet()
model = PLColorDiff(unet, None, None)
ic.disable()
model.load_state_dict(ckpt["state_dict"])

demo = gr.Interface(
    get_image,
    inputs=gr.inputs.Image(label="Upload a black and white face", type="filepath"),
    outputs="image",
    title="Upload a black and white face and get a colorized image!",
)

demo.launch()