Erwann Millon commited on
Commit
fb97152
1 Parent(s): 8ed6132
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import pipeline
3
+ from PLModel import PLColorDiff
4
+ from dataset import ColorizationDataset
5
+ from utils import lab_to_rgb, split_lab
6
+ import torch
7
+ import default_configs
8
+ from icecream import ic
9
+
10
+ from unet import SimpleUnet
11
+ def get_image(image):
12
+ print(image)
13
+ dataset = ColorizationDataset([image], split="val", config=conf, size=128)
14
+ lab_img = dataset.get_tensor_from_path(image)
15
+ batch = lab_img.unsqueeze(0)
16
+ # x_l, _ = split_lab(batch)
17
+ # bw = torch.cat((x_l, *[torch.zeros_like(x_l)] * 2), dim=1)
18
+ model.eval()
19
+ img = model.sample_plot_image(batch)
20
+ rgb_img = lab_to_rgb(*split_lab(img))
21
+ # model.test_step(batch)
22
+ return(rgb_img[0])
23
+ conf = SimpleUnetConfig = dict (
24
+ # device = get_device(),
25
+ device = "mps",
26
+ pin_memory = torch.cuda.is_available(),
27
+ T=300,
28
+ lr=5e-4,
29
+ batch_size=64,
30
+ img_size = 128,
31
+ sample=False,
32
+ log=False,
33
+ should_log=False,
34
+ sample_fn = None,
35
+ val_every=20,
36
+ epochs=100,
37
+ using_cond=False
38
+ )
39
+ ckpt_path = "checkpoints/epoch=1-step=706.ckpt"
40
+ ckpt = torch.load(ckpt_path, map_location=torch.device("mps"))
41
+ unet = SimpleUnet()
42
+ model = PLColorDiff(unet, None, None)
43
+ ic.disable()
44
+ model.load_state_dict(ckpt["state_dict"])
45
+
46
+ demo = gr.Interface(
47
+ get_image,
48
+ inputs=gr.inputs.Image(label="Upload a black and white face", type="filepath"),
49
+ outputs="image",
50
+ title="Upload a black and white face and get a colorized image!",
51
+ )
52
+
53
+ demo.launch()