menghanxia's picture
fixed checkpoint loading requires GPU issue
40d12a9
raw
history blame
3.47 kB
import gradio as gr
import os, requests
import numpy as np
import torch.nn.functional as F
from model.model import ResHalf
from inference import Inferencer
from utils import util
## local | remote
RUN_MODE = "remote"
if RUN_MODE != "local":
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/model_best.pth.tar")
os.rename("model_best.pth.tar", "./checkpoints/model_best.pth.tar")
## examples
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/girl.png")
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/wave.png")
os.system("wget https://huggingface.co/menghanxia/ReversibleHalftoning/resolve/main/painting.png")
## step 1: set up model
device = "cpu"
checkpt_path = "checkpoints/model_best.pth.tar"
invhalfer = Inferencer(checkpoint_path=checkpt_path, model=ResHalf(train=False), use_cuda=False, multi_gpu=False)
def prepare_data(input_img, decoding_only=False):
input_img = np.array(input_img / 255., np.float32)
if decoding_only:
input_img = input_img[:,:,:1]
input_img = util.img2tensor(input_img * 2. - 1.)
return input_img
def run_invhalf(invhalfer, input_img, decoding_only, device="cuda"):
input_img = prepare_data(input_img, decoding_only)
input_img = input_img.to(device)
if decoding_only:
print('>>>:restoration mode')
resColor = invhalfer(input_img, decoding_only=decoding_only)
output = util.tensor2img(resColor / 2. + 0.5) * 255.
else:
print('>>>:halftoning mode')
resHalftone, resColor = invhalfer(input_img, decoding_only=decoding_only)
output = util.tensor2img(resHalftone / 2. + 0.5) * 255.
return np.clip(output, 0, 255).astype(np.uint8)
def click_run(input_img, decoding_only):
output = run_invhalf(invhalfer, input_img, decoding_only, device)
return output
## step 2: configure interface
demo = gr.Blocks(title="ReversibleHalftoning")
with demo:
gr.Markdown(value="""
**Gradio demo for ReversibleHalftoning: Deep Halftoning with Reversible Binary Pattern**. Check our [github page](https://github.com/MenghanXia/ReversibleHalftoning) 😛.
""")
with gr.Row():
with gr.Column():
Image_input = gr.Image(type="numpy", label="Input", interactive=True)
with gr.Row():
Radio_mode = gr.Radio(type="index", choices=["Halftoning (Photo2Halftone)", "Restoration (Halftone2Photo)"], \
label="Choose a running mode", value="Halftoning (Photo2Halftone)")
Button_run = gr.Button(value="Run")
with gr.Column():
Image_output = gr.Image(type="numpy", label="Output").style(height=480)
Button_run.click(fn=click_run, inputs=[Image_input, Radio_mode], outputs=Image_output)
if RUN_MODE != "local":
gr.Examples(examples=[
['girl.png', "Halftoning (Photo2Halftone)"],
['wave.png', "Halftoning (Photo2Halftone)"],
['painting.png', "Restoration (Halftone2Photo)"],
],
inputs=[Image_input,Radio_mode], outputs=[Image_output], label="Examples")
if RUN_MODE == "local":
demo.launch(server_name='9.134.253.83',server_port=7788)
else:
demo.launch()