#!/usr/bin/env python from __future__ import annotations import json import shlex import subprocess import gradio as gr def run(image_path: str, class_index: int, sigma_y: float) -> str: out_name = image_path.split("/")[-1].split(".")[0] subprocess.run( shlex.split( f"python main.py --config confs/inet256.yml --deg colorization --scale 1 --class {class_index} --path_y {image_path} --save_path {out_name} --sigma_y {sigma_y}" ), cwd="DDNM/hq_demo", ) return f"DDNM/hq_demo/results/{out_name}/final/00000.png" def create_demo(): examples = [ [ "sample_images/monarch_gray.png", "monarch, monarch butterfly, milkweed butterfly, Danaus plexippus", 0, ], [ "sample_images/tiger_gray.png", "tiger, Panthera tigris", 0, ], ] with open("imagenet_classes.json") as f: imagenet_class_names = json.load(f) with gr.Blocks() as demo: with gr.Row(): with gr.Column(): image = gr.Image(label="Input image", type="filepath") class_index = gr.Dropdown(label="Class name", choices=imagenet_class_names, type="index", value=950) sigma_y = gr.Number(label="sigma_y", value=0, precision=2) run_button = gr.Button("Run") with gr.Column(): result = gr.Image(label="Result", type="filepath") gr.Examples( examples=examples, inputs=[ image, class_index, sigma_y, ], ) run_button.click( fn=run, inputs=[ image, class_index, sigma_y, ], outputs=result, ) return demo