Spaces:
Runtime error
Runtime error
import gradio as gr | |
import os, requests | |
import numpy as np | |
import torch.nn.functional as F | |
from model.model import ResHalf | |
from inference import Inferencer | |
from utils import util | |
## local | remote | |
RUN_MODE = "remote" | |
if RUN_MODE != "local": | |
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar") | |
os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar") | |
## examples | |
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/girl.png") | |
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/wave.png") | |
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/painting.png") | |
## step 1: set up model | |
device = "cpu" | |
checkpt_path = "checkpoints/model_best.pth.tar" | |
invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False) | |
def prepare_data(input_img, decoding_only=False): | |
input_img = np.array(input_img / 255., np.float32) | |
if decoding_only: | |
input_img = input_img[:,:,:1] | |
input_img = util.img2tensor(input_img * 2. - 1.) | |
return input_img | |
def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"): | |
input_img = prepare_data(input_img, decoding_only) | |
input_img = input_img.to(device) | |
if decoding_only: | |
print('>>>:restoration mode') | |
resColor = invhalfer(input_img, decoding_only=decoding_only) | |
output = util.tensor2img(resColor / 2. + 0.5) * 255. | |
else: | |
print('>>>:halftoning mode') | |
resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only) | |
output = util.tensor2img(resHalftone / 2. + 0.5) * 255. | |
return np.clip(output, 0, 255).astype(np.uint8) | |
def click_run(input_img, decoding_only): | |
output = run_invhalf(invhalfer, input_img, decoding_only, device) | |
return output | |
## step 2: configure interface | |
demo = gr.Blocks(title="ReversibleHalftoning") | |
with demo: | |
gr.Markdown(value=""" | |
**Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛. | |
""") | |
with gr.Row(): | |
with gr.Column(): | |
Image_input = gr.Image(type="numpy", label="Input", interactive=True) | |
with gr.Row(): | |
Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \ | |
label="Choose a running mode", value="Halftoning (Photo2Halftone)") | |
Button_run = gr.Button(value="Run") | |
with gr.Column(): | |
Image_output = gr.Image(type="numpy", label="Output").style(height=480) | |
Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output) | |
if RUN_MODE != "local": | |
gr.Examples(examples=[ | |
['girl.png', "Halftoning (Photo2Halftone)"], | |
['wave.png', "Halftoning (Photo2Halftone)"], | |
['painting.png', "Restoration (Halftone2Photo)"], | |
], | |
inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples") | |
if RUN_MODE == "local": | |
demo.launch(server_name='9.134.253.83',server_port=7788) | |
else: | |
demo.launch() |