Spaces:
Sleeping
Sleeping
robin-courant
commited on
Commit
•
f7a5cb1
1
Parent(s):
5e4c5a1
Add app
Browse files- app.py +218 -0
- configs/compnode/cpu.yaml +3 -0
- configs/config.yaml +17 -0
- configs/dataset/caption/caption.yaml +14 -0
- configs/dataset/char/char.yaml +15 -0
- configs/dataset/standardization/0300.yaml +15 -0
- configs/dataset/traj+caption+char.yaml +18 -0
- configs/dataset/trajectory/rot6d_trajectory.yaml +10 -0
- configs/diffuser/network/module/ca_director.yaml +17 -0
- configs/diffuser/network/rn_director.yaml +7 -0
- configs/diffuser/rn_director_edm.yaml +24 -0
- src/datasets/datamodule.py +68 -0
- src/datasets/modalities/caption_dataset.py +107 -0
- src/datasets/modalities/char_dataset.py +120 -0
- src/datasets/modalities/trajectory_dataset.py +152 -0
- src/datasets/multimodal_dataset.py +88 -0
- src/diffuser.py +221 -0
- src/models/modules/director.py +1154 -0
- src/models/networks.py +33 -0
- utils/common_viz.py +136 -0
- utils/file_utils.py +117 -0
- utils/random_utils.py +46 -0
- utils/rerun.py +74 -0
- utils/rotation_utils.py +178 -0
app.py
ADDED
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from functools import partial
|
2 |
+
from typing import Any, Callable, Dict
|
3 |
+
|
4 |
+
import clip
|
5 |
+
import gradio as gr
|
6 |
+
from gradio_rerun import Rerun
|
7 |
+
import numpy as np
|
8 |
+
from pytorch3d.renderer import TexturesVertex
|
9 |
+
from pytorch3d.structures import Meshes
|
10 |
+
import rerun as rr
|
11 |
+
import torch
|
12 |
+
|
13 |
+
from utils.common_viz import init, get_batch
|
14 |
+
from utils.random_utils import set_random_seed
|
15 |
+
from utils.rerun import log_sample
|
16 |
+
from src.diffuser import Diffuser
|
17 |
+
from src.datasets.multimodal_dataset import MultimodalDataset
|
18 |
+
|
19 |
+
# ------------------------------------------------------------------------------------- #
|
20 |
+
|
21 |
+
batch_size, num_cams, num_verts = None, None, None
|
22 |
+
|
23 |
+
SAMPLE_IDS = [
|
24 |
+
"2011_KAeAqaA0Llg_00005_00001",
|
25 |
+
"2011_F_EuMeT2wBo_00014_00001",
|
26 |
+
"2011_MCkKihQrNA4_00014_00000",
|
27 |
+
]
|
28 |
+
LABEL_TO_IDS = {
|
29 |
+
"right": 0,
|
30 |
+
"static": 1,
|
31 |
+
"complex": 2,
|
32 |
+
}
|
33 |
+
EXAMPLES = [
|
34 |
+
"While the character moves right, the camera trucks right.",
|
35 |
+
"While the character moves right, the camera performs a push in.",
|
36 |
+
"While the character moves right, the camera performs a pull out.",
|
37 |
+
"While the character stays static, the camera performs a boom bottom.",
|
38 |
+
"While the character stays static, the camera performs a boom top.",
|
39 |
+
"While the character moves to the right, the camera trucks right alongside them. Once the character comes to a stop, the camera remains static.", # noqa
|
40 |
+
"While the character moves to the right, the camera remains static. Once the character comes to a stop, the camera pushes in.", # noqa
|
41 |
+
]
|
42 |
+
DEFAULT_TEXT = [
|
43 |
+
"While the character moves right, the camera [...].",
|
44 |
+
"While the character remains static, [...].",
|
45 |
+
"While the character moves to the right, the camera [...]. "
|
46 |
+
"Once the character comes to a stop, the camera [...].",
|
47 |
+
]
|
48 |
+
|
49 |
+
HEADER = """
|
50 |
+
|
51 |
+
<div align="center">
|
52 |
+
<h1 style='text-align: center'>E.T. the Exceptional Trajectories</h2>
|
53 |
+
<a href="https://robincourant.github.io/info/"><strong>Robin Courant</strong></a>
|
54 |
+
·
|
55 |
+
<a href="https://nicolas-dufour.github.io/"><strong>Nicolas Dufour</strong></a>
|
56 |
+
·
|
57 |
+
<a href="https://triocrossing.github.io/"><strong>Xi Wang</strong></a>
|
58 |
+
·
|
59 |
+
<a href="http://people.irisa.fr/Marc.Christie/"><strong>Marc Christie</strong></a>
|
60 |
+
·
|
61 |
+
<a href="https://vicky.kalogeiton.info/"><strong>Vicky Kalogeiton</strong></a>
|
62 |
+
</div>
|
63 |
+
|
64 |
+
|
65 |
+
<div align="center">
|
66 |
+
<a href="https://www.lix.polytechnique.fr/vista/projects/2024_et_courant/" class="button"><b>[Webpage]</b></a>
|
67 |
+
<a href="https://github.com/robincourant/DIRECTOR" class="button"><b>[DIRECTOR]</b></a>
|
68 |
+
<a href="https://github.com/robincourant/CLaTr" class="button"><b>[CLaTr]</b></a>
|
69 |
+
<a href="https://github.com/robincourant/the-exceptional-trajectories" class="button"><b>[Data]</b></a>
|
70 |
+
</div>
|
71 |
+
|
72 |
+
<br/>
|
73 |
+
"""
|
74 |
+
|
75 |
+
# ------------------------------------------------------------------------------------- #
|
76 |
+
|
77 |
+
|
78 |
+
def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
|
79 |
+
num_frames, num_faces = vertices.shape[0], faces.shape[-2]
|
80 |
+
faces = faces.expand(num_frames, num_faces, 3)
|
81 |
+
|
82 |
+
verts_rgb = torch.ones_like(vertices)
|
83 |
+
verts_rgb[:, :, 1] = 0
|
84 |
+
textures = TexturesVertex(verts_features=verts_rgb)
|
85 |
+
meshes = Meshes(verts=vertices, faces=faces, textures=textures)
|
86 |
+
normals = meshes.verts_normals_padded()
|
87 |
+
|
88 |
+
return normals, meshes
|
89 |
+
|
90 |
+
|
91 |
+
def generate(
|
92 |
+
prompt: str,
|
93 |
+
seed: int,
|
94 |
+
guidance_weight: float,
|
95 |
+
sample_label: str,
|
96 |
+
# ----------------------- ß#
|
97 |
+
dataset: MultimodalDataset,
|
98 |
+
device: torch.device,
|
99 |
+
diffuser: Diffuser,
|
100 |
+
clip_model: clip.model.CLIP,
|
101 |
+
) -> Dict[str, Any]:
|
102 |
+
# Set arguments
|
103 |
+
set_random_seed(seed)
|
104 |
+
diffuser.gen_seeds = np.array([seed])
|
105 |
+
diffuser.guidance_weight = guidance_weight
|
106 |
+
|
107 |
+
# Inference
|
108 |
+
sample_id = SAMPLE_IDS[LABEL_TO_IDS[sample_label]]
|
109 |
+
seq_feat = diffuser.net.model.clip_sequential
|
110 |
+
batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device)
|
111 |
+
with torch.no_grad():
|
112 |
+
out = diffuser.predict_step(batch, 0)
|
113 |
+
|
114 |
+
# Run visualization
|
115 |
+
padding_mask = out["padding_mask"][0].to(bool).cpu()
|
116 |
+
padded_traj = out["gen_samples"][0].cpu()
|
117 |
+
traj = padded_traj[padding_mask]
|
118 |
+
padded_vertices = out["char_raw"]["char_vertices"][0]
|
119 |
+
vertices = padded_vertices[padding_mask]
|
120 |
+
faces = out["char_raw"]["char_faces"][0]
|
121 |
+
normals, meshes = get_normals(vertices, faces)
|
122 |
+
fx, fy, cx, cy = out["intrinsics"][0].cpu().numpy()
|
123 |
+
K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
124 |
+
caption = out["caption_raw"][0]
|
125 |
+
|
126 |
+
rr.init(f"{sample_id}")
|
127 |
+
rr.save(".tmp_gr.rrd")
|
128 |
+
log_sample(
|
129 |
+
root_name="world",
|
130 |
+
traj=traj.numpy(),
|
131 |
+
K=K,
|
132 |
+
vertices=vertices.numpy(),
|
133 |
+
faces=faces.numpy(),
|
134 |
+
normals=normals.numpy(),
|
135 |
+
caption=caption,
|
136 |
+
mesh_masks=None,
|
137 |
+
)
|
138 |
+
return "./.tmp_gr.rrd"
|
139 |
+
|
140 |
+
|
141 |
+
# ------------------------------------------------------------------------------------- #
|
142 |
+
|
143 |
+
|
144 |
+
def main(gen_fn: Callable):
|
145 |
+
theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
|
146 |
+
|
147 |
+
with gr.Blocks(theme=theme) as demo:
|
148 |
+
gr.Markdown(HEADER)
|
149 |
+
|
150 |
+
with gr.Row():
|
151 |
+
with gr.Column(scale=3):
|
152 |
+
with gr.Column(scale=2):
|
153 |
+
sample_str = gr.Dropdown(
|
154 |
+
choices=["static", "right", "complex"],
|
155 |
+
label="Character trajectory",
|
156 |
+
value="right",
|
157 |
+
interactive=True,
|
158 |
+
)
|
159 |
+
text = gr.Textbox(
|
160 |
+
placeholder="Type the camera motion you want to generate",
|
161 |
+
show_label=True,
|
162 |
+
label="Text prompt",
|
163 |
+
value=DEFAULT_TEXT[LABEL_TO_IDS[sample_str.value]],
|
164 |
+
)
|
165 |
+
seed = gr.Number(value=33, label="Seed")
|
166 |
+
guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
|
167 |
+
|
168 |
+
with gr.Column(scale=1):
|
169 |
+
btn = gr.Button("Generate", variant="primary")
|
170 |
+
|
171 |
+
with gr.Column(scale=2):
|
172 |
+
examples = gr.Examples(
|
173 |
+
examples=[[x, None, None] for x in EXAMPLES],
|
174 |
+
inputs=[text],
|
175 |
+
)
|
176 |
+
|
177 |
+
with gr.Row():
|
178 |
+
output = Rerun()
|
179 |
+
|
180 |
+
def load_example(example_id):
|
181 |
+
processed_example = examples.non_none_processed_examples[example_id]
|
182 |
+
return gr.utils.resolve_singleton(processed_example)
|
183 |
+
|
184 |
+
def change_fn(change):
|
185 |
+
sample_index = LABEL_TO_IDS[change]
|
186 |
+
return gr.update(value=DEFAULT_TEXT[sample_index])
|
187 |
+
|
188 |
+
sample_str.change(fn=change_fn, inputs=[sample_str], outputs=[text])
|
189 |
+
|
190 |
+
inputs = [text, seed, guidance, sample_str]
|
191 |
+
examples.dataset.click(
|
192 |
+
load_example,
|
193 |
+
inputs=[examples.dataset],
|
194 |
+
outputs=examples.inputs_with_examples,
|
195 |
+
show_progress=False,
|
196 |
+
postprocess=False,
|
197 |
+
queue=False,
|
198 |
+
).then(fn=gen_fn, inputs=inputs, outputs=[output])
|
199 |
+
btn.click(fn=gen_fn, inputs=inputs, outputs=[output])
|
200 |
+
text.submit(fn=gen_fn, inputs=inputs, outputs=[output])
|
201 |
+
demo.launch(share=False)
|
202 |
+
|
203 |
+
|
204 |
+
# ------------------------------------------------------------------------------------- #
|
205 |
+
|
206 |
+
|
207 |
+
if __name__ == "__main__":
|
208 |
+
# Initialize the models and dataset
|
209 |
+
diffuser, clip_model, dataset, device = init("config")
|
210 |
+
generate_sample = partial(
|
211 |
+
generate,
|
212 |
+
dataset=dataset,
|
213 |
+
device=device,
|
214 |
+
diffuser=diffuser,
|
215 |
+
clip_model=clip_model,
|
216 |
+
)
|
217 |
+
|
218 |
+
main(generate_sample)
|
configs/compnode/cpu.yaml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
device: cpu
|
2 |
+
num_gpus: 1
|
3 |
+
num_workers: 8
|
configs/config.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- dataset: traj+caption+char
|
3 |
+
- diffuser: rn_director_edm
|
4 |
+
- compnode: cpu
|
5 |
+
- _self_
|
6 |
+
|
7 |
+
dataset:
|
8 |
+
char:
|
9 |
+
load_vertices: true
|
10 |
+
|
11 |
+
checkpoint_path: 'checkpoints/ca-mixed-e449.ckpt'
|
12 |
+
batch_size: 128
|
13 |
+
data_dir: data
|
14 |
+
|
15 |
+
hydra:
|
16 |
+
run:
|
17 |
+
dir: ./${results_dir}/${xp_name}/${timestamp}
|
configs/dataset/caption/caption.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.datasets.modalities.caption_dataset.CaptionDataset
|
2 |
+
|
3 |
+
name: caption
|
4 |
+
|
5 |
+
dataset_dir: ${dataset.dataset_dir}
|
6 |
+
segment_dir: ${dataset.dataset_dir}/cam_segments
|
7 |
+
raw_caption_dir: ${dataset.dataset_dir}/caption
|
8 |
+
feat_caption_dir: ${dataset.dataset_dir}/caption_clip
|
9 |
+
|
10 |
+
num_segments: 27
|
11 |
+
num_feats: 512
|
12 |
+
num_cams: ${dataset.standardization.num_cams}
|
13 |
+
sequential: ${diffuser.network.module.clip_sequential}
|
14 |
+
max_feat_length: 77
|
configs/dataset/char/char.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.datasets.modalities.char_dataset.CharacterDataset
|
2 |
+
|
3 |
+
name: char
|
4 |
+
dataset_dir: ${dataset.dataset_dir}
|
5 |
+
num_cams: ${dataset.num_cams}
|
6 |
+
num_raw_feats: 3
|
7 |
+
num_frequencies: 10
|
8 |
+
min_freq: 0
|
9 |
+
max_freq: 4
|
10 |
+
num_encoding: 3 # ${eval:'2 * ${dataset.char.num_frequencies} * ${dataset.char.num_raw_feats}'}
|
11 |
+
sequential: ${diffuser.network.module.cond_sequential}
|
12 |
+
num_feats: ${eval:'${dataset.char.num_encoding} if ${dataset.char.sequential} else ${dataset.num_cams} * ${dataset.char.num_encoding}'}
|
13 |
+
standardize: ${dataset.trajectory.standardize}
|
14 |
+
standardization: ${dataset.standardization}
|
15 |
+
load_vertices: ${diffuser.do_projection}
|
configs/dataset/standardization/0300.yaml
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: '0300'
|
2 |
+
num_interframes: 0
|
3 |
+
num_cams: 300
|
4 |
+
num_total_frames: ${eval:'${dataset.standardization.num_interframes} * (${dataset.standardization.num_cams} - 1) + ${dataset.standardization.num_cams} '}
|
5 |
+
|
6 |
+
norm_mean: [7.93987673e-05, -9.98621393e-05, 4.12940653e-04]
|
7 |
+
norm_std: [0.027841, 0.01819818, 0.03138536]
|
8 |
+
|
9 |
+
shift_mean: [0.00201079, -0.27488501, -1.23616805]
|
10 |
+
shift_std: [1.13433516, 1.19061042, 1.58744263]
|
11 |
+
|
12 |
+
norm_mean_h: [6.676e-05, -5.084e-05, -7.782e-04]
|
13 |
+
norm_std_h: [0.0105, 0.006958, 0.01145]
|
14 |
+
|
15 |
+
velocity: true
|
configs/dataset/traj+caption+char.yaml
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.datasets.multimodal_dataset.MultimodalDataset
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- _self_
|
5 |
+
- trajectory: rot6d_trajectory
|
6 |
+
- char: char
|
7 |
+
- caption: caption
|
8 |
+
- standardization: '0300'
|
9 |
+
|
10 |
+
name: "${dataset.standardization.name}-t:${dataset.trajectory.name}|c:${dataset.caption.name}|h:${dataset.char.name}"
|
11 |
+
dataset_name: ${dataset.standardization.name}
|
12 |
+
dataset_dir: ${data_dir}
|
13 |
+
|
14 |
+
num_rawfeats: 12
|
15 |
+
num_cams: ${dataset.standardization.num_cams}
|
16 |
+
feature_type: ${dataset.trajectory.name}
|
17 |
+
num_feats: ${dataset.trajectory.num_feats}
|
18 |
+
num_cond_feats: ['${dataset.char.num_feats}','${dataset.caption.num_feats}']
|
configs/dataset/trajectory/rot6d_trajectory.yaml
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.datasets.modalities.trajectory_dataset.TrajectoryDataset
|
2 |
+
|
3 |
+
name: rot6d
|
4 |
+
set_name: null
|
5 |
+
dataset_dir: ${dataset.dataset_dir}
|
6 |
+
num_feats: 9
|
7 |
+
num_rawfeats: ${dataset.num_rawfeats}
|
8 |
+
num_cams: ${dataset.num_cams}
|
9 |
+
standardize: true
|
10 |
+
standardization: ${dataset.standardization}
|
configs/diffuser/network/module/ca_director.yaml
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.modules.director.CrossAttentionDirector
|
2 |
+
name: ca_director
|
3 |
+
num_feats: ${dataset.num_feats}
|
4 |
+
num_rawfeats: ${dataset.num_rawfeats}
|
5 |
+
num_cams: ${dataset.num_cams}
|
6 |
+
num_cond_feats: ${dataset.num_cond_feats}
|
7 |
+
latent_dim: 512
|
8 |
+
mlp_multiplier: 4
|
9 |
+
num_layers: 8
|
10 |
+
num_heads: 16
|
11 |
+
dropout: 0.1
|
12 |
+
stochastic_depth: 0.1
|
13 |
+
label_dropout: 0.1
|
14 |
+
num_text_registers: 16
|
15 |
+
clip_sequential: True
|
16 |
+
cond_sequential: True
|
17 |
+
device: ${compnode.device}
|
configs/diffuser/network/rn_director.yaml
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.models.networks.RnEDMPrecond
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- module: ca_director
|
5 |
+
|
6 |
+
name: rn_director
|
7 |
+
sigma_data: 0.5
|
configs/diffuser/rn_director_edm.yaml
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
_target_: src.diffuser.Diffuser
|
2 |
+
|
3 |
+
defaults:
|
4 |
+
- _self_
|
5 |
+
- network: rn_director
|
6 |
+
|
7 |
+
guidance_weight: 1.4
|
8 |
+
edm2_normalization: true
|
9 |
+
|
10 |
+
# EMA
|
11 |
+
ema_kwargs:
|
12 |
+
beta: 0.9999
|
13 |
+
update_every: 1
|
14 |
+
|
15 |
+
# Sampling
|
16 |
+
sampling_kwargs:
|
17 |
+
num_steps: 10
|
18 |
+
sigma_min: 0.002
|
19 |
+
sigma_max: 80
|
20 |
+
rho: 40
|
21 |
+
S_churn: 0
|
22 |
+
S_min: 0
|
23 |
+
S_max: inf
|
24 |
+
S_noise: 1
|
src/datasets/datamodule.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lightning import LightningDataModule
|
2 |
+
from torch.utils.data import Dataset, DataLoader
|
3 |
+
|
4 |
+
|
5 |
+
class Datamodule(LightningDataModule):
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
train_dataset: Dataset,
|
9 |
+
eval_dataset: Dataset,
|
10 |
+
batch_train_size: int,
|
11 |
+
num_workers: int,
|
12 |
+
eval_batch_size: int = None,
|
13 |
+
):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.train_dataset = train_dataset
|
17 |
+
self.eval_dataset = eval_dataset
|
18 |
+
|
19 |
+
self.batch_train_size = batch_train_size
|
20 |
+
self.eval_batch_size = (
|
21 |
+
eval_batch_size if eval_batch_size is not None else batch_train_size
|
22 |
+
)
|
23 |
+
|
24 |
+
self.num_workers = num_workers
|
25 |
+
|
26 |
+
def train_dataloader(self) -> DataLoader:
|
27 |
+
"""Load train set loader."""
|
28 |
+
persistent_workers = True if self.num_workers > 0 else False
|
29 |
+
|
30 |
+
dataloader = DataLoader(
|
31 |
+
self.train_dataset,
|
32 |
+
batch_size=self.batch_train_size,
|
33 |
+
num_workers=self.num_workers,
|
34 |
+
pin_memory=True,
|
35 |
+
persistent_workers=persistent_workers,
|
36 |
+
)
|
37 |
+
return dataloader
|
38 |
+
|
39 |
+
def val_dataloader(self) -> DataLoader:
|
40 |
+
"""Load val set loader."""
|
41 |
+
persistent_workers = True if self.num_workers > 0 else False
|
42 |
+
|
43 |
+
dataloader = DataLoader(
|
44 |
+
self.eval_dataset,
|
45 |
+
batch_size=self.eval_batch_size,
|
46 |
+
num_workers=self.num_workers,
|
47 |
+
pin_memory=True,
|
48 |
+
persistent_workers=persistent_workers,
|
49 |
+
)
|
50 |
+
return dataloader
|
51 |
+
|
52 |
+
def predict_dataloader(self) -> DataLoader:
|
53 |
+
"""Load predict set loader."""
|
54 |
+
dataloader = DataLoader(
|
55 |
+
self.eval_dataset,
|
56 |
+
batch_size=self.eval_batch_size,
|
57 |
+
num_workers=self.num_workers,
|
58 |
+
)
|
59 |
+
return dataloader
|
60 |
+
|
61 |
+
def test_dataloader(self) -> DataLoader:
|
62 |
+
"""Load test set loader."""
|
63 |
+
dataloader = DataLoader(
|
64 |
+
self.eval_dataset,
|
65 |
+
batch_size=self.eval_batch_size,
|
66 |
+
num_workers=self.num_workers,
|
67 |
+
)
|
68 |
+
return dataloader
|
src/datasets/modalities/caption_dataset.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import Counter
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
import torch.nn.functional as F
|
8 |
+
|
9 |
+
from utils.file_utils import load_txt
|
10 |
+
|
11 |
+
|
12 |
+
class CaptionDataset(Dataset):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
name: str,
|
16 |
+
dataset_dir: str,
|
17 |
+
num_cams: int,
|
18 |
+
num_feats: int,
|
19 |
+
num_segments: int,
|
20 |
+
sequential: bool,
|
21 |
+
**kwargs,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
self.modality = name
|
25 |
+
self.name = name
|
26 |
+
self.dataset_dir = Path(dataset_dir)
|
27 |
+
# Set data paths (segments, captions, etc...)
|
28 |
+
for name, field in kwargs.items():
|
29 |
+
if isinstance(field, str):
|
30 |
+
field = Path(field)
|
31 |
+
if name == "feat_caption_dir":
|
32 |
+
field = field / "seq" if sequential else field / "token"
|
33 |
+
setattr(self, name, field)
|
34 |
+
|
35 |
+
self.filenames = None
|
36 |
+
|
37 |
+
self.clip_seq_dir = self.dataset_dir / "caption_clip" / "seq" # For CLaTrScore
|
38 |
+
self.num_cams = num_cams
|
39 |
+
self.num_feats = num_feats
|
40 |
+
self.num_segments = num_segments
|
41 |
+
self.sequential = sequential
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.filenames)
|
45 |
+
|
46 |
+
def __getitem__(self, index):
|
47 |
+
filename = self.filenames[index]
|
48 |
+
|
49 |
+
# Load data
|
50 |
+
if hasattr(self, "segment_dir"):
|
51 |
+
raw_segments = torch.from_numpy(
|
52 |
+
np.load((self.segment_dir / (filename + ".npy")))
|
53 |
+
)
|
54 |
+
padded_raw_segments = F.pad(
|
55 |
+
raw_segments,
|
56 |
+
(0, self.num_cams - len(raw_segments)),
|
57 |
+
value=self.num_segments,
|
58 |
+
)
|
59 |
+
if hasattr(self, "raw_caption_dir"):
|
60 |
+
raw_caption = load_txt(self.raw_caption_dir / (filename + ".txt"))
|
61 |
+
if hasattr(self, "feat_caption_dir"):
|
62 |
+
feat_caption = torch.from_numpy(
|
63 |
+
np.load((self.feat_caption_dir / (filename + ".npy")))
|
64 |
+
)
|
65 |
+
if self.sequential:
|
66 |
+
feat_caption = F.pad(
|
67 |
+
feat_caption.to(torch.float32),
|
68 |
+
(0, 0, 0, self.max_feat_length - feat_caption.shape[0]),
|
69 |
+
)
|
70 |
+
|
71 |
+
if self.modality == "caption":
|
72 |
+
raw_data = {"caption": raw_caption, "segments": padded_raw_segments}
|
73 |
+
feat_data = (
|
74 |
+
feat_caption.permute(1, 0) if feat_caption.dim() == 2 else feat_caption
|
75 |
+
)
|
76 |
+
elif self.modality == "segments":
|
77 |
+
raw_data = {"segments": padded_raw_segments}
|
78 |
+
# Shift by one for padding
|
79 |
+
feat_data = F.one_hot(
|
80 |
+
padded_raw_segments, num_classes=self.num_segments + 1
|
81 |
+
).to(torch.float32)
|
82 |
+
if self.sequential:
|
83 |
+
feat_data = feat_data.permute(1, 0)
|
84 |
+
else:
|
85 |
+
feat_data = feat_data.reshape(-1)
|
86 |
+
elif self.modality == "class":
|
87 |
+
raw_data = {"segments": padded_raw_segments}
|
88 |
+
most_frequent_segment = Counter(raw_segments).most_common(1)[0][0]
|
89 |
+
feat_data = F.one_hot(
|
90 |
+
torch.tensor(most_frequent_segment), num_classes=self.num_segments
|
91 |
+
).to(torch.float32)
|
92 |
+
else:
|
93 |
+
raise ValueError(f"Modality {self.modality} not supported")
|
94 |
+
|
95 |
+
clip_seq_caption = torch.from_numpy(
|
96 |
+
np.load((self.clip_seq_dir / (filename + ".npy")))
|
97 |
+
)
|
98 |
+
padding_mask = torch.ones((self.max_feat_length))
|
99 |
+
padding_mask[clip_seq_caption.shape[0] :] = 0
|
100 |
+
clip_seq_caption = F.pad(
|
101 |
+
clip_seq_caption.to(torch.float32),
|
102 |
+
(0, 0, 0, self.max_feat_length - clip_seq_caption.shape[0]),
|
103 |
+
)
|
104 |
+
raw_data["clip_seq_caption"] = clip_seq_caption
|
105 |
+
raw_data["clip_seq_mask"] = padding_mask
|
106 |
+
|
107 |
+
return filename, feat_data, raw_data
|
src/datasets/modalities/char_dataset.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
# ------------------------------------------------------------------------------------- #
|
9 |
+
|
10 |
+
num_frequencies = None
|
11 |
+
|
12 |
+
# ------------------------------------------------------------------------------------- #
|
13 |
+
|
14 |
+
|
15 |
+
class CharacterDataset(Dataset):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
name: str,
|
19 |
+
dataset_dir: str,
|
20 |
+
standardize: bool,
|
21 |
+
num_feats: int,
|
22 |
+
num_cams: int,
|
23 |
+
sequential: bool,
|
24 |
+
num_frequencies: int,
|
25 |
+
min_freq: int,
|
26 |
+
max_freq: int,
|
27 |
+
load_vertices: bool,
|
28 |
+
**kwargs,
|
29 |
+
):
|
30 |
+
super().__init__()
|
31 |
+
self.modality = "char"
|
32 |
+
self.name = name
|
33 |
+
self.dataset_dir = Path(dataset_dir)
|
34 |
+
self.traj_dir = self.dataset_dir / "traj"
|
35 |
+
self.data_dir = self.dataset_dir / self.name
|
36 |
+
self.vert_dir = self.dataset_dir / "vert_raw"
|
37 |
+
self.center_dir = self.dataset_dir / "char_raw"
|
38 |
+
|
39 |
+
self.filenames = None
|
40 |
+
self.standardize = standardize
|
41 |
+
if self.standardize:
|
42 |
+
mean_std = kwargs["standardization"]
|
43 |
+
self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None]
|
44 |
+
self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None]
|
45 |
+
self.velocity = mean_std["velocity"]
|
46 |
+
|
47 |
+
self.num_cams = num_cams
|
48 |
+
self.num_feats = num_feats
|
49 |
+
self.sequential = sequential
|
50 |
+
self.num_frequencies = num_frequencies
|
51 |
+
self.min_freq = min_freq
|
52 |
+
self.max_freq = max_freq
|
53 |
+
|
54 |
+
self.load_vertices = load_vertices
|
55 |
+
|
56 |
+
def __len__(self):
|
57 |
+
return len(self.filenames)
|
58 |
+
|
59 |
+
def __getitem__(self, index):
|
60 |
+
filename = self.filenames[index]
|
61 |
+
|
62 |
+
char_filename = filename + ".npy"
|
63 |
+
char_path = self.data_dir / char_filename
|
64 |
+
|
65 |
+
raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32)
|
66 |
+
padding_size = self.num_cams - raw_char_feature.shape[0]
|
67 |
+
padded_raw_char_feature = F.pad(
|
68 |
+
raw_char_feature, (0, 0, 0, padding_size)
|
69 |
+
).permute(1, 0)
|
70 |
+
|
71 |
+
center_path = self.center_dir / char_filename # Center to offset mesh
|
72 |
+
center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32)
|
73 |
+
if self.load_vertices:
|
74 |
+
vert_path = self.vert_dir / char_filename
|
75 |
+
raw_verts = np.load(vert_path, allow_pickle=True)[()]
|
76 |
+
if raw_verts["vertices"] is None:
|
77 |
+
num_frames = raw_char_feature.shape[0]
|
78 |
+
verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32)
|
79 |
+
padded_verts = torch.zeros(
|
80 |
+
(self.num_cams, 6890, 3), dtype=torch.float32
|
81 |
+
)
|
82 |
+
faces = torch.zeros((13776, 3), dtype=torch.int16)
|
83 |
+
else:
|
84 |
+
verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32)
|
85 |
+
verts -= center_offset
|
86 |
+
padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size))
|
87 |
+
faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16)
|
88 |
+
|
89 |
+
char_feature = raw_char_feature.clone()
|
90 |
+
if self.velocity:
|
91 |
+
velocity = char_feature[1:].clone() - char_feature[:-1].clone()
|
92 |
+
char_feature = torch.cat([raw_char_feature[0][None], velocity])
|
93 |
+
|
94 |
+
if self.standardize:
|
95 |
+
# Normalize the first frame (orgin) and the rest (velocity) separately
|
96 |
+
if len(self.norm_mean) == 6:
|
97 |
+
char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device)
|
98 |
+
char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device)
|
99 |
+
char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device)
|
100 |
+
char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device)
|
101 |
+
# Normalize all in one
|
102 |
+
else:
|
103 |
+
char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device)
|
104 |
+
char_feature /= self.norm_std[:, 0].to(raw_char_feature.device)
|
105 |
+
padded_char_feature = F.pad(
|
106 |
+
char_feature,
|
107 |
+
(0, 0, 0, self.num_cams - char_feature.shape[0]),
|
108 |
+
)
|
109 |
+
|
110 |
+
if self.sequential:
|
111 |
+
padded_char_feature = padded_char_feature.permute(1, 0)
|
112 |
+
else:
|
113 |
+
padded_char_feature = padded_char_feature.reshape(-1)
|
114 |
+
|
115 |
+
raw_feats = {"char_raw_feat": padded_raw_char_feature}
|
116 |
+
if self.load_vertices:
|
117 |
+
raw_feats["char_vertices"] = padded_verts
|
118 |
+
raw_feats["char_faces"] = faces
|
119 |
+
|
120 |
+
return char_filename, padded_char_feature, raw_feats
|
src/datasets/modalities/trajectory_dataset.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
from evo.tools.file_interface import read_kitti_poses_file
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import Dataset
|
7 |
+
from torchtyping import TensorType
|
8 |
+
import torch.nn.functional as F
|
9 |
+
from typing import Tuple
|
10 |
+
|
11 |
+
from utils.file_utils import load_txt
|
12 |
+
from utils.rotation_utils import compute_rotation_matrix_from_ortho6d
|
13 |
+
|
14 |
+
num_cams = None
|
15 |
+
|
16 |
+
|
17 |
+
# ------------------------------------------------------------------------------------- #
|
18 |
+
|
19 |
+
|
20 |
+
class TrajectoryDataset(Dataset):
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
name: str,
|
24 |
+
set_name: str,
|
25 |
+
dataset_dir: str,
|
26 |
+
num_rawfeats: int,
|
27 |
+
num_feats: int,
|
28 |
+
num_cams: int,
|
29 |
+
standardize: bool,
|
30 |
+
**kwargs,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
self.name = name
|
34 |
+
self.set_name = set_name
|
35 |
+
self.dataset_dir = Path(dataset_dir)
|
36 |
+
if name == "relative":
|
37 |
+
self.data_dir = self.dataset_dir / "traj_raw"
|
38 |
+
self.relative_dir = self.dataset_dir / "relative"
|
39 |
+
else:
|
40 |
+
self.data_dir = self.dataset_dir / "traj"
|
41 |
+
self.intrinsics_dir = self.dataset_dir / "intrinsics"
|
42 |
+
|
43 |
+
self.num_rawfeats = num_rawfeats
|
44 |
+
self.num_feats = num_feats
|
45 |
+
self.num_cams = num_cams
|
46 |
+
|
47 |
+
self.augmentation = None
|
48 |
+
self.standardize = standardize
|
49 |
+
if self.standardize:
|
50 |
+
mean_std = kwargs["standardization"]
|
51 |
+
self.norm_mean = torch.Tensor(mean_std["norm_mean"])
|
52 |
+
self.norm_std = torch.Tensor(mean_std["norm_std"])
|
53 |
+
self.shift_mean = torch.Tensor(mean_std["shift_mean"])
|
54 |
+
self.shift_std = torch.Tensor(mean_std["shift_std"])
|
55 |
+
self.velocity = mean_std["velocity"]
|
56 |
+
|
57 |
+
# --------------------------------------------------------------------------------- #
|
58 |
+
|
59 |
+
def set_split(self, split: str, train_rate: float = 1.0):
|
60 |
+
self.split = split
|
61 |
+
split_path = Path(self.dataset_dir) / f"{split}_split.txt"
|
62 |
+
split_traj = load_txt(split_path).split("\n")
|
63 |
+
self.filenames = sorted(split_traj)
|
64 |
+
|
65 |
+
return self
|
66 |
+
|
67 |
+
# --------------------------------------------------------------------------------- #
|
68 |
+
|
69 |
+
def get_feature(
|
70 |
+
self, raw_matrix_trajectory: TensorType["num_cams", 4, 4]
|
71 |
+
) -> TensorType[9, "num_cams"]:
|
72 |
+
matrix_trajectory = torch.clone(raw_matrix_trajectory)
|
73 |
+
|
74 |
+
raw_trans = torch.clone(matrix_trajectory[:, :3, 3])
|
75 |
+
if self.velocity:
|
76 |
+
velocity = raw_trans[1:] - raw_trans[:-1]
|
77 |
+
raw_trans = torch.cat([raw_trans[0][None], velocity])
|
78 |
+
if self.standardize:
|
79 |
+
raw_trans[0] -= self.shift_mean
|
80 |
+
raw_trans[0] /= self.shift_std
|
81 |
+
raw_trans[1:] -= self.norm_mean
|
82 |
+
raw_trans[1:] /= self.norm_std
|
83 |
+
|
84 |
+
# Compute the 6D continuous rotation
|
85 |
+
raw_rot = matrix_trajectory[:, :3, :3]
|
86 |
+
rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6)
|
87 |
+
|
88 |
+
# Stack rotation 6D and translation
|
89 |
+
rot6d_trajectory = torch.hstack([rot6d, raw_trans]).permute(1, 0)
|
90 |
+
|
91 |
+
return rot6d_trajectory
|
92 |
+
|
93 |
+
def get_matrix(
|
94 |
+
self, raw_rot6d_trajectory: TensorType[9, "num_cams"]
|
95 |
+
) -> TensorType["num_cams", 4, 4]:
|
96 |
+
rot6d_trajectory = torch.clone(raw_rot6d_trajectory)
|
97 |
+
device = rot6d_trajectory.device
|
98 |
+
|
99 |
+
num_cams = rot6d_trajectory.shape[1]
|
100 |
+
matrix_trajectory = torch.eye(4, device=device)[None].repeat(num_cams, 1, 1)
|
101 |
+
|
102 |
+
raw_trans = rot6d_trajectory[6:].permute(1, 0)
|
103 |
+
if self.standardize:
|
104 |
+
raw_trans[0] *= self.shift_std.to(device)
|
105 |
+
raw_trans[0] += self.shift_mean.to(device)
|
106 |
+
raw_trans[1:] *= self.norm_std.to(device)
|
107 |
+
raw_trans[1:] += self.norm_mean.to(device)
|
108 |
+
if self.velocity:
|
109 |
+
raw_trans = torch.cumsum(raw_trans, dim=0)
|
110 |
+
matrix_trajectory[:, :3, 3] = raw_trans
|
111 |
+
|
112 |
+
rot6d = rot6d_trajectory[:6].permute(1, 0)
|
113 |
+
raw_rot = compute_rotation_matrix_from_ortho6d(rot6d)
|
114 |
+
matrix_trajectory[:, :3, :3] = raw_rot
|
115 |
+
|
116 |
+
return matrix_trajectory
|
117 |
+
|
118 |
+
# --------------------------------------------------------------------------------- #
|
119 |
+
|
120 |
+
def __getitem__(self, index: int) -> Tuple[str, TensorType["num_cams", 4, 4]]:
|
121 |
+
filename = self.filenames[index]
|
122 |
+
|
123 |
+
trajectory_filename = filename + ".txt"
|
124 |
+
trajectory_path = self.data_dir / trajectory_filename
|
125 |
+
|
126 |
+
trajectory = read_kitti_poses_file(trajectory_path)
|
127 |
+
matrix_trajectory = torch.from_numpy(np.array(trajectory.poses_se3)).to(
|
128 |
+
torch.float32
|
129 |
+
)
|
130 |
+
|
131 |
+
trajectory_feature = self.get_feature(matrix_trajectory)
|
132 |
+
|
133 |
+
padded_trajectory_feature = F.pad(
|
134 |
+
trajectory_feature, (0, self.num_cams - trajectory_feature.shape[1])
|
135 |
+
)
|
136 |
+
# Padding mask: 1 for valid cams, 0 for padded cams
|
137 |
+
padding_mask = torch.ones((self.num_cams))
|
138 |
+
padding_mask[trajectory_feature.shape[1] :] = 0
|
139 |
+
|
140 |
+
intrinsics_filename = filename + ".npy"
|
141 |
+
intrinsics_path = self.intrinsics_dir / intrinsics_filename
|
142 |
+
intrinsics = np.load(intrinsics_path)
|
143 |
+
|
144 |
+
return (
|
145 |
+
trajectory_filename,
|
146 |
+
padded_trajectory_feature,
|
147 |
+
padding_mask,
|
148 |
+
intrinsics
|
149 |
+
)
|
150 |
+
|
151 |
+
def __len__(self):
|
152 |
+
return len(self.filenames)
|
src/datasets/multimodal_dataset.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from copy import deepcopy as dp
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
|
6 |
+
|
7 |
+
class MultimodalDataset(Dataset):
|
8 |
+
def __init__(
|
9 |
+
self,
|
10 |
+
name,
|
11 |
+
dataset_name,
|
12 |
+
dataset_dir,
|
13 |
+
trajectory,
|
14 |
+
feature_type,
|
15 |
+
num_rawfeats,
|
16 |
+
num_feats,
|
17 |
+
num_cams,
|
18 |
+
num_cond_feats,
|
19 |
+
standardization,
|
20 |
+
augmentation=None,
|
21 |
+
**modalities,
|
22 |
+
):
|
23 |
+
self.dataset_dir = Path(dataset_dir)
|
24 |
+
self.name = name
|
25 |
+
self.dataset_name = dataset_name
|
26 |
+
self.feature_type = feature_type
|
27 |
+
self.num_rawfeats = num_rawfeats
|
28 |
+
self.num_feats = num_feats
|
29 |
+
self.num_cams = num_cams
|
30 |
+
self.trajectory_dataset = trajectory
|
31 |
+
self.standardization = standardization
|
32 |
+
self.modality_datasets = modalities
|
33 |
+
|
34 |
+
if augmentation is not None:
|
35 |
+
self.augmentation = True
|
36 |
+
self.augmentation_rate = augmentation.rate
|
37 |
+
self.trajectory_dataset.set_augmentation(augmentation.trajectory)
|
38 |
+
if hasattr(augmentation, "modalities"):
|
39 |
+
for modality, augments in augmentation.modalities:
|
40 |
+
self.modality_datasets[modality].set_augmentation(augments)
|
41 |
+
else:
|
42 |
+
self.augmentation = False
|
43 |
+
|
44 |
+
# --------------------------------------------------------------------------------- #
|
45 |
+
|
46 |
+
def set_split(self, split: str, train_rate: float = 1.0):
|
47 |
+
self.split = split
|
48 |
+
|
49 |
+
# Get trajectory split
|
50 |
+
self.trajectory_dataset = dp(self.trajectory_dataset).set_split(
|
51 |
+
split, train_rate
|
52 |
+
)
|
53 |
+
self.root_filenames = self.trajectory_dataset.filenames
|
54 |
+
|
55 |
+
# Get modality split
|
56 |
+
for modality_name in self.modality_datasets.keys():
|
57 |
+
self.modality_datasets[modality_name].filenames = self.root_filenames
|
58 |
+
|
59 |
+
self.get_feature = self.trajectory_dataset.get_feature
|
60 |
+
self.get_matrix = self.trajectory_dataset.get_matrix
|
61 |
+
|
62 |
+
return self
|
63 |
+
|
64 |
+
# --------------------------------------------------------------------------------- #
|
65 |
+
|
66 |
+
def __getitem__(self, index):
|
67 |
+
traj_out = self.trajectory_dataset[index]
|
68 |
+
traj_filename, traj_feature, padding_mask, intrinsics = traj_out
|
69 |
+
|
70 |
+
out = {
|
71 |
+
"traj_filename": traj_filename,
|
72 |
+
"traj_feat": traj_feature,
|
73 |
+
"padding_mask": padding_mask,
|
74 |
+
"intrinsics": intrinsics,
|
75 |
+
}
|
76 |
+
|
77 |
+
for modality_name, modality_dataset in self.modality_datasets.items():
|
78 |
+
modality_filename, modality_feature, modality_raw = modality_dataset[index]
|
79 |
+
assert traj_filename.split(".")[0] == modality_filename.split(".")[0]
|
80 |
+
out[f"{modality_name}_filename"] = modality_filename
|
81 |
+
out[f"{modality_name}_feat"] = modality_feature
|
82 |
+
out[f"{modality_name}_raw"] = modality_raw
|
83 |
+
out[f"{modality_name}_padding_mask"] = padding_mask
|
84 |
+
|
85 |
+
return out
|
86 |
+
|
87 |
+
def __len__(self):
|
88 |
+
return len(self.trajectory_dataset)
|
src/diffuser.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from omegaconf.dictconfig import DictConfig
|
2 |
+
from typing import List, Tuple
|
3 |
+
|
4 |
+
from ema_pytorch import EMA
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torchtyping import TensorType
|
8 |
+
import torch.nn as nn
|
9 |
+
import lightning as L
|
10 |
+
|
11 |
+
from utils.random_utils import StackedRandomGenerator
|
12 |
+
|
13 |
+
# ------------------------------------------------------------------------------------- #
|
14 |
+
|
15 |
+
batch_size, num_samples = None, None
|
16 |
+
num_feats, num_rawfeats, num_cams = None, None, None
|
17 |
+
RawTrajectory = TensorType["num_samples", "num_rawfeats", "num_cams"]
|
18 |
+
|
19 |
+
# ------------------------------------------------------------------------------------- #
|
20 |
+
|
21 |
+
|
22 |
+
class Diffuser(L.LightningModule):
|
23 |
+
def __init__(
|
24 |
+
self,
|
25 |
+
network: nn.Module,
|
26 |
+
guidance_weight: float,
|
27 |
+
ema_kwargs: DictConfig,
|
28 |
+
sampling_kwargs: DictConfig,
|
29 |
+
edm2_normalization: bool,
|
30 |
+
**kwargs,
|
31 |
+
):
|
32 |
+
super().__init__()
|
33 |
+
|
34 |
+
# Network and EMA
|
35 |
+
self.net = network
|
36 |
+
self.ema = EMA(self.net, **ema_kwargs)
|
37 |
+
self.guidance_weight = guidance_weight
|
38 |
+
self.edm2_normalization = edm2_normalization
|
39 |
+
self.sigma_data = network.sigma_data
|
40 |
+
|
41 |
+
# Sampling
|
42 |
+
self.num_steps = sampling_kwargs.num_steps
|
43 |
+
self.sigma_min = sampling_kwargs.sigma_min
|
44 |
+
self.sigma_max = sampling_kwargs.sigma_max
|
45 |
+
self.rho = sampling_kwargs.rho
|
46 |
+
self.S_churn = sampling_kwargs.S_churn
|
47 |
+
self.S_noise = sampling_kwargs.S_noise
|
48 |
+
self.S_min = sampling_kwargs.S_min
|
49 |
+
self.S_max = (
|
50 |
+
sampling_kwargs.S_max
|
51 |
+
if isinstance(sampling_kwargs.S_max, float)
|
52 |
+
else float("inf")
|
53 |
+
)
|
54 |
+
|
55 |
+
# ---------------------------------------------------------------------------------- #
|
56 |
+
|
57 |
+
def on_predict_start(self):
|
58 |
+
eval_dataset = self.trainer.datamodule.eval_dataset
|
59 |
+
self.modalities = list(eval_dataset.modality_datasets.keys())
|
60 |
+
|
61 |
+
self.get_matrix = self.trainer.datamodule.train_dataset.get_matrix
|
62 |
+
self.v_get_matrix = self.trainer.datamodule.eval_dataset.get_matrix
|
63 |
+
|
64 |
+
def predict_step(self, batch, batch_idx):
|
65 |
+
ref_samples, mask = batch["traj_feat"], batch["padding_mask"]
|
66 |
+
|
67 |
+
if len(self.modalities) > 0:
|
68 |
+
cond_k = [x for x in batch.keys() if "traj" not in x and "feat" in x]
|
69 |
+
cond_data = [batch[cond] for cond in cond_k]
|
70 |
+
conds = {}
|
71 |
+
for cond in cond_k:
|
72 |
+
cond_name = cond.replace("_feat", "")
|
73 |
+
if isinstance(batch[f"{cond_name}_raw"], dict):
|
74 |
+
for cond_name_, x in batch[f"{cond_name}_raw"].items():
|
75 |
+
conds[cond_name_] = x
|
76 |
+
else:
|
77 |
+
conds[cond_name] = batch[f"{cond_name}_raw"]
|
78 |
+
batch["conds"] = conds
|
79 |
+
else:
|
80 |
+
cond_data = None
|
81 |
+
|
82 |
+
# cf edm2 sigma_data normalization / https://arxiv.org/pdf/2312.02696.pdf
|
83 |
+
if self.edm2_normalization:
|
84 |
+
ref_samples *= self.sigma_data
|
85 |
+
_, gen_samples = self.sample(self.ema.ema_model, ref_samples, cond_data, mask)
|
86 |
+
|
87 |
+
batch["ref_samples"] = torch.stack([self.v_get_matrix(x) for x in ref_samples])
|
88 |
+
batch["gen_samples"] = torch.stack([self.get_matrix(x) for x in gen_samples])
|
89 |
+
|
90 |
+
return batch
|
91 |
+
|
92 |
+
# --------------------------------------------------------------------------------- #
|
93 |
+
|
94 |
+
def sample(
|
95 |
+
self,
|
96 |
+
net: torch.nn.Module,
|
97 |
+
traj_samples: RawTrajectory,
|
98 |
+
cond_samples: TensorType["num_samples", "num_feats"],
|
99 |
+
mask: TensorType["num_samples", "num_feats"],
|
100 |
+
external_seeds: List[int] = None,
|
101 |
+
) -> Tuple[RawTrajectory, RawTrajectory]:
|
102 |
+
# Pick latents
|
103 |
+
num_samples = traj_samples.shape[0]
|
104 |
+
seeds = self.gen_seeds if hasattr(self, "gen_seeds") else range(num_samples)
|
105 |
+
rnd = StackedRandomGenerator(self.device, seeds)
|
106 |
+
|
107 |
+
sz = [num_samples, self.net.num_feats, self.net.num_cams]
|
108 |
+
latents = rnd.randn_rn(sz, device=self.device)
|
109 |
+
# Generate trajectories.
|
110 |
+
generations = self.edm_sampler(
|
111 |
+
net,
|
112 |
+
latents,
|
113 |
+
class_labels=cond_samples,
|
114 |
+
mask=mask,
|
115 |
+
randn_like=rnd.randn_like,
|
116 |
+
guidance_weight=self.guidance_weight,
|
117 |
+
# ----------------------------------- #
|
118 |
+
num_steps=self.num_steps,
|
119 |
+
sigma_min=self.sigma_min,
|
120 |
+
sigma_max=self.sigma_max,
|
121 |
+
rho=self.rho,
|
122 |
+
S_churn=self.S_churn,
|
123 |
+
S_min=self.S_min,
|
124 |
+
S_max=self.S_max,
|
125 |
+
S_noise=self.S_noise,
|
126 |
+
)
|
127 |
+
|
128 |
+
return latents, generations
|
129 |
+
|
130 |
+
@staticmethod
|
131 |
+
def edm_sampler(
|
132 |
+
net,
|
133 |
+
latents,
|
134 |
+
class_labels=None,
|
135 |
+
mask=None,
|
136 |
+
guidance_weight=2.0,
|
137 |
+
randn_like=torch.randn_like,
|
138 |
+
num_steps=18,
|
139 |
+
sigma_min=0.002,
|
140 |
+
sigma_max=80,
|
141 |
+
rho=7,
|
142 |
+
S_churn=0,
|
143 |
+
S_min=0,
|
144 |
+
S_max=float("inf"),
|
145 |
+
S_noise=1,
|
146 |
+
):
|
147 |
+
# Time step discretization.
|
148 |
+
step_indices = torch.arange(num_steps, device=latents.device)
|
149 |
+
t_steps = (
|
150 |
+
sigma_max ** (1 / rho)
|
151 |
+
+ step_indices
|
152 |
+
/ (num_steps - 1)
|
153 |
+
* (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
|
154 |
+
) ** rho
|
155 |
+
t_steps = torch.cat(
|
156 |
+
[torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]
|
157 |
+
) # t_N = 0
|
158 |
+
|
159 |
+
# Main sampling loop.
|
160 |
+
bool_mask = ~mask.to(bool)
|
161 |
+
x_next = latents * t_steps[0]
|
162 |
+
bs = latents.shape[0]
|
163 |
+
for i, (t_cur, t_next) in enumerate(
|
164 |
+
zip(t_steps[:-1], t_steps[1:])
|
165 |
+
): # 0, ..., N-1
|
166 |
+
x_cur = x_next
|
167 |
+
|
168 |
+
# Increase noise temporarily.
|
169 |
+
gamma = (
|
170 |
+
min(S_churn / num_steps, np.sqrt(2) - 1)
|
171 |
+
if S_min <= t_cur <= S_max
|
172 |
+
else 0
|
173 |
+
)
|
174 |
+
t_hat = torch.as_tensor(t_cur + gamma * t_cur)
|
175 |
+
x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
|
176 |
+
|
177 |
+
# Euler step.
|
178 |
+
if class_labels is not None:
|
179 |
+
class_label_knot = [torch.zeros_like(label) for label in class_labels]
|
180 |
+
x_hat_both = torch.cat([x_hat, x_hat], dim=0)
|
181 |
+
y_label_both = [
|
182 |
+
torch.cat([y, y_knot], dim=0)
|
183 |
+
for y, y_knot in zip(class_labels, class_label_knot)
|
184 |
+
]
|
185 |
+
|
186 |
+
bool_mask_both = torch.cat([bool_mask, bool_mask], dim=0)
|
187 |
+
t_hat_both = torch.cat([t_hat.expand(bs), t_hat.expand(bs)], dim=0)
|
188 |
+
cond_denoised, denoised = net(
|
189 |
+
x_hat_both, t_hat_both, y=y_label_both, mask=bool_mask_both
|
190 |
+
).chunk(2, dim=0)
|
191 |
+
denoised = denoised + (cond_denoised - denoised) * guidance_weight
|
192 |
+
else:
|
193 |
+
denoised = net(x_hat, t_hat.expand(bs), mask=bool_mask)
|
194 |
+
d_cur = (x_hat - denoised) / t_hat
|
195 |
+
x_next = x_hat + (t_next - t_hat) * d_cur
|
196 |
+
|
197 |
+
# Apply 2nd order correction.
|
198 |
+
if i < num_steps - 1:
|
199 |
+
if class_labels is not None:
|
200 |
+
class_label_knot = [
|
201 |
+
torch.zeros_like(label) for label in class_labels
|
202 |
+
]
|
203 |
+
x_next_both = torch.cat([x_next, x_next], dim=0)
|
204 |
+
y_label_both = [
|
205 |
+
torch.cat([y, y_knot], dim=0)
|
206 |
+
for y, y_knot in zip(class_labels, class_label_knot)
|
207 |
+
]
|
208 |
+
bool_mask_both = torch.cat([bool_mask, bool_mask], dim=0)
|
209 |
+
t_next_both = torch.cat(
|
210 |
+
[t_next.expand(bs), t_next.expand(bs)], dim=0
|
211 |
+
)
|
212 |
+
cond_denoised, denoised = net(
|
213 |
+
x_next_both, t_next_both, y=y_label_both, mask=bool_mask_both
|
214 |
+
).chunk(2, dim=0)
|
215 |
+
denoised = denoised + (cond_denoised - denoised) * guidance_weight
|
216 |
+
else:
|
217 |
+
denoised = net(x_next, t_next.expand(bs), mask=bool_mask)
|
218 |
+
d_prime = (x_next - denoised) / t_next
|
219 |
+
x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
|
220 |
+
|
221 |
+
return x_next
|
src/models/modules/director.py
ADDED
@@ -0,0 +1,1154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torch import Tensor
|
4 |
+
import numpy as np
|
5 |
+
from einops import rearrange
|
6 |
+
|
7 |
+
from typing import Optional, List
|
8 |
+
from torchtyping import TensorType
|
9 |
+
from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
|
10 |
+
|
11 |
+
allow_ops_in_compiled_graph()
|
12 |
+
|
13 |
+
batch_size, num_cond_feats = None, None
|
14 |
+
|
15 |
+
|
16 |
+
class FusedMLP(nn.Sequential):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim_model: int,
|
20 |
+
dropout: float,
|
21 |
+
activation: nn.Module,
|
22 |
+
hidden_layer_multiplier: int = 4,
|
23 |
+
bias: bool = True,
|
24 |
+
):
|
25 |
+
super().__init__(
|
26 |
+
nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias),
|
27 |
+
activation(),
|
28 |
+
nn.Dropout(dropout),
|
29 |
+
nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias),
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
def _cast_if_autocast_enabled(tensor):
|
34 |
+
if torch.is_autocast_enabled():
|
35 |
+
if tensor.device.type == "cuda":
|
36 |
+
dtype = torch.get_autocast_gpu_dtype()
|
37 |
+
elif tensor.device.type == "cpu":
|
38 |
+
dtype = torch.get_autocast_cpu_dtype()
|
39 |
+
else:
|
40 |
+
raise NotImplementedError()
|
41 |
+
return tensor.to(dtype=dtype)
|
42 |
+
return tensor
|
43 |
+
|
44 |
+
|
45 |
+
class LayerNorm16Bits(torch.nn.LayerNorm):
|
46 |
+
"""
|
47 |
+
16-bit friendly version of torch.nn.LayerNorm
|
48 |
+
"""
|
49 |
+
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
normalized_shape,
|
53 |
+
eps=1e-06,
|
54 |
+
elementwise_affine=True,
|
55 |
+
device=None,
|
56 |
+
dtype=None,
|
57 |
+
):
|
58 |
+
super().__init__(
|
59 |
+
normalized_shape=normalized_shape,
|
60 |
+
eps=eps,
|
61 |
+
elementwise_affine=elementwise_affine,
|
62 |
+
device=device,
|
63 |
+
dtype=dtype,
|
64 |
+
)
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
module_device = x.device
|
68 |
+
downcast_x = _cast_if_autocast_enabled(x)
|
69 |
+
downcast_weight = (
|
70 |
+
_cast_if_autocast_enabled(self.weight)
|
71 |
+
if self.weight is not None
|
72 |
+
else self.weight
|
73 |
+
)
|
74 |
+
downcast_bias = (
|
75 |
+
_cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
|
76 |
+
)
|
77 |
+
with torch.autocast(enabled=False, device_type=module_device.type):
|
78 |
+
return nn.functional.layer_norm(
|
79 |
+
downcast_x,
|
80 |
+
self.normalized_shape,
|
81 |
+
downcast_weight,
|
82 |
+
downcast_bias,
|
83 |
+
self.eps,
|
84 |
+
)
|
85 |
+
|
86 |
+
|
87 |
+
class StochatichDepth(nn.Module):
|
88 |
+
def __init__(self, p: float):
|
89 |
+
super().__init__()
|
90 |
+
self.survival_prob = 1.0 - p
|
91 |
+
|
92 |
+
def forward(self, x: Tensor) -> Tensor:
|
93 |
+
if self.training and self.survival_prob < 1:
|
94 |
+
mask = (
|
95 |
+
torch.empty(x.shape[0], 1, 1, device=x.device).uniform_()
|
96 |
+
+ self.survival_prob
|
97 |
+
)
|
98 |
+
mask = mask.floor()
|
99 |
+
if self.survival_prob > 0:
|
100 |
+
mask = mask / self.survival_prob
|
101 |
+
return x * mask
|
102 |
+
else:
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
class CrossAttentionOp(nn.Module):
|
107 |
+
def __init__(
|
108 |
+
self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False
|
109 |
+
):
|
110 |
+
super().__init__()
|
111 |
+
self.dim_q = dim_q
|
112 |
+
self.dim_kv = dim_kv
|
113 |
+
self.attention_dim = attention_dim
|
114 |
+
self.num_heads = num_heads
|
115 |
+
self.use_biases = use_biases
|
116 |
+
self.is_sa = is_sa
|
117 |
+
if self.is_sa:
|
118 |
+
self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases)
|
119 |
+
else:
|
120 |
+
self.q = nn.Linear(dim_q, attention_dim, bias=use_biases)
|
121 |
+
self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases)
|
122 |
+
self.out = nn.Linear(attention_dim, dim_q, bias=use_biases)
|
123 |
+
|
124 |
+
def forward(self, x_to, x_from=None, attention_mask=None):
|
125 |
+
if x_from is None:
|
126 |
+
x_from = x_to
|
127 |
+
if self.is_sa:
|
128 |
+
q, k, v = self.qkv(x_to).chunk(3, dim=-1)
|
129 |
+
else:
|
130 |
+
q = self.q(x_to)
|
131 |
+
k, v = self.kv(x_from).chunk(2, dim=-1)
|
132 |
+
q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
|
133 |
+
k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
|
134 |
+
v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
|
135 |
+
if attention_mask is not None:
|
136 |
+
attention_mask = attention_mask.unsqueeze(1)
|
137 |
+
x = torch.nn.functional.scaled_dot_product_attention(
|
138 |
+
q, k, v, attn_mask=attention_mask
|
139 |
+
)
|
140 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
141 |
+
x = self.out(x)
|
142 |
+
return x
|
143 |
+
|
144 |
+
|
145 |
+
class CrossAttentionBlock(nn.Module):
|
146 |
+
def __init__(
|
147 |
+
self,
|
148 |
+
dim_q: int,
|
149 |
+
dim_kv: int,
|
150 |
+
num_heads: int,
|
151 |
+
attention_dim: int = 0,
|
152 |
+
mlp_multiplier: int = 4,
|
153 |
+
dropout: float = 0.0,
|
154 |
+
stochastic_depth: float = 0.0,
|
155 |
+
use_biases: bool = True,
|
156 |
+
retrieve_attention_scores: bool = False,
|
157 |
+
use_layernorm16: bool = True,
|
158 |
+
):
|
159 |
+
super().__init__()
|
160 |
+
layer_norm = (
|
161 |
+
nn.LayerNorm
|
162 |
+
if not use_layernorm16 or retrieve_attention_scores
|
163 |
+
else LayerNorm16Bits
|
164 |
+
)
|
165 |
+
self.retrieve_attention_scores = retrieve_attention_scores
|
166 |
+
self.initial_to_ln = layer_norm(dim_q, eps=1e-6)
|
167 |
+
attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim
|
168 |
+
self.ca = CrossAttentionOp(
|
169 |
+
attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases
|
170 |
+
)
|
171 |
+
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
|
172 |
+
self.middle_ln = layer_norm(dim_q, eps=1e-6)
|
173 |
+
self.ffn = FusedMLP(
|
174 |
+
dim_model=dim_q,
|
175 |
+
dropout=dropout,
|
176 |
+
activation=nn.GELU,
|
177 |
+
hidden_layer_multiplier=mlp_multiplier,
|
178 |
+
bias=use_biases,
|
179 |
+
)
|
180 |
+
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
|
181 |
+
|
182 |
+
def forward(
|
183 |
+
self,
|
184 |
+
to_tokens: Tensor,
|
185 |
+
from_tokens: Tensor,
|
186 |
+
to_token_mask: Optional[Tensor] = None,
|
187 |
+
from_token_mask: Optional[Tensor] = None,
|
188 |
+
) -> Tensor:
|
189 |
+
if to_token_mask is None and from_token_mask is None:
|
190 |
+
attention_mask = None
|
191 |
+
else:
|
192 |
+
if to_token_mask is None:
|
193 |
+
to_token_mask = torch.ones(
|
194 |
+
to_tokens.shape[0],
|
195 |
+
to_tokens.shape[1],
|
196 |
+
dtype=torch.bool,
|
197 |
+
device=to_tokens.device,
|
198 |
+
)
|
199 |
+
if from_token_mask is None:
|
200 |
+
from_token_mask = torch.ones(
|
201 |
+
from_tokens.shape[0],
|
202 |
+
from_tokens.shape[1],
|
203 |
+
dtype=torch.bool,
|
204 |
+
device=from_tokens.device,
|
205 |
+
)
|
206 |
+
attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2)
|
207 |
+
attention_output = self.ca(
|
208 |
+
self.initial_to_ln(to_tokens),
|
209 |
+
from_tokens,
|
210 |
+
attention_mask=attention_mask,
|
211 |
+
)
|
212 |
+
to_tokens = to_tokens + self.ca_stochastic_depth(attention_output)
|
213 |
+
to_tokens = to_tokens + self.ffn_stochastic_depth(
|
214 |
+
self.ffn(self.middle_ln(to_tokens))
|
215 |
+
)
|
216 |
+
return to_tokens
|
217 |
+
|
218 |
+
|
219 |
+
class SelfAttentionBlock(nn.Module):
|
220 |
+
def __init__(
|
221 |
+
self,
|
222 |
+
dim_qkv: int,
|
223 |
+
num_heads: int,
|
224 |
+
attention_dim: int = 0,
|
225 |
+
mlp_multiplier: int = 4,
|
226 |
+
dropout: float = 0.0,
|
227 |
+
stochastic_depth: float = 0.0,
|
228 |
+
use_biases: bool = True,
|
229 |
+
use_layer_scale: bool = False,
|
230 |
+
layer_scale_value: float = 0.0,
|
231 |
+
use_layernorm16: bool = True,
|
232 |
+
):
|
233 |
+
super().__init__()
|
234 |
+
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
|
235 |
+
self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
|
236 |
+
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
|
237 |
+
self.sa = CrossAttentionOp(
|
238 |
+
attention_dim,
|
239 |
+
num_heads,
|
240 |
+
dim_qkv,
|
241 |
+
dim_qkv,
|
242 |
+
is_sa=True,
|
243 |
+
use_biases=use_biases,
|
244 |
+
)
|
245 |
+
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
|
246 |
+
self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
|
247 |
+
self.ffn = FusedMLP(
|
248 |
+
dim_model=dim_qkv,
|
249 |
+
dropout=dropout,
|
250 |
+
activation=nn.GELU,
|
251 |
+
hidden_layer_multiplier=mlp_multiplier,
|
252 |
+
bias=use_biases,
|
253 |
+
)
|
254 |
+
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
|
255 |
+
self.use_layer_scale = use_layer_scale
|
256 |
+
if use_layer_scale:
|
257 |
+
self.layer_scale_1 = nn.Parameter(
|
258 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
259 |
+
)
|
260 |
+
self.layer_scale_2 = nn.Parameter(
|
261 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
262 |
+
)
|
263 |
+
|
264 |
+
def forward(
|
265 |
+
self,
|
266 |
+
tokens: torch.Tensor,
|
267 |
+
token_mask: Optional[torch.Tensor] = None,
|
268 |
+
):
|
269 |
+
if token_mask is None:
|
270 |
+
attention_mask = None
|
271 |
+
else:
|
272 |
+
attention_mask = token_mask.unsqueeze(1) * torch.ones(
|
273 |
+
tokens.shape[0],
|
274 |
+
tokens.shape[1],
|
275 |
+
1,
|
276 |
+
dtype=torch.bool,
|
277 |
+
device=tokens.device,
|
278 |
+
)
|
279 |
+
attention_output = self.sa(
|
280 |
+
self.initial_ln(tokens),
|
281 |
+
attention_mask=attention_mask,
|
282 |
+
)
|
283 |
+
if self.use_layer_scale:
|
284 |
+
tokens = tokens + self.sa_stochastic_depth(
|
285 |
+
self.layer_scale_1 * attention_output
|
286 |
+
)
|
287 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
288 |
+
self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
tokens = tokens + self.sa_stochastic_depth(attention_output)
|
292 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
293 |
+
self.ffn(self.middle_ln(tokens))
|
294 |
+
)
|
295 |
+
return tokens
|
296 |
+
|
297 |
+
|
298 |
+
class AdaLNSABlock(nn.Module):
|
299 |
+
def __init__(
|
300 |
+
self,
|
301 |
+
dim_qkv: int,
|
302 |
+
dim_cond: int,
|
303 |
+
num_heads: int,
|
304 |
+
attention_dim: int = 0,
|
305 |
+
mlp_multiplier: int = 4,
|
306 |
+
dropout: float = 0.0,
|
307 |
+
stochastic_depth: float = 0.0,
|
308 |
+
use_biases: bool = True,
|
309 |
+
use_layer_scale: bool = False,
|
310 |
+
layer_scale_value: float = 0.1,
|
311 |
+
use_layernorm16: bool = True,
|
312 |
+
):
|
313 |
+
super().__init__()
|
314 |
+
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
|
315 |
+
self.initial_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False)
|
316 |
+
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
|
317 |
+
self.adaln_modulation = nn.Sequential(
|
318 |
+
nn.SiLU(),
|
319 |
+
nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases),
|
320 |
+
)
|
321 |
+
# Zero init
|
322 |
+
nn.init.zeros_(self.adaln_modulation[1].weight)
|
323 |
+
nn.init.zeros_(self.adaln_modulation[1].bias)
|
324 |
+
|
325 |
+
self.sa = CrossAttentionOp(
|
326 |
+
attention_dim,
|
327 |
+
num_heads,
|
328 |
+
dim_qkv,
|
329 |
+
dim_qkv,
|
330 |
+
is_sa=True,
|
331 |
+
use_biases=use_biases,
|
332 |
+
)
|
333 |
+
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
|
334 |
+
self.middle_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False)
|
335 |
+
self.ffn = FusedMLP(
|
336 |
+
dim_model=dim_qkv,
|
337 |
+
dropout=dropout,
|
338 |
+
activation=nn.GELU,
|
339 |
+
hidden_layer_multiplier=mlp_multiplier,
|
340 |
+
bias=use_biases,
|
341 |
+
)
|
342 |
+
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
|
343 |
+
self.use_layer_scale = use_layer_scale
|
344 |
+
if use_layer_scale:
|
345 |
+
self.layer_scale_1 = nn.Parameter(
|
346 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
347 |
+
)
|
348 |
+
self.layer_scale_2 = nn.Parameter(
|
349 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
350 |
+
)
|
351 |
+
|
352 |
+
def forward(
|
353 |
+
self,
|
354 |
+
tokens: torch.Tensor,
|
355 |
+
cond: torch.Tensor,
|
356 |
+
token_mask: Optional[torch.Tensor] = None,
|
357 |
+
):
|
358 |
+
if token_mask is None:
|
359 |
+
attention_mask = None
|
360 |
+
else:
|
361 |
+
attention_mask = token_mask.unsqueeze(1) * torch.ones(
|
362 |
+
tokens.shape[0],
|
363 |
+
tokens.shape[1],
|
364 |
+
1,
|
365 |
+
dtype=torch.bool,
|
366 |
+
device=tokens.device,
|
367 |
+
)
|
368 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
369 |
+
self.adaln_modulation(cond).chunk(6, dim=-1)
|
370 |
+
)
|
371 |
+
attention_output = self.sa(
|
372 |
+
modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa),
|
373 |
+
attention_mask=attention_mask,
|
374 |
+
)
|
375 |
+
if self.use_layer_scale:
|
376 |
+
tokens = tokens + self.sa_stochastic_depth(
|
377 |
+
gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output
|
378 |
+
)
|
379 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
380 |
+
gate_mlp.unsqueeze(1)
|
381 |
+
* self.layer_scale_2
|
382 |
+
* self.ffn(
|
383 |
+
modulate_shift_and_scale(
|
384 |
+
self.middle_ln(tokens), shift_mlp, scale_mlp
|
385 |
+
)
|
386 |
+
)
|
387 |
+
)
|
388 |
+
else:
|
389 |
+
tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth(
|
390 |
+
attention_output
|
391 |
+
)
|
392 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
393 |
+
gate_mlp.unsqueeze(1)
|
394 |
+
* self.ffn(
|
395 |
+
modulate_shift_and_scale(
|
396 |
+
self.middle_ln(tokens), shift_mlp, scale_mlp
|
397 |
+
)
|
398 |
+
)
|
399 |
+
)
|
400 |
+
return tokens
|
401 |
+
|
402 |
+
|
403 |
+
class CrossAttentionSABlock(nn.Module):
|
404 |
+
def __init__(
|
405 |
+
self,
|
406 |
+
dim_qkv: int,
|
407 |
+
dim_cond: int,
|
408 |
+
num_heads: int,
|
409 |
+
attention_dim: int = 0,
|
410 |
+
mlp_multiplier: int = 4,
|
411 |
+
dropout: float = 0.0,
|
412 |
+
stochastic_depth: float = 0.0,
|
413 |
+
use_biases: bool = True,
|
414 |
+
use_layer_scale: bool = False,
|
415 |
+
layer_scale_value: float = 0.0,
|
416 |
+
use_layernorm16: bool = True,
|
417 |
+
):
|
418 |
+
super().__init__()
|
419 |
+
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
|
420 |
+
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
|
421 |
+
self.ca = CrossAttentionOp(
|
422 |
+
attention_dim,
|
423 |
+
num_heads,
|
424 |
+
dim_qkv,
|
425 |
+
dim_cond,
|
426 |
+
is_sa=False,
|
427 |
+
use_biases=use_biases,
|
428 |
+
)
|
429 |
+
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
|
430 |
+
self.ca_ln = layer_norm(dim_qkv, eps=1e-6)
|
431 |
+
|
432 |
+
self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
|
433 |
+
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
|
434 |
+
|
435 |
+
self.sa = CrossAttentionOp(
|
436 |
+
attention_dim,
|
437 |
+
num_heads,
|
438 |
+
dim_qkv,
|
439 |
+
dim_qkv,
|
440 |
+
is_sa=True,
|
441 |
+
use_biases=use_biases,
|
442 |
+
)
|
443 |
+
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
|
444 |
+
self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
|
445 |
+
self.ffn = FusedMLP(
|
446 |
+
dim_model=dim_qkv,
|
447 |
+
dropout=dropout,
|
448 |
+
activation=nn.GELU,
|
449 |
+
hidden_layer_multiplier=mlp_multiplier,
|
450 |
+
bias=use_biases,
|
451 |
+
)
|
452 |
+
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
|
453 |
+
self.use_layer_scale = use_layer_scale
|
454 |
+
if use_layer_scale:
|
455 |
+
self.layer_scale_1 = nn.Parameter(
|
456 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
457 |
+
)
|
458 |
+
self.layer_scale_2 = nn.Parameter(
|
459 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
460 |
+
)
|
461 |
+
|
462 |
+
def forward(
|
463 |
+
self,
|
464 |
+
tokens: torch.Tensor,
|
465 |
+
cond: torch.Tensor,
|
466 |
+
token_mask: Optional[torch.Tensor] = None,
|
467 |
+
cond_mask: Optional[torch.Tensor] = None,
|
468 |
+
):
|
469 |
+
if cond_mask is None:
|
470 |
+
cond_attention_mask = None
|
471 |
+
else:
|
472 |
+
cond_attention_mask = torch.ones(
|
473 |
+
cond.shape[0],
|
474 |
+
1,
|
475 |
+
cond.shape[1],
|
476 |
+
dtype=torch.bool,
|
477 |
+
device=tokens.device,
|
478 |
+
) * token_mask.unsqueeze(2)
|
479 |
+
if token_mask is None:
|
480 |
+
attention_mask = None
|
481 |
+
else:
|
482 |
+
attention_mask = token_mask.unsqueeze(1) * torch.ones(
|
483 |
+
tokens.shape[0],
|
484 |
+
tokens.shape[1],
|
485 |
+
1,
|
486 |
+
dtype=torch.bool,
|
487 |
+
device=tokens.device,
|
488 |
+
)
|
489 |
+
ca_output = self.ca(
|
490 |
+
self.ca_ln(tokens),
|
491 |
+
cond,
|
492 |
+
attention_mask=cond_attention_mask,
|
493 |
+
)
|
494 |
+
ca_output = torch.nan_to_num(
|
495 |
+
ca_output, nan=0.0, posinf=0.0, neginf=0.0
|
496 |
+
) # Needed as some tokens get attention from no token so Nan
|
497 |
+
tokens = tokens + self.ca_stochastic_depth(ca_output)
|
498 |
+
attention_output = self.sa(
|
499 |
+
self.initial_ln(tokens),
|
500 |
+
attention_mask=attention_mask,
|
501 |
+
)
|
502 |
+
if self.use_layer_scale:
|
503 |
+
tokens = tokens + self.sa_stochastic_depth(
|
504 |
+
self.layer_scale_1 * attention_output
|
505 |
+
)
|
506 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
507 |
+
self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
|
508 |
+
)
|
509 |
+
else:
|
510 |
+
tokens = tokens + self.sa_stochastic_depth(attention_output)
|
511 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
512 |
+
self.ffn(self.middle_ln(tokens))
|
513 |
+
)
|
514 |
+
return tokens
|
515 |
+
|
516 |
+
|
517 |
+
class CAAdaLNSABlock(nn.Module):
|
518 |
+
def __init__(
|
519 |
+
self,
|
520 |
+
dim_qkv: int,
|
521 |
+
dim_cond: int,
|
522 |
+
num_heads: int,
|
523 |
+
attention_dim: int = 0,
|
524 |
+
mlp_multiplier: int = 4,
|
525 |
+
dropout: float = 0.0,
|
526 |
+
stochastic_depth: float = 0.0,
|
527 |
+
use_biases: bool = True,
|
528 |
+
use_layer_scale: bool = False,
|
529 |
+
layer_scale_value: float = 0.1,
|
530 |
+
use_layernorm16: bool = True,
|
531 |
+
):
|
532 |
+
super().__init__()
|
533 |
+
layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
|
534 |
+
self.ca = CrossAttentionOp(
|
535 |
+
attention_dim,
|
536 |
+
num_heads,
|
537 |
+
dim_qkv,
|
538 |
+
dim_cond,
|
539 |
+
is_sa=False,
|
540 |
+
use_biases=use_biases,
|
541 |
+
)
|
542 |
+
self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
|
543 |
+
self.ca_ln = layer_norm(dim_qkv, eps=1e-6)
|
544 |
+
self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
|
545 |
+
attention_dim = dim_qkv if attention_dim == 0 else attention_dim
|
546 |
+
self.adaln_modulation = nn.Sequential(
|
547 |
+
nn.SiLU(),
|
548 |
+
nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases),
|
549 |
+
)
|
550 |
+
# Zero init
|
551 |
+
nn.init.zeros_(self.adaln_modulation[1].weight)
|
552 |
+
nn.init.zeros_(self.adaln_modulation[1].bias)
|
553 |
+
|
554 |
+
self.sa = CrossAttentionOp(
|
555 |
+
attention_dim,
|
556 |
+
num_heads,
|
557 |
+
dim_qkv,
|
558 |
+
dim_qkv,
|
559 |
+
is_sa=True,
|
560 |
+
use_biases=use_biases,
|
561 |
+
)
|
562 |
+
self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
|
563 |
+
self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
|
564 |
+
self.ffn = FusedMLP(
|
565 |
+
dim_model=dim_qkv,
|
566 |
+
dropout=dropout,
|
567 |
+
activation=nn.GELU,
|
568 |
+
hidden_layer_multiplier=mlp_multiplier,
|
569 |
+
bias=use_biases,
|
570 |
+
)
|
571 |
+
self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
|
572 |
+
self.use_layer_scale = use_layer_scale
|
573 |
+
if use_layer_scale:
|
574 |
+
self.layer_scale_1 = nn.Parameter(
|
575 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
576 |
+
)
|
577 |
+
self.layer_scale_2 = nn.Parameter(
|
578 |
+
torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
|
579 |
+
)
|
580 |
+
|
581 |
+
def forward(
|
582 |
+
self,
|
583 |
+
tokens: torch.Tensor,
|
584 |
+
cond_1: torch.Tensor,
|
585 |
+
cond_2: torch.Tensor,
|
586 |
+
cond_1_mask: Optional[torch.Tensor] = None,
|
587 |
+
token_mask: Optional[torch.Tensor] = None,
|
588 |
+
):
|
589 |
+
if token_mask is None and cond_1_mask is None:
|
590 |
+
cond_attention_mask = None
|
591 |
+
elif token_mask is None:
|
592 |
+
cond_attention_mask = cond_1_mask.unsqueeze(1) * torch.ones(
|
593 |
+
cond_1.shape[0],
|
594 |
+
cond_1.shape[1],
|
595 |
+
1,
|
596 |
+
dtype=torch.bool,
|
597 |
+
device=cond_1.device,
|
598 |
+
)
|
599 |
+
elif cond_1_mask is None:
|
600 |
+
cond_attention_mask = torch.ones(
|
601 |
+
tokens.shape[0],
|
602 |
+
1,
|
603 |
+
tokens.shape[1],
|
604 |
+
dtype=torch.bool,
|
605 |
+
device=tokens.device,
|
606 |
+
) * token_mask.unsqueeze(2)
|
607 |
+
else:
|
608 |
+
cond_attention_mask = cond_1_mask.unsqueeze(1) * token_mask.unsqueeze(2)
|
609 |
+
if token_mask is None:
|
610 |
+
attention_mask = None
|
611 |
+
else:
|
612 |
+
attention_mask = token_mask.unsqueeze(1) * torch.ones(
|
613 |
+
tokens.shape[0],
|
614 |
+
tokens.shape[1],
|
615 |
+
1,
|
616 |
+
dtype=torch.bool,
|
617 |
+
device=tokens.device,
|
618 |
+
)
|
619 |
+
ca_output = self.ca(
|
620 |
+
self.ca_ln(tokens),
|
621 |
+
cond_1,
|
622 |
+
attention_mask=cond_attention_mask,
|
623 |
+
)
|
624 |
+
ca_output = torch.nan_to_num(ca_output, nan=0.0, posinf=0.0, neginf=0.0)
|
625 |
+
tokens = tokens + self.ca_stochastic_depth(ca_output)
|
626 |
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
|
627 |
+
self.adaln_modulation(cond_2).chunk(6, dim=-1)
|
628 |
+
)
|
629 |
+
attention_output = self.sa(
|
630 |
+
modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa),
|
631 |
+
attention_mask=attention_mask,
|
632 |
+
)
|
633 |
+
if self.use_layer_scale:
|
634 |
+
tokens = tokens + self.sa_stochastic_depth(
|
635 |
+
gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output
|
636 |
+
)
|
637 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
638 |
+
gate_mlp.unsqueeze(1)
|
639 |
+
* self.layer_scale_2
|
640 |
+
* self.ffn(
|
641 |
+
modulate_shift_and_scale(
|
642 |
+
self.middle_ln(tokens), shift_mlp, scale_mlp
|
643 |
+
)
|
644 |
+
)
|
645 |
+
)
|
646 |
+
else:
|
647 |
+
tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth(
|
648 |
+
attention_output
|
649 |
+
)
|
650 |
+
tokens = tokens + self.ffn_stochastic_depth(
|
651 |
+
gate_mlp.unsqueeze(1)
|
652 |
+
* self.ffn(
|
653 |
+
modulate_shift_and_scale(
|
654 |
+
self.middle_ln(tokens), shift_mlp, scale_mlp
|
655 |
+
)
|
656 |
+
)
|
657 |
+
)
|
658 |
+
return tokens
|
659 |
+
|
660 |
+
|
661 |
+
class PositionalEmbedding(nn.Module):
|
662 |
+
"""
|
663 |
+
Taken from https://github.com/NVlabs/edm
|
664 |
+
"""
|
665 |
+
|
666 |
+
def __init__(self, num_channels, max_positions=10000, endpoint=False):
|
667 |
+
super().__init__()
|
668 |
+
self.num_channels = num_channels
|
669 |
+
self.max_positions = max_positions
|
670 |
+
self.endpoint = endpoint
|
671 |
+
freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32)
|
672 |
+
freqs = 2 * freqs / self.num_channels
|
673 |
+
freqs = (1 / self.max_positions) ** freqs
|
674 |
+
self.register_buffer("freqs", freqs)
|
675 |
+
|
676 |
+
def forward(self, x):
|
677 |
+
x = torch.outer(x, self.freqs)
|
678 |
+
out = torch.cat([x.cos(), x.sin()], dim=1)
|
679 |
+
return out.to(x.dtype)
|
680 |
+
|
681 |
+
|
682 |
+
class PositionalEncoding(nn.Module):
|
683 |
+
def __init__(self, d_model, dropout=0.0, max_len=10000):
|
684 |
+
super().__init__()
|
685 |
+
self.dropout = nn.Dropout(p=dropout)
|
686 |
+
|
687 |
+
pe = torch.zeros(max_len, d_model)
|
688 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
689 |
+
div_term = torch.exp(
|
690 |
+
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
|
691 |
+
)
|
692 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
693 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
694 |
+
pe = pe.unsqueeze(0)
|
695 |
+
|
696 |
+
self.register_buffer("pe", pe)
|
697 |
+
|
698 |
+
def forward(self, x):
|
699 |
+
# not used in the final model
|
700 |
+
x = x + self.pe[:, : x.shape[1], :]
|
701 |
+
return self.dropout(x)
|
702 |
+
|
703 |
+
|
704 |
+
class TimeEmbedder(nn.Module):
|
705 |
+
def __init__(
|
706 |
+
self,
|
707 |
+
dim: int,
|
708 |
+
time_scaling: float,
|
709 |
+
expansion: int = 4,
|
710 |
+
):
|
711 |
+
super().__init__()
|
712 |
+
self.encode_time = PositionalEmbedding(num_channels=dim, endpoint=True)
|
713 |
+
|
714 |
+
self.time_scaling = time_scaling
|
715 |
+
self.map_time = nn.Sequential(
|
716 |
+
nn.Linear(dim, dim * expansion),
|
717 |
+
nn.SiLU(),
|
718 |
+
nn.Linear(dim * expansion, dim * expansion),
|
719 |
+
)
|
720 |
+
|
721 |
+
def forward(self, t: Tensor) -> Tensor:
|
722 |
+
time = self.encode_time(t * self.time_scaling)
|
723 |
+
time_mean = time.mean(dim=-1, keepdim=True)
|
724 |
+
time_std = time.std(dim=-1, keepdim=True)
|
725 |
+
time = (time - time_mean) / time_std
|
726 |
+
return self.map_time(time)
|
727 |
+
|
728 |
+
|
729 |
+
def modulate_shift_and_scale(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
|
730 |
+
return x * (1 + scale).unsqueeze(1) + shift.unsqueeze(1)
|
731 |
+
|
732 |
+
|
733 |
+
# ------------------------------------------------------------------------------------- #
|
734 |
+
|
735 |
+
|
736 |
+
class BaseDirector(nn.Module):
|
737 |
+
def __init__(
|
738 |
+
self,
|
739 |
+
name: str,
|
740 |
+
num_feats: int,
|
741 |
+
num_cond_feats: int,
|
742 |
+
num_cams: int,
|
743 |
+
latent_dim: int,
|
744 |
+
mlp_multiplier: int,
|
745 |
+
num_layers: int,
|
746 |
+
num_heads: int,
|
747 |
+
dropout: float,
|
748 |
+
stochastic_depth: float,
|
749 |
+
label_dropout: float,
|
750 |
+
num_rawfeats: int,
|
751 |
+
clip_sequential: bool = False,
|
752 |
+
cond_sequential: bool = False,
|
753 |
+
device: str = "cuda",
|
754 |
+
**kwargs,
|
755 |
+
):
|
756 |
+
super().__init__()
|
757 |
+
self.name = name
|
758 |
+
self.label_dropout = label_dropout
|
759 |
+
self.num_rawfeats = num_rawfeats
|
760 |
+
self.num_feats = num_feats
|
761 |
+
self.num_cams = num_cams
|
762 |
+
self.clip_sequential = clip_sequential
|
763 |
+
self.cond_sequential = cond_sequential
|
764 |
+
self.use_layernorm16 = device == "cuda"
|
765 |
+
|
766 |
+
self.input_projection = nn.Sequential(
|
767 |
+
nn.Linear(num_feats, latent_dim),
|
768 |
+
PositionalEncoding(latent_dim),
|
769 |
+
)
|
770 |
+
self.time_embedding = TimeEmbedder(latent_dim // 4, time_scaling=1000)
|
771 |
+
self.init_conds_mappings(num_cond_feats, latent_dim)
|
772 |
+
self.init_backbone(
|
773 |
+
num_layers, latent_dim, mlp_multiplier, num_heads, dropout, stochastic_depth
|
774 |
+
)
|
775 |
+
self.init_output_projection(num_feats, latent_dim)
|
776 |
+
|
777 |
+
def forward(
|
778 |
+
self,
|
779 |
+
x: Tensor,
|
780 |
+
timesteps: Tensor,
|
781 |
+
y: List[Tensor] = None,
|
782 |
+
mask: Tensor = None,
|
783 |
+
) -> Tensor:
|
784 |
+
mask = mask.logical_not() if mask is not None else None
|
785 |
+
x = rearrange(x, "b c n -> b n c")
|
786 |
+
x = self.input_projection(x)
|
787 |
+
t = self.time_embedding(timesteps)
|
788 |
+
if y is not None:
|
789 |
+
y = self.mask_cond(y)
|
790 |
+
y = self.cond_mapping(y, mask, t)
|
791 |
+
|
792 |
+
x = self.backbone(x, y, mask)
|
793 |
+
x = self.output_projection(x, y)
|
794 |
+
return rearrange(x, "b n c -> b c n")
|
795 |
+
|
796 |
+
def init_conds_mappings(self, num_cond_feats, latent_dim):
|
797 |
+
raise NotImplementedError(
|
798 |
+
"This method should be implemented in the derived class"
|
799 |
+
)
|
800 |
+
|
801 |
+
def init_backbone(self):
|
802 |
+
raise NotImplementedError(
|
803 |
+
"This method should be implemented in the derived class"
|
804 |
+
)
|
805 |
+
|
806 |
+
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
|
807 |
+
raise NotImplementedError(
|
808 |
+
"This method should be implemented in the derived class"
|
809 |
+
)
|
810 |
+
|
811 |
+
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
|
812 |
+
raise NotImplementedError(
|
813 |
+
"This method should be implemented in the derived class"
|
814 |
+
)
|
815 |
+
|
816 |
+
def mask_cond(
|
817 |
+
self, cond: List[TensorType["batch_size", "num_cond_feats"]]
|
818 |
+
) -> TensorType["batch_size", "num_cond_feats"]:
|
819 |
+
bs = cond[0].shape[0]
|
820 |
+
if self.training and self.label_dropout > 0.0:
|
821 |
+
# 1-> use null_cond, 0-> use real cond
|
822 |
+
prob = torch.ones(bs, device=cond[0].device) * self.label_dropout
|
823 |
+
masked_cond = []
|
824 |
+
common_mask = torch.bernoulli(prob) # Common to all modalities
|
825 |
+
for _cond in cond:
|
826 |
+
modality_mask = torch.bernoulli(prob) # Modality only
|
827 |
+
mask = torch.clip(common_mask + modality_mask, 0, 1)
|
828 |
+
mask = mask.view(bs, 1, 1) if _cond.dim() == 3 else mask.view(bs, 1)
|
829 |
+
masked_cond.append(_cond * (1.0 - mask))
|
830 |
+
return masked_cond
|
831 |
+
else:
|
832 |
+
return cond
|
833 |
+
|
834 |
+
def init_output_projection(self, num_feats, latent_dim):
|
835 |
+
raise NotImplementedError(
|
836 |
+
"This method should be implemented in the derived class"
|
837 |
+
)
|
838 |
+
|
839 |
+
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
|
840 |
+
raise NotImplementedError(
|
841 |
+
"This method should be implemented in the derived class"
|
842 |
+
)
|
843 |
+
|
844 |
+
|
845 |
+
class AdaLNDirector(BaseDirector):
|
846 |
+
def __init__(
|
847 |
+
self,
|
848 |
+
name: str,
|
849 |
+
num_feats: int,
|
850 |
+
num_cond_feats: int,
|
851 |
+
num_cams: int,
|
852 |
+
latent_dim: int,
|
853 |
+
mlp_multiplier: int,
|
854 |
+
num_layers: int,
|
855 |
+
num_heads: int,
|
856 |
+
dropout: float,
|
857 |
+
stochastic_depth: float,
|
858 |
+
label_dropout: float,
|
859 |
+
num_rawfeats: int,
|
860 |
+
clip_sequential: bool = False,
|
861 |
+
cond_sequential: bool = False,
|
862 |
+
device: str = "cuda",
|
863 |
+
**kwargs,
|
864 |
+
):
|
865 |
+
super().__init__(
|
866 |
+
name=name,
|
867 |
+
num_feats=num_feats,
|
868 |
+
num_cond_feats=num_cond_feats,
|
869 |
+
num_cams=num_cams,
|
870 |
+
latent_dim=latent_dim,
|
871 |
+
mlp_multiplier=mlp_multiplier,
|
872 |
+
num_layers=num_layers,
|
873 |
+
num_heads=num_heads,
|
874 |
+
dropout=dropout,
|
875 |
+
stochastic_depth=stochastic_depth,
|
876 |
+
label_dropout=label_dropout,
|
877 |
+
num_rawfeats=num_rawfeats,
|
878 |
+
clip_sequential=clip_sequential,
|
879 |
+
cond_sequential=cond_sequential,
|
880 |
+
device=device,
|
881 |
+
)
|
882 |
+
assert not (clip_sequential and cond_sequential)
|
883 |
+
|
884 |
+
def init_conds_mappings(self, num_cond_feats, latent_dim):
|
885 |
+
self.joint_cond_projection = nn.Linear(sum(num_cond_feats), latent_dim)
|
886 |
+
|
887 |
+
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
|
888 |
+
c_emb = torch.cat(cond, dim=-1)
|
889 |
+
return self.joint_cond_projection(c_emb) + t
|
890 |
+
|
891 |
+
def init_backbone(
|
892 |
+
self,
|
893 |
+
num_layers,
|
894 |
+
latent_dim,
|
895 |
+
mlp_multiplier,
|
896 |
+
num_heads,
|
897 |
+
dropout,
|
898 |
+
stochastic_depth,
|
899 |
+
):
|
900 |
+
self.backbone_module = nn.ModuleList(
|
901 |
+
[
|
902 |
+
AdaLNSABlock(
|
903 |
+
dim_qkv=latent_dim,
|
904 |
+
dim_cond=latent_dim,
|
905 |
+
num_heads=num_heads,
|
906 |
+
mlp_multiplier=mlp_multiplier,
|
907 |
+
dropout=dropout,
|
908 |
+
stochastic_depth=stochastic_depth,
|
909 |
+
use_layernorm16=self.use_layernorm16,
|
910 |
+
)
|
911 |
+
for _ in range(num_layers)
|
912 |
+
]
|
913 |
+
)
|
914 |
+
|
915 |
+
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
|
916 |
+
for block in self.backbone_module:
|
917 |
+
x = block(x, y, mask)
|
918 |
+
return x
|
919 |
+
|
920 |
+
def init_output_projection(self, num_feats, latent_dim):
|
921 |
+
layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
|
922 |
+
|
923 |
+
self.final_norm = layer_norm(latent_dim, eps=1e-6, elementwise_affine=False)
|
924 |
+
self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
|
925 |
+
self.final_adaln = nn.Sequential(
|
926 |
+
nn.SiLU(),
|
927 |
+
nn.Linear(latent_dim, latent_dim * 2, bias=True),
|
928 |
+
)
|
929 |
+
# Zero init
|
930 |
+
nn.init.zeros_(self.final_adaln[1].weight)
|
931 |
+
nn.init.zeros_(self.final_adaln[1].bias)
|
932 |
+
|
933 |
+
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
|
934 |
+
shift, scale = self.final_adaln(y).chunk(2, dim=-1)
|
935 |
+
x = modulate_shift_and_scale(self.final_norm(x), shift, scale)
|
936 |
+
return self.final_linear(x)
|
937 |
+
|
938 |
+
|
939 |
+
class CrossAttentionDirector(BaseDirector):
|
940 |
+
def __init__(
|
941 |
+
self,
|
942 |
+
name: str,
|
943 |
+
num_feats: int,
|
944 |
+
num_cond_feats: int,
|
945 |
+
num_cams: int,
|
946 |
+
latent_dim: int,
|
947 |
+
mlp_multiplier: int,
|
948 |
+
num_layers: int,
|
949 |
+
num_heads: int,
|
950 |
+
dropout: float,
|
951 |
+
stochastic_depth: float,
|
952 |
+
label_dropout: float,
|
953 |
+
num_rawfeats: int,
|
954 |
+
num_text_registers: int,
|
955 |
+
clip_sequential: bool = True,
|
956 |
+
cond_sequential: bool = True,
|
957 |
+
device: str = "cuda",
|
958 |
+
**kwargs,
|
959 |
+
):
|
960 |
+
self.num_text_registers = num_text_registers
|
961 |
+
self.num_heads = num_heads
|
962 |
+
self.dropout = dropout
|
963 |
+
self.mlp_multiplier = mlp_multiplier
|
964 |
+
self.stochastic_depth = stochastic_depth
|
965 |
+
super().__init__(
|
966 |
+
name=name,
|
967 |
+
num_feats=num_feats,
|
968 |
+
num_cond_feats=num_cond_feats,
|
969 |
+
num_cams=num_cams,
|
970 |
+
latent_dim=latent_dim,
|
971 |
+
mlp_multiplier=mlp_multiplier,
|
972 |
+
num_layers=num_layers,
|
973 |
+
num_heads=num_heads,
|
974 |
+
dropout=dropout,
|
975 |
+
stochastic_depth=stochastic_depth,
|
976 |
+
label_dropout=label_dropout,
|
977 |
+
num_rawfeats=num_rawfeats,
|
978 |
+
clip_sequential=clip_sequential,
|
979 |
+
cond_sequential=cond_sequential,
|
980 |
+
device=device,
|
981 |
+
)
|
982 |
+
assert clip_sequential and cond_sequential
|
983 |
+
|
984 |
+
def init_conds_mappings(self, num_cond_feats, latent_dim):
|
985 |
+
self.cond_projection = nn.ModuleList(
|
986 |
+
[nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats]
|
987 |
+
)
|
988 |
+
self.cond_registers = nn.Parameter(
|
989 |
+
torch.randn(self.num_text_registers, latent_dim), requires_grad=True
|
990 |
+
)
|
991 |
+
nn.init.trunc_normal_(self.cond_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02)
|
992 |
+
self.cond_sa = nn.ModuleList(
|
993 |
+
[
|
994 |
+
SelfAttentionBlock(
|
995 |
+
dim_qkv=latent_dim,
|
996 |
+
num_heads=self.num_heads,
|
997 |
+
mlp_multiplier=self.mlp_multiplier,
|
998 |
+
dropout=self.dropout,
|
999 |
+
stochastic_depth=self.stochastic_depth,
|
1000 |
+
use_layernorm16=self.use_layernorm16,
|
1001 |
+
)
|
1002 |
+
for _ in range(2)
|
1003 |
+
]
|
1004 |
+
)
|
1005 |
+
self.cond_positional_embedding = PositionalEncoding(latent_dim, max_len=10000)
|
1006 |
+
|
1007 |
+
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
|
1008 |
+
batch_size = cond[0].shape[0]
|
1009 |
+
cond_emb = [
|
1010 |
+
cond_proj(rearrange(c, "b c n -> b n c"))
|
1011 |
+
for cond_proj, c in zip(self.cond_projection, cond)
|
1012 |
+
]
|
1013 |
+
cond_emb = [
|
1014 |
+
self.cond_registers.unsqueeze(0).expand(batch_size, -1, -1),
|
1015 |
+
t.unsqueeze(1),
|
1016 |
+
] + cond_emb
|
1017 |
+
cond_emb = torch.cat(cond_emb, dim=1)
|
1018 |
+
cond_emb = self.cond_positional_embedding(cond_emb)
|
1019 |
+
for block in self.cond_sa:
|
1020 |
+
cond_emb = block(cond_emb)
|
1021 |
+
return cond_emb
|
1022 |
+
|
1023 |
+
def init_backbone(
|
1024 |
+
self,
|
1025 |
+
num_layers,
|
1026 |
+
latent_dim,
|
1027 |
+
mlp_multiplier,
|
1028 |
+
num_heads,
|
1029 |
+
dropout,
|
1030 |
+
stochastic_depth,
|
1031 |
+
):
|
1032 |
+
self.backbone_module = nn.ModuleList(
|
1033 |
+
[
|
1034 |
+
CrossAttentionSABlock(
|
1035 |
+
dim_qkv=latent_dim,
|
1036 |
+
dim_cond=latent_dim,
|
1037 |
+
num_heads=num_heads,
|
1038 |
+
mlp_multiplier=mlp_multiplier,
|
1039 |
+
dropout=dropout,
|
1040 |
+
stochastic_depth=stochastic_depth,
|
1041 |
+
use_layernorm16=self.use_layernorm16,
|
1042 |
+
)
|
1043 |
+
for _ in range(num_layers)
|
1044 |
+
]
|
1045 |
+
)
|
1046 |
+
|
1047 |
+
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
|
1048 |
+
for block in self.backbone_module:
|
1049 |
+
x = block(x, y, mask, None)
|
1050 |
+
return x
|
1051 |
+
|
1052 |
+
def init_output_projection(self, num_feats, latent_dim):
|
1053 |
+
layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
|
1054 |
+
|
1055 |
+
self.final_norm = layer_norm(latent_dim, eps=1e-6)
|
1056 |
+
self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
|
1057 |
+
|
1058 |
+
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
|
1059 |
+
return self.final_linear(self.final_norm(x))
|
1060 |
+
|
1061 |
+
|
1062 |
+
class InContextDirector(BaseDirector):
|
1063 |
+
def __init__(
|
1064 |
+
self,
|
1065 |
+
name: str,
|
1066 |
+
num_feats: int,
|
1067 |
+
num_cond_feats: int,
|
1068 |
+
num_cams: int,
|
1069 |
+
latent_dim: int,
|
1070 |
+
mlp_multiplier: int,
|
1071 |
+
num_layers: int,
|
1072 |
+
num_heads: int,
|
1073 |
+
dropout: float,
|
1074 |
+
stochastic_depth: float,
|
1075 |
+
label_dropout: float,
|
1076 |
+
num_rawfeats: int,
|
1077 |
+
clip_sequential: bool = False,
|
1078 |
+
cond_sequential: bool = False,
|
1079 |
+
device: str = "cuda",
|
1080 |
+
**kwargs,
|
1081 |
+
):
|
1082 |
+
super().__init__(
|
1083 |
+
name=name,
|
1084 |
+
num_feats=num_feats,
|
1085 |
+
num_cond_feats=num_cond_feats,
|
1086 |
+
num_cams=num_cams,
|
1087 |
+
latent_dim=latent_dim,
|
1088 |
+
mlp_multiplier=mlp_multiplier,
|
1089 |
+
num_layers=num_layers,
|
1090 |
+
num_heads=num_heads,
|
1091 |
+
dropout=dropout,
|
1092 |
+
stochastic_depth=stochastic_depth,
|
1093 |
+
label_dropout=label_dropout,
|
1094 |
+
num_rawfeats=num_rawfeats,
|
1095 |
+
clip_sequential=clip_sequential,
|
1096 |
+
cond_sequential=cond_sequential,
|
1097 |
+
device=device,
|
1098 |
+
)
|
1099 |
+
|
1100 |
+
def init_conds_mappings(self, num_cond_feats, latent_dim):
|
1101 |
+
self.cond_projection = nn.ModuleList(
|
1102 |
+
[nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats]
|
1103 |
+
)
|
1104 |
+
|
1105 |
+
def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
|
1106 |
+
for i in range(len(cond)):
|
1107 |
+
if cond[i].dim() == 3:
|
1108 |
+
cond[i] = rearrange(cond[i], "b c n -> b n c")
|
1109 |
+
cond_emb = [cond_proj(c) for cond_proj, c in zip(self.cond_projection, cond)]
|
1110 |
+
cond_emb = [c.unsqueeze(1) if c.dim() == 2 else cond_emb for c in cond_emb]
|
1111 |
+
cond_emb = torch.cat([t.unsqueeze(1)] + cond_emb, dim=1)
|
1112 |
+
return cond_emb
|
1113 |
+
|
1114 |
+
def init_backbone(
|
1115 |
+
self,
|
1116 |
+
num_layers,
|
1117 |
+
latent_dim,
|
1118 |
+
mlp_multiplier,
|
1119 |
+
num_heads,
|
1120 |
+
dropout,
|
1121 |
+
stochastic_depth,
|
1122 |
+
):
|
1123 |
+
self.backbone_module = nn.ModuleList(
|
1124 |
+
[
|
1125 |
+
SelfAttentionBlock(
|
1126 |
+
dim_qkv=latent_dim,
|
1127 |
+
num_heads=num_heads,
|
1128 |
+
mlp_multiplier=mlp_multiplier,
|
1129 |
+
dropout=dropout,
|
1130 |
+
stochastic_depth=stochastic_depth,
|
1131 |
+
use_layernorm16=self.use_layernorm16,
|
1132 |
+
)
|
1133 |
+
for _ in range(num_layers)
|
1134 |
+
]
|
1135 |
+
)
|
1136 |
+
|
1137 |
+
def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
|
1138 |
+
bs, n_y, _ = y.shape
|
1139 |
+
mask = torch.cat([torch.ones(bs, n_y, device=y.device), mask], dim=1)
|
1140 |
+
x = torch.cat([y, x], dim=1)
|
1141 |
+
for block in self.backbone_module:
|
1142 |
+
x = block(x, mask)
|
1143 |
+
return x
|
1144 |
+
|
1145 |
+
def init_output_projection(self, num_feats, latent_dim):
|
1146 |
+
layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
|
1147 |
+
|
1148 |
+
self.final_norm = layer_norm(latent_dim, eps=1e-6)
|
1149 |
+
self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
|
1150 |
+
|
1151 |
+
def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
|
1152 |
+
num_y = y.shape[1]
|
1153 |
+
x = x[:, num_y:]
|
1154 |
+
return self.final_linear(self.final_norm(x))
|
src/models/networks.py
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
|
4 |
+
# ----------------------------------------------------------------------------
|
5 |
+
# Improved preconditioning proposed in the paper "Elucidating the Design
|
6 |
+
# Space of Diffusion-Based Generative Models" (EDM).
|
7 |
+
|
8 |
+
|
9 |
+
class RnEDMPrecond(nn.Module):
|
10 |
+
def __init__(self, sigma_data: float = 0.5, module: nn.Module = None, **kwargs):
|
11 |
+
super().__init__()
|
12 |
+
self.sigma_data = sigma_data
|
13 |
+
|
14 |
+
self.model = module
|
15 |
+
self.num_rawfeats = module.num_rawfeats
|
16 |
+
self.num_feats = module.num_feats
|
17 |
+
self.num_cams = module.num_cams
|
18 |
+
|
19 |
+
def forward(self, x, sigma, y=None, mask=None):
|
20 |
+
"""
|
21 |
+
x: [batch_size, num_feats, max_frames], denoted x_t in the paper
|
22 |
+
sigma: [batch_size] (int)
|
23 |
+
"""
|
24 |
+
sigma = sigma.reshape(-1, 1, 1)
|
25 |
+
c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
|
26 |
+
c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
|
27 |
+
c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
|
28 |
+
c_noise = sigma.log() / 4
|
29 |
+
|
30 |
+
F_x = self.model(c_in * x, c_noise.flatten(), y=y, mask=mask)
|
31 |
+
D_x = c_skip * x + c_out * F_x
|
32 |
+
|
33 |
+
return D_x
|
utils/common_viz.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Dict, List, Tuple
|
2 |
+
|
3 |
+
import clip
|
4 |
+
from hydra import compose, initialize
|
5 |
+
from hydra.utils import instantiate
|
6 |
+
from omegaconf import OmegaConf
|
7 |
+
import torch
|
8 |
+
from torchtyping import TensorType
|
9 |
+
from torch.utils.data import DataLoader
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
from src.diffuser import Diffuser
|
13 |
+
from src.datasets.multimodal_dataset import MultimodalDataset
|
14 |
+
|
15 |
+
# ------------------------------------------------------------------------------------- #
|
16 |
+
|
17 |
+
batch_size, context_length = None, None
|
18 |
+
collate_fn = DataLoader([]).collate_fn
|
19 |
+
|
20 |
+
# ------------------------------------------------------------------------------------- #
|
21 |
+
|
22 |
+
|
23 |
+
def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
|
24 |
+
for key, value in batch.items():
|
25 |
+
if isinstance(value, torch.Tensor):
|
26 |
+
batch[key] = value.to(device)
|
27 |
+
return batch
|
28 |
+
|
29 |
+
|
30 |
+
def load_clip_model(version: str, device: str) -> clip.model.CLIP:
|
31 |
+
model, _ = clip.load(version, device=device, jit=False)
|
32 |
+
model.eval()
|
33 |
+
for p in model.parameters():
|
34 |
+
p.requires_grad = False
|
35 |
+
return model
|
36 |
+
|
37 |
+
|
38 |
+
def encode_text(
|
39 |
+
caption_raws: List[str], # batch_size
|
40 |
+
clip_model: clip.model.CLIP,
|
41 |
+
max_token_length: int,
|
42 |
+
device: str,
|
43 |
+
) -> TensorType["batch_size", "context_length"]:
|
44 |
+
if max_token_length is not None:
|
45 |
+
default_context_length = 77
|
46 |
+
context_length = max_token_length + 2 # start_token + 20 + end_token
|
47 |
+
assert context_length < default_context_length
|
48 |
+
# [bs, context_length] # if n_tokens > context_length -> will truncate
|
49 |
+
texts = clip.tokenize(
|
50 |
+
caption_raws, context_length=context_length, truncate=True
|
51 |
+
)
|
52 |
+
zero_pad = torch.zeros(
|
53 |
+
[texts.shape[0], default_context_length - context_length],
|
54 |
+
dtype=texts.dtype,
|
55 |
+
device=texts.device,
|
56 |
+
)
|
57 |
+
texts = torch.cat([texts, zero_pad], dim=1)
|
58 |
+
else:
|
59 |
+
# [bs, context_length] # if n_tokens > 77 -> will truncate
|
60 |
+
texts = clip.tokenize(caption_raws, truncate=True)
|
61 |
+
|
62 |
+
# [batch_size, n_ctx, d_model]
|
63 |
+
x = clip_model.token_embedding(texts.to(device)).type(clip_model.dtype)
|
64 |
+
x = x + clip_model.positional_embedding.type(clip_model.dtype)
|
65 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
66 |
+
x = clip_model.transformer(x)
|
67 |
+
x = x.permute(1, 0, 2) # LND -> NLD
|
68 |
+
x = clip_model.ln_final(x).type(clip_model.dtype)
|
69 |
+
# x.shape = [batch_size, n_ctx, transformer.width]
|
70 |
+
# take features from the eot embedding (eot_token is the highest in each sequence)
|
71 |
+
x_tokens = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)].float()
|
72 |
+
x_seq = [x[k, : (m + 1)].float() for k, m in enumerate(texts.argmax(dim=-1))]
|
73 |
+
|
74 |
+
return x_seq, x_tokens
|
75 |
+
|
76 |
+
|
77 |
+
def get_batch(
|
78 |
+
prompt: str,
|
79 |
+
sample_id: str,
|
80 |
+
clip_model: clip.model.CLIP,
|
81 |
+
dataset: MultimodalDataset,
|
82 |
+
seq_feat: bool,
|
83 |
+
device: torch.device,
|
84 |
+
) -> Dict[str, Any]:
|
85 |
+
# Get base batch
|
86 |
+
sample_index = dataset.root_filenames.index(sample_id)
|
87 |
+
raw_batch = dataset[sample_index]
|
88 |
+
batch = collate_fn([to_device(raw_batch, device)])
|
89 |
+
|
90 |
+
# Encode text
|
91 |
+
caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device)
|
92 |
+
|
93 |
+
if seq_feat:
|
94 |
+
caption_feat = caption_seq[0]
|
95 |
+
caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0]))
|
96 |
+
caption_feat = caption_feat.unsqueeze(0).permute(0, 2, 1)
|
97 |
+
else:
|
98 |
+
caption_feat = caption_tokens
|
99 |
+
|
100 |
+
# Update batch
|
101 |
+
batch["caption_raw"] = [prompt]
|
102 |
+
batch["caption_feat"] = caption_feat
|
103 |
+
|
104 |
+
return batch
|
105 |
+
|
106 |
+
|
107 |
+
def init(
|
108 |
+
config_name: str,
|
109 |
+
) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset, torch.device]:
|
110 |
+
with initialize(version_base="1.3", config_path="../configs"):
|
111 |
+
config = compose(config_name=config_name)
|
112 |
+
|
113 |
+
OmegaConf.register_new_resolver("eval", eval)
|
114 |
+
|
115 |
+
# Initialize model
|
116 |
+
device = torch.device(config.compnode.device)
|
117 |
+
diffuser = instantiate(config.diffuser)
|
118 |
+
state_dict = torch.load(config.checkpoint_path, map_location=device)["state_dict"]
|
119 |
+
state_dict["ema.initted"] = diffuser.ema.initted
|
120 |
+
state_dict["ema.step"] = diffuser.ema.step
|
121 |
+
diffuser.load_state_dict(state_dict, strict=False)
|
122 |
+
diffuser.to(device).eval()
|
123 |
+
|
124 |
+
# Initialize CLIP model
|
125 |
+
clip_model = load_clip_model("ViT-B/32", device)
|
126 |
+
|
127 |
+
# Initialize dataset
|
128 |
+
config.dataset.char.load_vertices = True
|
129 |
+
config.batch_size = 1
|
130 |
+
dataset = instantiate(config.dataset)
|
131 |
+
dataset.set_split("demo")
|
132 |
+
diffuser.modalities = list(dataset.modality_datasets.keys())
|
133 |
+
diffuser.get_matrix = dataset.get_matrix
|
134 |
+
diffuser.v_get_matrix = dataset.get_matrix
|
135 |
+
|
136 |
+
return diffuser, clip_model, dataset, device
|
utils/file_utils.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import os.path as osp
|
4 |
+
import pickle
|
5 |
+
import subprocess
|
6 |
+
from typing import Any
|
7 |
+
|
8 |
+
import h5py
|
9 |
+
import numpy as np
|
10 |
+
import pandas as pd
|
11 |
+
import torch
|
12 |
+
import torchaudio
|
13 |
+
from torchtyping import TensorType
|
14 |
+
|
15 |
+
num_channels, num_frames, height, width = None, None, None, None
|
16 |
+
|
17 |
+
|
18 |
+
def create_dir(dir_name: str):
|
19 |
+
"""Create a directory if it does not exist yet."""
|
20 |
+
if not osp.exists(dir_name):
|
21 |
+
os.makedirs(dir_name)
|
22 |
+
|
23 |
+
|
24 |
+
def move_files(source_path: str, destpath: str):
|
25 |
+
"""Move files from `source_path` to `dest_path`."""
|
26 |
+
subprocess.call(["mv", source_path, destpath])
|
27 |
+
|
28 |
+
|
29 |
+
def load_pickle(pickle_path: str) -> Any:
|
30 |
+
"""Load a pickle file."""
|
31 |
+
with open(pickle_path, "rb") as f:
|
32 |
+
data = pickle.load(f)
|
33 |
+
return data
|
34 |
+
|
35 |
+
|
36 |
+
def load_hdf5(hdf5_path: str) -> Any:
|
37 |
+
with h5py.File(hdf5_path, "r") as h5file:
|
38 |
+
data = {key: np.array(value) for key, value in h5file.items()}
|
39 |
+
return data
|
40 |
+
|
41 |
+
|
42 |
+
def save_hdf5(data: Any, hdf5_path: str):
|
43 |
+
with h5py.File(hdf5_path, "w") as h5file:
|
44 |
+
for key, value in data.items():
|
45 |
+
h5file.create_dataset(key, data=value)
|
46 |
+
|
47 |
+
|
48 |
+
def save_pickle(data: Any, pickle_path: str):
|
49 |
+
"""Save data in a pickle file."""
|
50 |
+
with open(pickle_path, "wb") as f:
|
51 |
+
pickle.dump(data, f, protocol=4)
|
52 |
+
|
53 |
+
|
54 |
+
def load_txt(txt_path: str):
|
55 |
+
"""Load a txt file."""
|
56 |
+
with open(txt_path, "r") as f:
|
57 |
+
data = f.read()
|
58 |
+
return data
|
59 |
+
|
60 |
+
|
61 |
+
def save_txt(data: str, txt_path: str):
|
62 |
+
"""Save data in a txt file."""
|
63 |
+
with open(txt_path, "w") as f:
|
64 |
+
f.write(data)
|
65 |
+
|
66 |
+
|
67 |
+
def load_pth(pth_path: str) -> Any:
|
68 |
+
"""Load a pth (PyTorch) file."""
|
69 |
+
data = torch.load(pth_path)
|
70 |
+
return data
|
71 |
+
|
72 |
+
|
73 |
+
def save_pth(data: Any, pth_path: str):
|
74 |
+
"""Save data in a pth (PyTorch) file."""
|
75 |
+
torch.save(data, pth_path)
|
76 |
+
|
77 |
+
|
78 |
+
def load_csv(csv_path: str, header: Any = None) -> pd.DataFrame:
|
79 |
+
"""Load a csv file."""
|
80 |
+
try:
|
81 |
+
data = pd.read_csv(csv_path, header=header)
|
82 |
+
except pd.errors.EmptyDataError:
|
83 |
+
data = pd.DataFrame()
|
84 |
+
return data
|
85 |
+
|
86 |
+
|
87 |
+
def save_csv(data: Any, csv_path: str):
|
88 |
+
"""Save data in a csv file."""
|
89 |
+
pd.DataFrame(data).to_csv(csv_path, header=False, index=False)
|
90 |
+
|
91 |
+
|
92 |
+
def load_json(json_path: str, header: Any = None) -> pd.DataFrame:
|
93 |
+
"""Load a json file."""
|
94 |
+
with open(json_path, "r") as f:
|
95 |
+
data = json.load(f)
|
96 |
+
return data
|
97 |
+
|
98 |
+
|
99 |
+
def save_json(data: Any, json_path: str):
|
100 |
+
"""Save data in a json file."""
|
101 |
+
with open(json_path, "w") as json_file:
|
102 |
+
json.dump(data, json_file)
|
103 |
+
|
104 |
+
|
105 |
+
def load_audio(audio_path: str, **kwargs):
|
106 |
+
"""Load an audio file."""
|
107 |
+
waveform, sample_rate = torchaudio.load(audio_path, **kwargs)
|
108 |
+
return waveform, sample_rate
|
109 |
+
|
110 |
+
|
111 |
+
def save_audio(
|
112 |
+
data: TensorType["num_channels", "num_frames"],
|
113 |
+
audio_path: str,
|
114 |
+
sample_rate: int = 44100,
|
115 |
+
):
|
116 |
+
"""Save data in an audio file."""
|
117 |
+
torchaudio.save(audio_path, data, sample_rate)
|
utils/random_utils.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import random
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def set_random_seed(seed: int):
|
7 |
+
torch.manual_seed((seed) % (1 << 31))
|
8 |
+
torch.cuda.manual_seed((seed) % (1 << 31))
|
9 |
+
torch.cuda.manual_seed_all((seed) % (1 << 31))
|
10 |
+
np.random.seed((seed) % (1 << 31))
|
11 |
+
random.seed((seed) % (1 << 31))
|
12 |
+
torch.backends.cudnn.benchmark = False
|
13 |
+
torch.backends.cudnn.deterministic = True
|
14 |
+
|
15 |
+
|
16 |
+
class StackedRandomGenerator:
|
17 |
+
"""
|
18 |
+
Wrapper for torch.Generator that allows specifying a different random seed for each
|
19 |
+
sample in a minibatch.
|
20 |
+
"""
|
21 |
+
|
22 |
+
def __init__(self, device, seeds):
|
23 |
+
super().__init__()
|
24 |
+
self.generators = [
|
25 |
+
torch.Generator(device).manual_seed(int(seed) % (1 << 31)) for seed in seeds
|
26 |
+
]
|
27 |
+
|
28 |
+
def randn_rn(self, size, **kwargs):
|
29 |
+
assert size[0] == len(self.generators)
|
30 |
+
return torch.stack(
|
31 |
+
[torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]
|
32 |
+
)
|
33 |
+
|
34 |
+
def randn_like(self, input):
|
35 |
+
return self.randn_rn(
|
36 |
+
input.shape, dtype=input.dtype, layout=input.layout, device=input.device
|
37 |
+
)
|
38 |
+
|
39 |
+
def randint(self, *args, size, **kwargs):
|
40 |
+
assert size[0] == len(self.generators)
|
41 |
+
return torch.stack(
|
42 |
+
[
|
43 |
+
torch.randint(*args, size=size[1:], generator=gen, **kwargs)
|
44 |
+
for gen in self.generators
|
45 |
+
]
|
46 |
+
)
|
utils/rerun.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from matplotlib import colormaps
|
3 |
+
import rerun as rr
|
4 |
+
from rerun.components import Material
|
5 |
+
from scipy.spatial import transform
|
6 |
+
|
7 |
+
|
8 |
+
def color_fn(x, cmap="tab10"):
|
9 |
+
return colormaps[cmap](x % colormaps[cmap].N)
|
10 |
+
|
11 |
+
|
12 |
+
def log_sample(
|
13 |
+
root_name: str,
|
14 |
+
traj: np.ndarray,
|
15 |
+
K: np.ndarray,
|
16 |
+
vertices: np.ndarray,
|
17 |
+
faces: np.ndarray,
|
18 |
+
normals: np.ndarray,
|
19 |
+
caption: str,
|
20 |
+
mesh_masks: np.ndarray,
|
21 |
+
):
|
22 |
+
num_cameras = traj.shape[0]
|
23 |
+
|
24 |
+
rr.log(root_name, rr.ViewCoordinates.RIGHT_HAND_Y_DOWN, timeless=True)
|
25 |
+
rr.log(
|
26 |
+
f"{root_name}/trajectory/points",
|
27 |
+
rr.Points3D(traj[:, :3, 3]),
|
28 |
+
timeless=True,
|
29 |
+
)
|
30 |
+
rr.log(
|
31 |
+
f"{root_name}/trajectory/line",
|
32 |
+
rr.LineStrips3D(
|
33 |
+
np.stack((traj[:, :3, 3][:-1], traj[:, :3, 3][1:]), axis=1),
|
34 |
+
colors=[(1.0, 0.0, 1.0, 1.0)],
|
35 |
+
),
|
36 |
+
timeless=True,
|
37 |
+
)
|
38 |
+
for k in range(num_cameras):
|
39 |
+
rr.set_time_sequence("frame_idx", k)
|
40 |
+
|
41 |
+
translation = traj[k][:3, 3]
|
42 |
+
rotation_q = transform.Rotation.from_matrix(traj[k][:3, :3]).as_quat()
|
43 |
+
rr.log(
|
44 |
+
f"{root_name}/camera/image",
|
45 |
+
rr.Pinhole(
|
46 |
+
image_from_camera=K,
|
47 |
+
width=K[0, -1] * 2,
|
48 |
+
height=K[1, -1] * 2,
|
49 |
+
),
|
50 |
+
)
|
51 |
+
rr.log(
|
52 |
+
f"{root_name}/camera",
|
53 |
+
rr.Transform3D(
|
54 |
+
translation=translation,
|
55 |
+
rotation=rr.Quaternion(xyzw=rotation_q),
|
56 |
+
),
|
57 |
+
)
|
58 |
+
rr.set_time_sequence("image", k)
|
59 |
+
|
60 |
+
# Null vertices
|
61 |
+
if vertices[k].sum() == 0:
|
62 |
+
rr.log(f"{root_name}/char/char", rr.Clear(recursive=False))
|
63 |
+
rr.log(f"{root_name}/camera/image/bbox", rr.Clear(recursive=False))
|
64 |
+
continue
|
65 |
+
|
66 |
+
rr.log(
|
67 |
+
f"{root_name}/char/char",
|
68 |
+
rr.Mesh3D(
|
69 |
+
vertex_positions=vertices[k],
|
70 |
+
indices=faces,
|
71 |
+
vertex_normals=normals[k],
|
72 |
+
mesh_material=Material(albedo_factor=color_fn(0)),
|
73 |
+
),
|
74 |
+
)
|
utils/rotation_utils.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy.spatial.transform import Rotation as R
|
3 |
+
import torch
|
4 |
+
from torchtyping import TensorType
|
5 |
+
from itertools import product
|
6 |
+
|
7 |
+
num_samples, num_cams = None, None
|
8 |
+
|
9 |
+
|
10 |
+
def rotvec_to_matrix(rotvec):
|
11 |
+
return R.from_rotvec(rotvec).as_matrix()
|
12 |
+
|
13 |
+
|
14 |
+
def matrix_to_rotvec(mat):
|
15 |
+
return R.from_matrix(mat).as_rotvec()
|
16 |
+
|
17 |
+
|
18 |
+
def compose_rotvec(r1, r2):
|
19 |
+
"""
|
20 |
+
#TODO: adapt to torch
|
21 |
+
Compose two rotation euler vectors.
|
22 |
+
"""
|
23 |
+
r1 = r1.cpu().numpy() if isinstance(r1, torch.Tensor) else r1
|
24 |
+
r2 = r2.cpu().numpy() if isinstance(r2, torch.Tensor) else r2
|
25 |
+
|
26 |
+
R1 = rotvec_to_matrix(r1)
|
27 |
+
R2 = rotvec_to_matrix(r2)
|
28 |
+
cR = np.einsum("...ij,...jk->...ik", R1, R2)
|
29 |
+
return torch.from_numpy(matrix_to_rotvec(cR))
|
30 |
+
|
31 |
+
|
32 |
+
def quat_to_rotvec(quat, eps=1e-6):
|
33 |
+
# w > 0 to ensure 0 <= angle <= pi
|
34 |
+
flip = (quat[..., :1] < 0).float()
|
35 |
+
quat = (-1 * quat) * flip + (1 - flip) * quat
|
36 |
+
|
37 |
+
angle = 2 * torch.atan2(torch.linalg.norm(quat[..., 1:], dim=-1), quat[..., 0])
|
38 |
+
|
39 |
+
angle2 = angle * angle
|
40 |
+
small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
|
41 |
+
large_angle_scales = angle / torch.sin(angle / 2 + eps)
|
42 |
+
|
43 |
+
small_angles = (angle <= 1e-3).float()
|
44 |
+
rot_vec_scale = (
|
45 |
+
small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
|
46 |
+
)
|
47 |
+
rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
|
48 |
+
return rot_vec
|
49 |
+
|
50 |
+
|
51 |
+
# batch*n
|
52 |
+
def normalize_vector(v, return_mag=False):
|
53 |
+
batch = v.shape[0]
|
54 |
+
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
|
55 |
+
v_mag = torch.max(
|
56 |
+
v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8])).to(v.device)
|
57 |
+
)
|
58 |
+
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
|
59 |
+
v = v / v_mag
|
60 |
+
if return_mag is True:
|
61 |
+
return v, v_mag[:, 0]
|
62 |
+
else:
|
63 |
+
return v
|
64 |
+
|
65 |
+
|
66 |
+
# u, v batch*n
|
67 |
+
def cross_product(u, v):
|
68 |
+
batch = u.shape[0]
|
69 |
+
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
|
70 |
+
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
|
71 |
+
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
|
72 |
+
|
73 |
+
out = torch.cat(
|
74 |
+
(i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1
|
75 |
+
) # [batch, 6]
|
76 |
+
|
77 |
+
return out
|
78 |
+
|
79 |
+
|
80 |
+
def compute_rotation_matrix_from_ortho6d(ortho6d):
|
81 |
+
x_raw = ortho6d[:, 0:3] # [batch, 6]
|
82 |
+
y_raw = ortho6d[:, 3:6] # [batch, 6]
|
83 |
+
|
84 |
+
x = normalize_vector(x_raw) # [batch, 6]
|
85 |
+
z = cross_product(x, y_raw) # [batch, 6]
|
86 |
+
z = normalize_vector(z) # [batch, 6]
|
87 |
+
y = cross_product(z, x) # [batch, 6]
|
88 |
+
|
89 |
+
x = x.view(-1, 3, 1)
|
90 |
+
y = y.view(-1, 3, 1)
|
91 |
+
z = z.view(-1, 3, 1)
|
92 |
+
matrix = torch.cat((x, y, z), 2) # [batch, 3, 3]
|
93 |
+
return matrix
|
94 |
+
|
95 |
+
|
96 |
+
def invert_rotvec(rotvec: TensorType["num_samples", 3]):
|
97 |
+
angle = torch.norm(rotvec, dim=-1)
|
98 |
+
axis = rotvec / (angle.unsqueeze(-1) + 1e-6)
|
99 |
+
inverted_rotvec = -angle.unsqueeze(-1) * axis
|
100 |
+
return inverted_rotvec
|
101 |
+
|
102 |
+
|
103 |
+
def are_rotations(matrix: TensorType["num_samples", 3, 3]) -> TensorType["num_samples"]:
|
104 |
+
"""Check if a matrix is a rotation matrix."""
|
105 |
+
# Check if the matrix is orthogonal
|
106 |
+
identity = torch.eye(3, device=matrix.device)
|
107 |
+
is_orthogonal = (
|
108 |
+
torch.isclose(torch.bmm(matrix, matrix.transpose(1, 2)), identity, atol=1e-6)
|
109 |
+
.all(dim=1)
|
110 |
+
.all(dim=1)
|
111 |
+
)
|
112 |
+
|
113 |
+
# Check if the determinant is 1
|
114 |
+
determinant = torch.det(matrix)
|
115 |
+
is_determinant_one = torch.isclose(
|
116 |
+
determinant, torch.tensor(1.0, device=matrix.device), atol=1e-6
|
117 |
+
)
|
118 |
+
|
119 |
+
return torch.logical_and(is_orthogonal, is_determinant_one)
|
120 |
+
|
121 |
+
|
122 |
+
def project_so3(
|
123 |
+
matrix: TensorType["num_samples", 4, 4]
|
124 |
+
) -> TensorType["num_samples", 4, 4]:
|
125 |
+
# Project rotation matrix to SO(3)
|
126 |
+
# TODO: use torch
|
127 |
+
rot = R.from_matrix(matrix[:, :3, :3].cpu().numpy()).as_matrix()
|
128 |
+
|
129 |
+
projection = torch.eye(4).unsqueeze(0).repeat(matrix.shape[0], 1, 1).to(matrix)
|
130 |
+
projection[:, :3, :3] = torch.from_numpy(rot).to(matrix)
|
131 |
+
projection[:, :3, 3] = matrix[:, :3, 3]
|
132 |
+
|
133 |
+
return projection
|
134 |
+
|
135 |
+
|
136 |
+
def pairwise_geodesic(
|
137 |
+
R_x: TensorType["num_samples", "num_cams", 3, 3],
|
138 |
+
R_y: TensorType["num_samples", "num_cams", 3, 3],
|
139 |
+
reduction: str = "mean",
|
140 |
+
block_size: int = 200,
|
141 |
+
):
|
142 |
+
def arange(start, stop, step, endpoint=True):
|
143 |
+
arr = torch.arange(start, stop, step)
|
144 |
+
if endpoint and arr[-1] != stop - 1:
|
145 |
+
arr = torch.cat((arr, torch.tensor([stop - 1], dtype=arr.dtype)))
|
146 |
+
return arr
|
147 |
+
|
148 |
+
# Geodesic distance
|
149 |
+
# https://math.stackexchange.com/questions/2113634/comparing-two-rotation-matrices
|
150 |
+
num_samples, num_cams, _, _ = R_x.shape
|
151 |
+
|
152 |
+
C = torch.zeros(num_samples, num_samples, device=R_x.device)
|
153 |
+
chunk_indices = arange(0, num_samples + 1, block_size, endpoint=True)
|
154 |
+
for i, j in product(
|
155 |
+
range(chunk_indices.shape[0] - 1), range(chunk_indices.shape[0] - 1)
|
156 |
+
):
|
157 |
+
start_x, stop_x = chunk_indices[i], chunk_indices[i + 1]
|
158 |
+
start_y, stop_y = chunk_indices[j], chunk_indices[j + 1]
|
159 |
+
r_x, r_y = R_x[start_x:stop_x], R_y[start_y:stop_y]
|
160 |
+
|
161 |
+
# Compute rotations between each pair of cameras of each sample
|
162 |
+
r_xy = torch.einsum("anjk,bnlk->abnjl", r_x, r_y) # b, b, N, 3, 3
|
163 |
+
|
164 |
+
# Compute axis-angle representations: angle is the geodesic distance
|
165 |
+
traces = r_xy.diagonal(dim1=-2, dim2=-1).sum(-1)
|
166 |
+
c = torch.acos(torch.clamp((traces - 1) / 2, -1, 1)) / torch.pi
|
167 |
+
|
168 |
+
# Average distance between cameras over samples
|
169 |
+
if reduction == "mean":
|
170 |
+
C[start_x:stop_x, start_y:stop_y] = c.mean(-1)
|
171 |
+
elif reduction == "sum":
|
172 |
+
C[start_x:stop_x, start_y:stop_y] = c.sum(-1)
|
173 |
+
|
174 |
+
# Check for NaN values in traces
|
175 |
+
if torch.isnan(c).any():
|
176 |
+
raise ValueError("NaN values detected in traces")
|
177 |
+
|
178 |
+
return C
|