File size: 4,229 Bytes
d323598 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
"""Gradio interface for Vista model."""
from __future__ import annotations
import glob
import os
import queue
import threading
import gradio as gr
import gradio_rerun
import rerun as rr
import spaces
import vista
@spaces.GPU(duration=400)
@rr.thread_local_stream("Vista")
def generate_gradio(
first_frame_file_name: str,
n_rounds: float=3,
n_steps: float=10,
height=576,
width=1024,
n_frames=25,
cfg_scale=2.5,
cond_aug=0.0,
):
global model
n_rounds = int(n_rounds)
n_steps = int(n_steps)
# Use a queue to log immediately from internals
log_queue = queue.SimpleQueue()
stream = rr.binary_stream()
blueprint = vista.generate_blueprint(n_rounds)
rr.send_blueprint(blueprint)
yield stream.read()
handle = threading.Thread(
target=vista.run_sampling,
args=[
log_queue,
first_frame_file_name,
height,
width,
n_rounds,
n_frames,
n_steps,
cfg_scale,
cond_aug,
model,
],
)
handle.start()
while True:
msg = log_queue.get()
if msg == "done":
break
else:
entity_path, entity, times = msg
rr.reset_time()
for timeline, time in times:
if isinstance(time, int):
rr.set_time_sequence(timeline, time)
else:
rr.set_time_seconds(timeline, time)
rr.log(entity_path, entity)
yield stream.read()
handle.join()
model = vista.create_model()
with gr.Blocks(css="style.css") as demo:
gr.Markdown(
"""
# Vista: A Generalizable Driving World Model with High Fidelity and Versatile Controllability
[Shenyuan Gao](https://github.com/Little-Podi), [Jiazhi Yang](https://scholar.google.com/citations?user=Ju7nGX8AAAAJ&hl=en), [Li Chen](https://scholar.google.com/citations?user=ulZxvY0AAAAJ&hl=en), [Kashyap Chitta](https://kashyap7x.github.io/), [Yihang Qiu](https://scholar.google.com/citations?user=qgRUOdIAAAAJ&hl=en), [Andreas Geiger](https://www.cvlibs.net/), [Jun Zhang](https://eejzhang.people.ust.hk/), [Hongyang Li](https://lihongyang.info/)
This is a demo of the [Vista model](https://github.com/OpenDriveLab/Vista), a driving world model that can be used to simulate a variety of driving scenarios. This demo uses [Rerun](https://rerun.io/)'s custom [gradio component](https://www.gradio.app/custom-components/gallery?id=radames%2Fgradio_rerun) to livestream the model's output and show intermediate results.
[📜technical report](https://arxiv.org/abs/2405.17398), [🎬video demos](https://vista-demo.github.io/), [🤗model weights](https://huggingface.co/OpenDriveLab/Vista)
Note that the GPU time is limited to 400 seconds per run. If you need more time, you can run the model locally or on your own server.
"""
)
first_frame = gr.Image(sources="upload", type="filepath")
example_dir_path = os.path.join(os.path.dirname(__file__), "example_images")
example_file_paths = sorted(glob.glob(os.path.join(example_dir_path, "*.*")))
example_gallery = gr.Examples(
examples=example_file_paths,
inputs=first_frame,
cache_examples=False,
)
btn = gr.Button("Generate video")
num_rounds = gr.Slider(
label="Segments",
info="Number of 25 frame segments to generate. Higher values lead to longer videos. Try to keep the product of segments and steps below 30 to avoid running out of time.",
minimum=1,
maximum=5,
value=2,
step=1
)
num_steps = gr.Slider(
label="Diffusion Steps",
info="Number of diffusion steps per segment. Higher values lead to more detailed videos. Try to keep the product of segments and steps below 30 to avoid running out of time.",
minimum=1,
maximum=50,
value=15,
step=1
)
with gr.Row():
viewer = gradio_rerun.Rerun(streaming=True)
btn.click(
generate_gradio,
inputs=[first_frame, num_rounds, num_steps],
outputs=[viewer],
)
demo.launch()
|