AutoLink / app.py
xingzhehe's picture
try fitst commit
91fc62a
raw
history blame
2.53 kB
from models.model import Model as AutoLink
import gradio as gr
import PIL
import torch
import os
import imageio
import numpy as np
device = 'cuda' if torch.cuda.is_available() else 'cpu'
autolink = AutoLink.load_from_checkpoint(os.path.join("checkpoints", "celeba_wild_k32_m0.8_b16_t0.00075_sklr512", "model.ckpt"))
autolink.to(device)
def predict_image(image_in: PIL.Image.Image) -> PIL.Image.Image:
if image_in == None:
raise gr.Error("Please upload a video or image.")
edge_map = autolink(image_in)
return edge_map
def predict_video(video_in: str) -> str:
if video_in == None:
raise gr.Error("Please upload a video or image.")
video_out = video_in[:-4] + '_out.mp4'
video_in = imageio.get_reader(video_in)
writer = imageio.get_writer(video_out, mode='I', fps=video_in.get_meta_data()['fps'])
for image_in in video_in:
image_in = PIL.Image.fromarray(image_in)
edge_map = autolink(image_in)
writer.append_data(np.array(edge_map))
writer.close()
return video_out
with gr.Blocks() as blocks:
gr.Markdown("""
# AutoLink
## Self-supervised Learning of Human Skeletons and Object Outlines by Linking Keypoints
* [Paper](https://arxiv.org/abs/2205.10636)
* [Project Page](https://xingzhehe.github.io/autolink/)
* [GitHub](https://github.com/xingzhehe/AutoLink-Self-supervised-Learning-of-Human-Skeletons-and-Object-Outlines-by-Linking-Keypoints)
""")
with gr.Tab("Image"):
with gr.Row():
with gr.Column():
image_in = gr.Image(source="upload", type="pil", visible=True)
with gr.Column():
image_out = gr.Image()
run_btn = gr.Button("Run")
run_btn.click(fn=predict_image, inputs=[image_in], outputs=[image_out])
gr.Examples(fn=predict_image, examples=[["assets/jackie_chan.jpg", None]],
inputs=[image_in], outputs=[image_out],
cache_examples=False)
with gr.Tab("Video") as tab:
with gr.Row():
with gr.Column():
video_in = gr.Video(source="upload", type="mp4")
with gr.Column():
video_out = gr.Video()
run_btn = gr.Button("Run")
run_btn.click(fn=predict_video, inputs=[video_in], outputs=[video_out])
gr.Examples(fn=predict_video, examples=[["assets/00344.mp4"],],
inputs=[video_in], outputs=[video_out],
cache_examples=False)
blocks.launch()