|
|
|
|
|
from __future__ import annotations |
|
|
|
import json |
|
import shlex |
|
import subprocess |
|
|
|
import gradio as gr |
|
|
|
|
|
def run(image_path: str, class_index: int, sigma_y: float) -> str: |
|
out_name = image_path.split('/')[-1].split('.')[0] |
|
subprocess.run(shlex.split( |
|
f'python main.py --config confs/inet256.yml --deg colorization --scale 1 --class {class_index} --path_y {image_path} --save_path {out_name} --sigma_y {sigma_y}' |
|
), |
|
cwd='DDNM/hq_demo') |
|
return f'DDNM/hq_demo/results/{out_name}/final/00000.png' |
|
|
|
|
|
def create_demo(): |
|
examples = [ |
|
[ |
|
'sample_images/monarch_gray.png', |
|
'monarch, monarch butterfly, milkweed butterfly, Danaus plexippus', |
|
0, |
|
], |
|
[ |
|
'sample_images/tiger_gray.png', |
|
'tiger, Panthera tigris', |
|
0, |
|
], |
|
] |
|
|
|
with open('imagenet_classes.json') as f: |
|
imagenet_class_names = json.load(f) |
|
|
|
with gr.Blocks() as demo: |
|
with gr.Row(): |
|
with gr.Column(): |
|
image = gr.Image(label='Input image', type='filepath') |
|
class_index = gr.Dropdown(label='Class name', |
|
choices=imagenet_class_names, |
|
type='index', |
|
value=950) |
|
sigma_y = gr.Number(label='sigma_y', value=0, precision=2) |
|
run_button = gr.Button('Run') |
|
with gr.Column(): |
|
result = gr.Image(label='Result', type='filepath') |
|
|
|
gr.Examples( |
|
examples=examples, |
|
inputs=[ |
|
image, |
|
class_index, |
|
sigma_y, |
|
], |
|
) |
|
|
|
run_button.click( |
|
fn=run, |
|
inputs=[ |
|
image, |
|
class_index, |
|
sigma_y, |
|
], |
|
outputs=result, |
|
) |
|
return demo |
|
|