menghanxia's picture
support output to be moved to input
2ac9f6b
raw
history blame contribute delete
No virus
3.91 kB
import gradio as gr
import os, requests
import numpy as np
import torch.nn.functional as F
from model.model import ResHalf
from inference import Inferencer
from utils import util
## local | remote
RUN_MODE = "remote"
if RUN_MODE != "local":
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar")
os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar")
## examples
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/girl.png")
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/wave.png")
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/painting.png")
## step 1: set up model
device = "cpu"
checkpt_path = "checkpoints/model_best.pth.tar"
invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False)
def prepare_data(input_img, decoding_only=False):
input_img = np.array(input_img / 255., np.float32)
if decoding_only:
input_img = input_img[:,:,:1]
input_img = util.img2tensor(input_img * 2. - 1.)
return input_img
def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"):
input_img = prepare_data(input_img, decoding_only)
input_img = input_img.to(device)
if decoding_only:
print('>>>:restoration mode')
resColor = invhalfer(input_img, decoding_only=decoding_only)
output = util.tensor2img(resColor / 2. + 0.5) * 255.
else:
print('>>>:halftoning mode')
resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
return np.clip(output, 0, 255).astype(np.uint8)
def click_run(input_img, decoding_only):
output = run_invhalf(invhalfer, input_img, decoding_only, device)
return output
def click_move(output_img, decoding_only):
if decoding_only:
radio_status = "Halftoning (Photo2Halftone)"
else:
radio_status = "Restoration (Halftone2Photo)"
return output_img, radio_status, None
## step 2: configure interface
demo = gr.Blocks(title="ReversibleHalftoning")
with demo:
gr.Markdown(value="""
**Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) πŸ˜›.
""")
with gr.Row():
with gr.Column():
Image_input = gr.Image(type="numpy", label="Input", interactive=True).style(height=480)
with gr.Row():
Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \
label="Choose a running mode", value="Halftoning (Photo2Halftone)")
Button_run = gr.Button(value="Run")
with gr.Column():
Image_output = gr.Image(type="numpy", label="Output").style(height=480)
Button_move = gr.Button(value="Use it as input")
Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
Button_move.click(fn=click_move, inputs=[Image_output, Radio_mode], outputs=[Image_input, Radio_mode, Image_output])
if RUN_MODE != "local":
gr.Examples(examples=[
['girl.png', "Halftoning (Photo2Halftone)"],
['wave.png', "Halftoning (Photo2Halftone)"],
['painting.png', "Restoration (Halftone2Photo)"],
],
inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")
if RUN_MODE == "local":
demo.launch(server_name='9.134.253.83',server_port=7788)
else:
demo.launch()