from gradio_imageslider import ImageSlider import functools import os import tempfile import diffusers import gradio as gr import imageio as imageio import numpy as np import spaces import torch as torch from PIL import Image from tqdm import tqdm from pathlib import Path import gradio from gradio.utils import get_cache_folder from infer import lotus, lotus_video device = torch.device("cuda" if torch.cuda.is_available() else "cpu") def infer(path_input, seed=0): name_base, name_ext = os.path.splitext(os.path.basename(path_input)) output_g, output_d = lotus(path_input, 'normal', seed, device) if not os.path.exists("files/output"): os.makedirs("files/output") g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}") d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}") output_g.save(g_save_path) output_d.save(d_save_path) return [path_input, g_save_path], [path_input, d_save_path] def infer_video(path_input, seed=0): frames_g, frames_d = lotus_video(path_input, 'normal', seed, device) if not os.path.exists("files/output"): os.makedirs("files/output") name_base, _ = os.path.splitext(os.path.basename(path_input)) g_save_path = os.path.join("files/output", f"{name_base}_g.mp4") d_save_path = os.path.join("files/output", f"{name_base}_d.mp4") imageio.mimsave(g_save_path, frames_g) imageio.mimsave(d_save_path, frames_d) return [g_save_path, d_save_path] def run_demo_server(): gradio_theme = gr.themes.Default() with gr.Blocks( theme=gradio_theme, title="LOTUS (Normal)", css=""" #download { height: 118px; } .slider .inner { width: 5px; background: #FFF; } .viewport { aspect-ratio: 4/3; } .tabs button.selected { font-size: 20px !important; color: crimson !important; } h1 { text-align: center; display: block; } h2 { text-align: center; display: block; } h3 { text-align: center; display: block; } .md_feedback li { margin-bottom: 0px !important; } """, head=""" """, ) as demo: gr.Markdown( """ # LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction
""" ) with gr.Tabs(elem_classes=["tabs"]): with gr.Tab("IMAGE"): with gr.Row(): with gr.Column(): image_input = gr.Image( label="Input Image", type="filepath", ) seed = gr.Number( label="Seed (only for Generative mode)", minimum=0, maximum=999999999, ) with gr.Row(): image_submit_btn = gr.Button( value="Predict Normal!", variant="primary" ) image_reset_btn = gr.Button(value="Reset") with gr.Column(): image_output_g = ImageSlider( label="Output (Generative)", type="filepath", interactive=False, elem_classes="slider", position=0.25, ) with gr.Row(): image_output_d = ImageSlider( label="Output (Discriminative)", type="filepath", interactive=False, elem_classes="slider", position=0.25, ) gr.Examples( fn=infer, examples=sorted([ os.path.join("files", "images", name) for name in os.listdir(os.path.join("files", "images")) ]), inputs=[image_input], outputs=[image_output_g, image_output_d], cache_examples=True, ) with gr.Tab("VIDEO"): with gr.Row(): with gr.Column(): input_video = gr.Video( label="Input Video", autoplay=True, loop=True, ) seed = gr.Number( label="Seed (only for Generative mode)", minimum=0, maximum=999999999, ) with gr.Row(): video_submit_btn = gr.Button( value="Predict Normal!", variant="primary" ) video_reset_btn = gr.Button(value="Reset") with gr.Column(): video_output_g = gr.Video( label="Output (Generative)", interactive=False, autoplay=True, loop=True, show_share_button=True, ) with gr.Row(): video_output_d = gr.Video( label="Output (Discriminative)", interactive=False, autoplay=True, loop=True, show_share_button=True, ) gr.Examples( fn=infer_video, examples=sorted([ os.path.join("files", "videos", name) for name in os.listdir(os.path.join("files", "videos")) ]), inputs=[input_video], outputs=[video_output_g, video_output_d], cache_examples=True, ) ### Image image_submit_btn.click( fn=infer, inputs=[image_input, seed], outputs=[image_output_g, image_output_d], ) image_reset_btn.click( fn=lambda: (None, None, None), inputs=[], outputs=[image_output_g, image_output_d], queue=False, ) ### Video video_submit_btn.click( fn=infer_video, inputs=[input_video, seed], outputs=[video_output_g, video_output_d], queue=True, ) video_reset_btn.click( fn=lambda: (None, None, None), inputs=[], outputs=[video_output_g, video_output_d], ) ### Server launch demo.queue( api_open=False, ).launch( server_name="0.0.0.0", server_port=7860, ) def main(): os.system("pip freeze") if os.path.exists("files/output"): os.system("rm -rf files/output") run_demo_server() if __name__ == "__main__": main()