File size: 4,229 Bytes
d323598
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""Gradio interface for Vista model."""
from __future__ import annotations

import glob
import os
import queue
import threading

import gradio as gr
import gradio_rerun
import rerun as rr
import spaces

import vista


@spaces.GPU(duration=400)
@rr.thread_local_stream("Vista")
def generate_gradio(
    first_frame_file_name: str,
    n_rounds: float=3,
    n_steps: float=10,
    height=576,
    width=1024,
    n_frames=25,
    cfg_scale=2.5,
    cond_aug=0.0,
):
    global model

    n_rounds = int(n_rounds)
    n_steps = int(n_steps)

    # Use a queue to log immediately from internals
    log_queue = queue.SimpleQueue()

    stream = rr.binary_stream()

    blueprint = vista.generate_blueprint(n_rounds)
    rr.send_blueprint(blueprint)
    yield stream.read()

    handle = threading.Thread(
        target=vista.run_sampling,
        args=[
            log_queue,
            first_frame_file_name,
            height,
            width,
            n_rounds,
            n_frames,
            n_steps,
            cfg_scale,
            cond_aug,
            model,
        ],
    )
    handle.start()
    while True:
        msg = log_queue.get()
        if msg == "done":
            break
        else:
            entity_path, entity, times = msg
            rr.reset_time()
            for timeline, time in times:
                if isinstance(time, int):
                    rr.set_time_sequence(timeline, time)
                else:
                    rr.set_time_seconds(timeline, time)
            rr.log(entity_path, entity)
            yield stream.read()
    handle.join()


model = vista.create_model()

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(
        """
        # Vista: A Generalizable Driving World Model with High Fidelity and Versatile Controllability

        [Shenyuan Gao](https://github.com/Little-Podi), [Jiazhi Yang](https://scholar.google.com/citations?user=Ju7nGX8AAAAJ&hl=en), [Li Chen](https://scholar.google.com/citations?user=ulZxvY0AAAAJ&hl=en), [Kashyap Chitta](https://kashyap7x.github.io/), [Yihang Qiu](https://scholar.google.com/citations?user=qgRUOdIAAAAJ&hl=en), [Andreas Geiger](https://www.cvlibs.net/), [Jun Zhang](https://eejzhang.people.ust.hk/), [Hongyang Li](https://lihongyang.info/)

        This is a demo of the [Vista model](https://github.com/OpenDriveLab/Vista), a driving world model that can be used to simulate a variety of driving scenarios. This demo uses [Rerun](https://rerun.io/)'s custom [gradio component](https://www.gradio.app/custom-components/gallery?id=radames%2Fgradio_rerun) to livestream the model's output and show intermediate results.

         [📜technical report](https://arxiv.org/abs/2405.17398),  [🎬video demos](https://vista-demo.github.io/),  [🤗model weights](https://huggingface.co/OpenDriveLab/Vista)

         Note that the GPU time is limited to 400 seconds per run. If you need more time, you can run the model locally or on your own server.
        """
    )
    first_frame = gr.Image(sources="upload", type="filepath")
    example_dir_path = os.path.join(os.path.dirname(__file__), "example_images")
    example_file_paths = sorted(glob.glob(os.path.join(example_dir_path, "*.*")))
    example_gallery = gr.Examples(
        examples=example_file_paths,
        inputs=first_frame,
        cache_examples=False,
    )

    btn = gr.Button("Generate video")
    num_rounds = gr.Slider(
        label="Segments",
        info="Number of 25 frame segments to generate. Higher values lead to longer videos. Try to keep the product of segments and steps below 30 to avoid running out of time.",
        minimum=1,
        maximum=5,
        value=2,
        step=1
    )
    num_steps = gr.Slider(
        label="Diffusion Steps",
        info="Number of diffusion steps per segment. Higher values lead to more detailed videos. Try to keep the product of segments and steps below 30 to avoid running out of time.",
        minimum=1,
        maximum=50,
        value=15,
        step=1
    )

    with gr.Row():
        viewer = gradio_rerun.Rerun(streaming=True)
    btn.click(
        generate_gradio,
        inputs=[first_frame, num_rounds, num_steps],
        outputs=[viewer],
    )

demo.launch()