File size: 4,038 Bytes
443d045 0ca1021 443d045 0ca1021 443d045 0ca1021 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 |
import sys
from typing import Dict
sys.path.insert(0, 'gradio-modified')
import gradio as gr
import numpy as np
from PIL import Image
import torch
if torch.cuda.is_available():
t = torch.cuda.get_device_properties(0).total_memory
r = torch.cuda.memory_reserved(0)
a = torch.cuda.memory_allocated(0)
f = t-a # free inside reserved
if f < 2**32:
device = 'cpu'
else:
device = 'cuda'
else:
device = 'cpu'
torch._C._jit_set_bailout_depth(0)
print('Use device:', device)
net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt')
def resize_original(img: Image.Image):
if img is None:
return img
if isinstance(img, dict):
img = img["image"]
guide_img = img.convert('L')
w, h = guide_img.size
scale = 256 / min(guide_img.size)
guide_img = guide_img.resize([int(round(s*scale)) for s in guide_img.size], Image.Resampling.LANCZOS)
guide = np.asarray(guide_img)
h, w = guide.shape[-2:]
rows = int(np.ceil(h/64))*64
cols = int(np.ceil(w/64))*64
ph_1 = (rows-h) // 2
ph_2 = rows-h - (rows-h) // 2
pw_1 = (cols-w) // 2
pw_2 = cols-w - (cols-w) // 2
guide = np.pad(guide, ((ph_1, ph_2), (pw_1, pw_2)), mode='constant', constant_values=255)
guide_img = Image.fromarray(guide)
return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA')
def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str):
if not isinstance(img, dict):
return gr.update(visible=True)
if hint_mode == "Roughly Hint":
hint_mode_int = 0
elif hint_mode == "Precisely Hint":
hint_mode_int = 1
guide_img = guide_img.convert('L')
hint_img = img["mask"].convert('RGBA') # I modified gradio to enable it upload colorful mask
guide = torch.from_numpy(np.asarray(guide_img))[None,None].float().to(device) / 255.0 * 2 - 1
hint = torch.from_numpy(np.asarray(hint_img)).permute(2,0,1)[None].float().to(device) / 255.0 * 2 - 1
hint_alpha = (hint[:,-1:] > 0.99).float()
hint = hint[:,:3] * hint_alpha - 2 * (1 - hint_alpha)
np.random.seed(int(seed))
b, c, h, w = hint.shape
h //= 8
w //= 8
noises = [torch.from_numpy(np.random.randn(b, c, h, w)).float().to(device) for _ in range(16+1)]
with torch.inference_mode():
sample = net(noises, guide, hint, hint_mode_int)
out = sample[0].cpu().numpy().transpose([1,2,0])
out = np.uint8(((out + 1) / 2 * 255).clip(0,255))
return Image.fromarray(out).convert('RGB')
with gr.Blocks() as demo:
gr.Markdown('''<center><h1>Image Colorization With Hint</h1></center>
<h2>Colorize your images/sketches with hint points.</h2>
<br />
''')
with gr.Row():
with gr.Column():
inp = gr.Image(
source="upload",
tool="sketch", # tool="color-sketch", # color-sketch upload image mixed with the original
type="pil",
label="Sketch",
interactive=True,
elem_id="sketch-canvas"
)
inp_store = gr.Image(
type="pil",
interactive=False
)
inp_store.visible = False
with gr.Column():
seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True)
hint_mode = gr.Radio(["Roughly Hint", "Precisely Hint"], value="Roughly Hint", label="Hint Mode")
btn = gr.Button("Run")
with gr.Column():
output = gr.Image(type="pil", label="Output", interactive=False)
gr.Markdown('''
Upon uploading an image, kindly give color hints at specific points, and then run the model. Average inference time is about 52 seconds.
''')
inp.upload(
resize_original,
inp,
[inp, inp_store],
)
btn.click(
colorize,
[inp, inp_store, seed, hint_mode],
output
)
if __name__ == "__main__":
demo.launch(share=True)
|