Spaces:
Runtime error
Runtime error
import sys | |
from typing import Dict | |
sys.path.insert(0, 'gradio-modified') | |
import gradio as gr | |
import numpy as np | |
from PIL import Image | |
import torch | |
if torch.cuda.is_available(): | |
t = torch.cuda.get_device_properties(0).total_memory | |
r = torch.cuda.memory_reserved(0) | |
a = torch.cuda.memory_allocated(0) | |
f = t-a # free inside reserved | |
if f < 2**32: | |
device = 'cpu' | |
else: | |
device = 'cuda' | |
else: | |
device = 'cpu' | |
print('Use device:', device) | |
net = torch.jit.load(f'weights/pkp-v1.{device}.jit.pt') | |
def resize_original(img: Image.Image): | |
if img is None: | |
return img | |
if isinstance(img, dict): | |
img = img["image"] | |
guide_img = img.convert('L') | |
w, h = guide_img.size | |
scale = 256 / min(guide_img.size) | |
guide_img = guide_img.resize([int(round(s*scale)) for s in guide_img.size], Image.Resampling.LANCZOS) | |
guide = np.asarray(guide_img) | |
h, w = guide.shape[-2:] | |
rows = int(np.ceil(h/64))*64 | |
cols = int(np.ceil(w/64))*64 | |
ph_1 = (rows-h) // 2 | |
ph_2 = rows-h - (rows-h) // 2 | |
pw_1 = (cols-w) // 2 | |
pw_2 = cols-w - (cols-w) // 2 | |
guide = np.pad(guide, ((ph_1, ph_2), (pw_1, pw_2)), mode='constant', constant_values=255) | |
guide_img = Image.fromarray(guide) | |
return gr.Image.update(value=guide_img.convert('RGBA')), guide_img.convert('RGBA') | |
def colorize(img: Dict[str, Image.Image], guide_img: Image.Image, seed: int, hint_mode: str): | |
if not isinstance(img, dict): | |
return gr.update(visible=True) | |
if hint_mode == "Roughly Hint": | |
hint_mode_int = 0 | |
elif hint_mode == "Precisely Hint": | |
hint_mode_int = 1 | |
guide_img = guide_img.convert('L') | |
hint_img = img["mask"].convert('RGBA') # I modified gradio to enable it upload colorful mask | |
guide = torch.from_numpy(np.asarray(guide_img))[None,None].float().to(device) / 255.0 * 2 - 1 | |
hint = torch.from_numpy(np.asarray(hint_img)).permute(2,0,1)[None].float().to(device) / 255.0 * 2 - 1 | |
hint_alpha = (hint[:,-1:] > 0.99).float() | |
hint = hint[:,:3] * hint_alpha - 2 * (1 - hint_alpha) | |
np.random.seed(int(seed)) | |
b, c, h, w = hint.shape | |
h //= 8 | |
w //= 8 | |
noises = [torch.from_numpy(np.random.randn(b, c, h, w)).float().to(device) for _ in range(16+1)] | |
with torch.inference_mode(): | |
sample = net(noises, guide, hint, hint_mode_int) | |
out = sample[0].cpu().numpy().transpose([1,2,0]) | |
out = np.uint8(((out + 1) / 2 * 255).clip(0,255)) | |
return Image.fromarray(out).convert('RGB') | |
with gr.Blocks() as demo: | |
gr.Markdown('''<center><h1>Anime Colorization With Hint</h1></center> | |
<h2>Colorize your anime sketches with hint points.</h2> | |
This is a modified version of | |
<a href="https://github.com/HighCWu/pixel-guide-diffusion-for-anime-colorization"> | |
HighCWu/pixel-guide-diffusion-for-anime-colorization | |
</a> with hint points inputs.<br /> | |
''') | |
with gr.Row(): | |
with gr.Column(): | |
inp = gr.Image( | |
source="upload", | |
tool="sketch", # tool="color-sketch", # color-sketch upload image mixed with the original | |
type="pil", | |
label="Sketch", | |
interactive=True, | |
elem_id="sketch-canvas" | |
) | |
inp_store = gr.Image( | |
type="pil", | |
interactive=False | |
) | |
inp_store.visible = False | |
with gr.Column(): | |
seed = gr.Slider(1, 2**32, step=1, label="Seed", interactive=True, randomize=True) | |
hint_mode = gr.Radio(["Roughly Hint", "Precisely Hint"], value="Roughly Hint", label="Hint Mode") | |
btn = gr.Button("Run") | |
with gr.Column(): | |
output = gr.Image(type="pil", label="Output", interactive=False) | |
gr.Markdown(''' | |
PS: Worse than the no hint version I thought. Probably because my model is underfitting in the super-resolution part<br /> | |
I modified a little gradio codes for uploading the colorful hint points. | |
''') | |
gr.Markdown( | |
'<center><img src="https://visitor-badge.glitch.me/badge?page_id=highcwu.anime-colorization-with-hint" alt="visitor badge"/></center>' | |
) | |
inp.upload( | |
resize_original, | |
inp, | |
[inp, inp_store], | |
) | |
btn.click( | |
colorize, | |
[inp, inp_store, seed, hint_mode], | |
output | |
) | |
if __name__ == "__main__": | |
demo.launch() | |