import gradio as gr import os, requests import numpy as np import torch.nn.functional as F from model.model import ResHalf from inference import Inferencer from utils import util ## local | remote RUN_MODE = "remote" if RUN_MODE != "local": os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar") os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar") ## examples os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/girl.png") os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/wave.png") os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/painting.png") ## step 1: set up model device = "cpu" checkpt_path = "checkpoints/model_best.pth.tar" invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False) def prepare_data(input_img, decoding_only=False): input_img = np.array(input_img / 255., np.float32) if decoding_only: input_img = input_img[:,:,:1] input_img = util.img2tensor(input_img * 2. - 1.) return input_img def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"): input_img = prepare_data(input_img, decoding_only) input_img = input_img.to(device) if decoding_only: print('>>>:restoration mode') resColor = invhalfer(input_img, decoding_only=decoding_only) output = util.tensor2img(resColor / 2. + 0.5) * 255. else: print('>>>:halftoning mode') resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only) output = util.tensor2img(resHalftone / 2. + 0.5) * 255. return np.clip(output, 0, 255).astype(np.uint8) def click_run(input_img, decoding_only): output = run_invhalf(invhalfer, input_img, decoding_only, device) return output ## step 2: configure interface demo = gr.Blocks(title="ReversibleHalftoning") with demo: gr.Markdown(value=""" **Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛. """) with gr.Row(): with gr.Column(): Image_input = gr.Image(type="numpy", label="Input", interactive=True) with gr.Row(): Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \ label="Choose a running mode", value="Halftoning (Photo2Halftone)") Button_run = gr.Button(value="Run") with gr.Column(): Image_output = gr.Image(type="numpy", label="Output").style(height=480) Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output) if RUN_MODE != "local": gr.Examples(examples=[ ['girl.png', "Halftoning (Photo2Halftone)"], ['wave.png', "Halftoning (Photo2Halftone)"], ['painting.png', "Restoration (Halftone2Photo)"], ], inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples") if RUN_MODE == "local": demo.launch(server_name='9.134.253.83',server_port=7788) else: demo.launch()