Spaces:
Running
on
Zero
Running
on
Zero
from gradio_imageslider import ImageSlider | |
import functools | |
import os | |
import tempfile | |
import diffusers | |
import gradio as gr | |
import imageio as imageio | |
import numpy as np | |
import spaces | |
import torch as torch | |
from PIL import Image | |
from tqdm import tqdm | |
from pathlib import Path | |
import gradio | |
from gradio.utils import get_cache_folder | |
from infer import lotus, lotus_video | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def infer(path_input, seed): | |
name_base, name_ext = os.path.splitext(os.path.basename(path_input)) | |
output_g, output_d = lotus(path_input, 'normal', seed, device) | |
if not os.path.exists("files/output"): | |
os.makedirs("files/output") | |
g_save_path = os.path.join("files/output", f"{name_base}_g{name_ext}") | |
d_save_path = os.path.join("files/output", f"{name_base}_d{name_ext}") | |
output_g.save(g_save_path) | |
output_d.save(d_save_path) | |
return [path_input, g_save_path], [path_input, d_save_path] | |
def infer_video(path_input, seed): | |
frames_g, frames_d = lotus_video(path_input, 'normal', seed, device) | |
if not os.path.exists("files/output"): | |
os.makedirs("files/output") | |
name_base, _ = os.path.splitext(os.path.basename(path_input)) | |
g_save_path = os.path.join("files/output", f"{name_base}_g.mp4") | |
d_save_path = os.path.join("files/output", f"{name_base}_d.mp4") | |
imageio.mimsave(g_save_path, frames_g) | |
imageio.mimsave(d_save_path, frames_d) | |
return [g_save_path, d_save_path] | |
def run_demo_server(): | |
gradio_theme = gr.themes.Default() | |
with gr.Blocks( | |
theme=gradio_theme, | |
title="LOTUS (Normal)", | |
css=""" | |
#download { | |
height: 118px; | |
} | |
.slider .inner { | |
width: 5px; | |
background: #FFF; | |
} | |
.viewport { | |
aspect-ratio: 4/3; | |
} | |
.tabs button.selected { | |
font-size: 20px !important; | |
color: crimson !important; | |
} | |
h1 { | |
text-align: center; | |
display: block; | |
} | |
h2 { | |
text-align: center; | |
display: block; | |
} | |
h3 { | |
text-align: center; | |
display: block; | |
} | |
.md_feedback li { | |
margin-bottom: 0px !important; | |
} | |
""", | |
head=""" | |
<script async src="https://www.googletagmanager.com/gtag/js?id=G-1FWSVCGZTG"></script> | |
<script> | |
window.dataLayer = window.dataLayer || []; | |
function gtag() {dataLayer.push(arguments);} | |
gtag('js', new Date()); | |
gtag('config', 'G-1FWSVCGZTG'); | |
</script> | |
""", | |
) as demo: | |
gr.Markdown( | |
""" | |
# LOTUS: Diffusion-based Visual Foundation Model for High-quality Dense Prediction | |
<p align="center"> | |
<a title="Page" href="https://lotus3d.github.io/" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/badge/Project-Website-pink?logo=googlechrome&logoColor=white"> | |
</a> | |
<a title="arXiv" href="https://arxiv.org/abs/2409.18124" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/badge/arXiv-Paper-b31b1b?logo=arxiv&logoColor=white"> | |
</a> | |
<a title="Github" href="https://github.com/EnVision-Research/Lotus" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://img.shields.io/github/stars/EnVision-Research/Lotus?label=GitHub%20%E2%98%85&logo=github&color=C8C" alt="badge-github-stars"> | |
</a> | |
<a title="Social" href="https://x.com/haodongli00/status/1839524569058582884" target="_blank" rel="noopener noreferrer" style="display: inline-block;"> | |
<img src="https://www.obukhov.ai/img/badges/badge-social.svg" alt="social"> | |
</a> | |
""" | |
) | |
with gr.Tabs(elem_classes=["tabs"]): | |
with gr.Tab("IMAGE"): | |
with gr.Row(): | |
with gr.Column(): | |
image_input = gr.Image( | |
label="Input Image", | |
type="filepath", | |
) | |
seed = gr.Number( | |
label="Seed (only for Generative mode)", | |
minimum=0, | |
maximum=999999999, | |
) | |
with gr.Row(): | |
image_submit_btn = gr.Button( | |
value="Predict Normal!", variant="primary" | |
) | |
image_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
image_output_g = ImageSlider( | |
label="Output (Generative)", | |
type="filepath", | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
) | |
with gr.Row(): | |
image_output_d = ImageSlider( | |
label="Output (Discriminative)", | |
type="filepath", | |
interactive=False, | |
elem_classes="slider", | |
position=0.25, | |
) | |
gr.Examples( | |
fn=infer, | |
examples=sorted([ | |
[os.path.join("files", "images", name), 0] | |
for name in os.listdir(os.path.join("files", "images")) | |
]), | |
inputs=[image_input, seed], | |
outputs=[image_output_g, image_output_d], | |
cache_examples=True, | |
) | |
with gr.Tab("VIDEO"): | |
with gr.Row(): | |
with gr.Column(): | |
input_video = gr.Video( | |
label="Input Video", | |
autoplay=True, | |
loop=True, | |
) | |
seed = gr.Number( | |
label="Seed (only for Generative mode)", | |
minimum=0, | |
maximum=999999999, | |
) | |
with gr.Row(): | |
video_submit_btn = gr.Button( | |
value="Predict Normal!", variant="primary" | |
) | |
video_reset_btn = gr.Button(value="Reset") | |
with gr.Column(): | |
video_output_g = gr.Video( | |
label="Output (Generative)", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
) | |
with gr.Row(): | |
video_output_d = gr.Video( | |
label="Output (Discriminative)", | |
interactive=False, | |
autoplay=True, | |
loop=True, | |
show_share_button=True, | |
) | |
gr.Examples( | |
fn=infer_video, | |
examples=sorted([ | |
[os.path.join("files", "videos", name), 0] | |
for name in os.listdir(os.path.join("files", "videos")) | |
]), | |
inputs=[input_video, seed], | |
outputs=[video_output_g, video_output_d], | |
cache_examples=True, | |
) | |
### Image | |
image_submit_btn.click( | |
fn=infer, | |
inputs=[image_input, seed], | |
outputs=[image_output_g, image_output_d], | |
) | |
image_reset_btn.click( | |
fn=lambda: (None, None, None), | |
inputs=[], | |
outputs=[image_output_g, image_output_d], | |
queue=False, | |
) | |
### Video | |
video_submit_btn.click( | |
fn=infer_video, | |
inputs=[input_video, seed], | |
outputs=[video_output_g, video_output_d], | |
queue=True, | |
) | |
video_reset_btn.click( | |
fn=lambda: (None, None, None), | |
inputs=[], | |
outputs=[video_output_g, video_output_d], | |
) | |
### Server launch | |
demo.queue( | |
api_open=False, | |
).launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
) | |
def main(): | |
os.system("pip freeze") | |
if os.path.exists("files/output"): | |
os.system("rm -rf files/output") | |
run_demo_server() | |
if __name__ == "__main__": | |
main() | |