Spaces:
Sleeping
Sleeping
File size: 7,492 Bytes
b389fc0 f7a5cb1 69a06cd f7a5cb1 69a06cd f7a5cb1 69a06cd f7a5cb1 45aa913 f7a5cb1 c2d401a f7a5cb1 69a06cd f7a5cb1 361d4aa f7a5cb1 66da254 f7a5cb1 361d4aa 0d05803 c395a63 0d05803 361d4aa 0d05803 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import spaces
from functools import partial
from typing import Any, Callable, Dict
import clip
import gradio as gr
from gradio_rerun import Rerun
import numpy as np
import trimesh
import rerun as rr
import torch
from utils.common_viz import init, get_batch
from utils.random_utils import set_random_seed
from utils.rerun import log_sample
from src.diffuser import Diffuser
from src.datasets.multimodal_dataset import MultimodalDataset
# ------------------------------------------------------------------------------------- #
batch_size, num_cams, num_verts = None, None, None
SAMPLE_IDS = [
"2011_KAeAqaA0Llg_00005_00001",
"2011_F_EuMeT2wBo_00014_00001",
"2011_MCkKihQrNA4_00014_00000",
]
LABEL_TO_IDS = {
"right": 0,
"static": 1,
"complex": 2,
}
EXAMPLES = [
"While the character moves right, the camera trucks right.",
"While the character moves right, the camera performs a push in.",
"While the character moves right, the camera performs a pull out.",
"While the character stays static, the camera performs a boom bottom.",
"While the character stays static, the camera performs a boom top.",
"While the character moves to the right, the camera trucks right alongside them. Once the character comes to a stop, the camera remains static.", # noqa
"While the character moves to the right, the camera remains static. Once the character comes to a stop, the camera pushes in.", # noqa
]
DEFAULT_TEXT = [
"While the character moves right, the camera [...].",
"While the character remains static, [...].",
"While the character moves to the right, the camera [...]. "
"Once the character comes to a stop, the camera [...].",
]
HEADER = """
<div align="center">
<h1 style='text-align: center'>E.T. the Exceptional Trajectories</h2>
<a href="https://robincourant.github.io/info/"><strong>Robin Courant</strong></a>
路
<a href="https://nicolas-dufour.github.io/"><strong>Nicolas Dufour</strong></a>
路
<a href="https://triocrossing.github.io/"><strong>Xi Wang</strong></a>
路
<a href="http://people.irisa.fr/Marc.Christie/"><strong>Marc Christie</strong></a>
路
<a href="https://vicky.kalogeiton.info/"><strong>Vicky Kalogeiton</strong></a>
</div>
<div align="center">
<a href="https://www.lix.polytechnique.fr/vista/projects/2024_et_courant/" class="button"><b>[Webpage]</b></a>
<a href="https://github.com/robincourant/DIRECTOR" class="button"><b>[DIRECTOR]</b></a>
<a href="https://github.com/robincourant/CLaTr" class="button"><b>[CLaTr]</b></a>
<a href="https://github.com/robincourant/the-exceptional-trajectories" class="button"><b>[Data]</b></a>
</div>
<br/>
"""
# ------------------------------------------------------------------------------------- #
def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
num_frames, num_faces = vertices.shape[0], faces.shape[-2]
faces = faces.expand(num_frames, num_faces, 3)
normals = [
trimesh.Trimesh(vertices=v, faces=f, process=False).vertex_normals
for v, f in zip(vertices, faces)
]
normals = torch.from_numpy(np.stack(normals))
return normals
@spaces.GPU
def generate(
prompt: str,
seed: int,
guidance_weight: float,
sample_label: str,
# ----------------------- 脽#
dataset: MultimodalDataset,
device: torch.device,
diffuser: Diffuser,
clip_model: clip.model.CLIP,
) -> Dict[str, Any]:
# Set arguments
set_random_seed(seed)
diffuser.gen_seeds = np.array([seed])
diffuser.guidance_weight = guidance_weight
# Inference
sample_id = SAMPLE_IDS[LABEL_TO_IDS[sample_label]]
seq_feat = diffuser.net.model.clip_sequential
batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device)
with torch.no_grad():
out = diffuser.predict_step(batch, 0)
# Run visualization
padding_mask = out["padding_mask"][0].to(bool).cpu()
padded_traj = out["gen_samples"][0].cpu()
traj = padded_traj[padding_mask]
padded_vertices = out["char_raw"]["char_vertices"][0]
vertices = padded_vertices[padding_mask]
faces = out["char_raw"]["char_faces"][0]
normals = get_normals(vertices, faces)
fx, fy, cx, cy = out["intrinsics"][0].cpu().numpy()
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
caption = out["caption_raw"][0]
rr.init(f"{sample_id}")
rr.save(".tmp_gr.rrd")
log_sample(
root_name="world",
traj=traj.numpy(),
K=K,
vertices=vertices.numpy(),
faces=faces.numpy(),
normals=normals.numpy(),
caption=caption,
mesh_masks=None,
)
return "./.tmp_gr.rrd"
# ------------------------------------------------------------------------------------- #
def launch_app(gen_fn: Callable):
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
with gr.Blocks(theme=theme) as demo:
gr.Markdown(HEADER)
with gr.Row():
with gr.Column(scale=3):
with gr.Column(scale=2):
sample_str = gr.Dropdown(
choices=["static", "right", "complex"],
label="Character trajectory",
value="right",
interactive=True,
)
text = gr.Textbox(
placeholder="Type the camera motion you want to generate",
show_label=True,
label="Text prompt",
value=DEFAULT_TEXT[LABEL_TO_IDS[sample_str.value]],
)
seed = gr.Number(value=33, label="Seed")
guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
with gr.Column(scale=1):
btn = gr.Button("Generate", variant="primary")
with gr.Column(scale=2):
examples = gr.Examples(
examples=[[x, None, None] for x in EXAMPLES],
inputs=[text],
)
with gr.Row():
output = Rerun()
def load_example(example_id):
processed_example = examples.non_none_processed_examples[example_id]
return gr.utils.resolve_singleton(processed_example)
def change_fn(change):
sample_index = LABEL_TO_IDS[change]
return gr.update(value=DEFAULT_TEXT[sample_index])
sample_str.change(fn=change_fn, inputs=[sample_str], outputs=[text])
inputs = [text, seed, guidance, sample_str]
examples.dataset.click(
load_example,
inputs=[examples.dataset],
outputs=examples.inputs_with_examples,
show_progress=False,
postprocess=False,
queue=False,
).then(fn=gen_fn, inputs=inputs, outputs=[output])
btn.click(fn=gen_fn, inputs=inputs, outputs=[output])
text.submit(fn=gen_fn, inputs=inputs, outputs=[output])
demo.queue().launch(share=False)
# ------------------------------------------------------------------------------------- #
diffuser, clip_model, dataset, device = init("config")
diffuser.to("cuda")
clip_model.to("cuda")
generate_sample = partial(
generate,
dataset=dataset,
device=device,
diffuser=diffuser,
clip_model=clip_model,
)
launch_app(generate_sample)
|