| from models.model import Model as AutoLink | |
| import gradio as gr | |
| import PIL | |
| import torch | |
| import os | |
| import imageio | |
| import numpy as np | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| autolink = AutoLink.load_from_checkpoint(os.path.join("checkpoints", "celeba_wild_k32_m0.8_b16_t0.00075_sklr512", "model.ckpt")) | |
| autolink.to(device) | |
| def predict_image(image_in: PIL.Image.Image) -> PIL.Image.Image: | |
| if image_in == None: | |
| raise gr.Error("Please upload a video or image.") | |
| edge_map = autolink(image_in) | |
| return edge_map | |
| def predict_video(video_in: str) -> str: | |
| if video_in == None: | |
| raise gr.Error("Please upload a video or image.") | |
| video_out = video_in[:-4] + '_out.mp4' | |
| video_in = imageio.get_reader(video_in) | |
| writer = imageio.get_writer(video_out, mode='I', fps=video_in.get_meta_data()['fps']) | |
| for image_in in video_in: | |
| image_in = PIL.Image.fromarray(image_in) | |
| edge_map = autolink(image_in) | |
| writer.append_data(np.array(edge_map)) | |
| writer.close() | |
| return video_out | |
| with gr.Blocks() as blocks: | |
| gr.Markdown(""" | |
| # AutoLink | |
| ## Self-supervised Learning of Human Skeletons and Object Outlines by Linking Keypoints | |
| * [Paper](https://arxiv.org/abs/2205.10636) | |
| * [Project Page](https://xingzhehe.github.io/autolink/) | |
| * [GitHub](https://github.com/xingzhehe/AutoLink-Self-supervised-Learning-of-Human-Skeletons-and-Object-Outlines-by-Linking-Keypoints) | |
| """) | |
| with gr.Tab("Image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_in = gr.Image(source="upload", type="pil", visible=True) | |
| with gr.Column(): | |
| image_out = gr.Image() | |
| run_btn = gr.Button("Run") | |
| run_btn.click(fn=predict_image, inputs=[image_in], outputs=[image_out]) | |
| gr.Examples(fn=predict_image, examples=[["assets/jackie_chan.jpg", None]], | |
| inputs=[image_in], outputs=[image_out], | |
| cache_examples=False) | |
| with gr.Tab("Video") as tab: | |
| with gr.Row(): | |
| with gr.Column(): | |
| video_in = gr.Video(source="upload", type="mp4") | |
| with gr.Column(): | |
| video_out = gr.Video() | |
| run_btn = gr.Button("Run") | |
| run_btn.click(fn=predict_video, inputs=[video_in], outputs=[video_out]) | |
| gr.Examples(fn=predict_video, examples=[["assets/00344.mp4"],], | |
| inputs=[video_in], outputs=[video_out], | |
| cache_examples=False) | |
| blocks.launch() | |