robin-courant commited on
Commit
f7a5cb1
1 Parent(s): 5e4c5a1
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import partial
2
+ from typing import Any, Callable, Dict
3
+
4
+ import clip
5
+ import gradio as gr
6
+ from gradio_rerun import Rerun
7
+ import numpy as np
8
+ from pytorch3d.renderer import TexturesVertex
9
+ from pytorch3d.structures import Meshes
10
+ import rerun as rr
11
+ import torch
12
+
13
+ from utils.common_viz import init, get_batch
14
+ from utils.random_utils import set_random_seed
15
+ from utils.rerun import log_sample
16
+ from src.diffuser import Diffuser
17
+ from src.datasets.multimodal_dataset import MultimodalDataset
18
+
19
+ # ------------------------------------------------------------------------------------- #
20
+
21
+ batch_size, num_cams, num_verts = None, None, None
22
+
23
+ SAMPLE_IDS = [
24
+ "2011_KAeAqaA0Llg_00005_00001",
25
+ "2011_F_EuMeT2wBo_00014_00001",
26
+ "2011_MCkKihQrNA4_00014_00000",
27
+ ]
28
+ LABEL_TO_IDS = {
29
+ "right": 0,
30
+ "static": 1,
31
+ "complex": 2,
32
+ }
33
+ EXAMPLES = [
34
+ "While the character moves right, the camera trucks right.",
35
+ "While the character moves right, the camera performs a push in.",
36
+ "While the character moves right, the camera performs a pull out.",
37
+ "While the character stays static, the camera performs a boom bottom.",
38
+ "While the character stays static, the camera performs a boom top.",
39
+ "While the character moves to the right, the camera trucks right alongside them. Once the character comes to a stop, the camera remains static.", # noqa
40
+ "While the character moves to the right, the camera remains static. Once the character comes to a stop, the camera pushes in.", # noqa
41
+ ]
42
+ DEFAULT_TEXT = [
43
+ "While the character moves right, the camera [...].",
44
+ "While the character remains static, [...].",
45
+ "While the character moves to the right, the camera [...]. "
46
+ "Once the character comes to a stop, the camera [...].",
47
+ ]
48
+
49
+ HEADER = """
50
+
51
+ <div align="center">
52
+ <h1 style='text-align: center'>E.T. the Exceptional Trajectories</h2>
53
+ <a href="https://robincourant.github.io/info/"><strong>Robin Courant</strong></a>
54
+ ·
55
+ <a href="https://nicolas-dufour.github.io/"><strong>Nicolas Dufour</strong></a>
56
+ ·
57
+ <a href="https://triocrossing.github.io/"><strong>Xi Wang</strong></a>
58
+ ·
59
+ <a href="http://people.irisa.fr/Marc.Christie/"><strong>Marc Christie</strong></a>
60
+ ·
61
+ <a href="https://vicky.kalogeiton.info/"><strong>Vicky Kalogeiton</strong></a>
62
+ </div>
63
+
64
+
65
+ <div align="center">
66
+ <a href="https://www.lix.polytechnique.fr/vista/projects/2024_et_courant/" class="button"><b>[Webpage]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
67
+ <a href="https://github.com/robincourant/DIRECTOR" class="button"><b>[DIRECTOR]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
68
+ <a href="https://github.com/robincourant/CLaTr" class="button"><b>[CLaTr]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
69
+ <a href="https://github.com/robincourant/the-exceptional-trajectories" class="button"><b>[Data]</b></a> &nbsp;&nbsp;&nbsp;&nbsp;
70
+ </div>
71
+
72
+ <br/>
73
+ """
74
+
75
+ # ------------------------------------------------------------------------------------- #
76
+
77
+
78
+ def get_normals(vertices: torch.Tensor, faces: torch.Tensor) -> torch.Tensor:
79
+ num_frames, num_faces = vertices.shape[0], faces.shape[-2]
80
+ faces = faces.expand(num_frames, num_faces, 3)
81
+
82
+ verts_rgb = torch.ones_like(vertices)
83
+ verts_rgb[:, :, 1] = 0
84
+ textures = TexturesVertex(verts_features=verts_rgb)
85
+ meshes = Meshes(verts=vertices, faces=faces, textures=textures)
86
+ normals = meshes.verts_normals_padded()
87
+
88
+ return normals, meshes
89
+
90
+
91
+ def generate(
92
+ prompt: str,
93
+ seed: int,
94
+ guidance_weight: float,
95
+ sample_label: str,
96
+ # ----------------------- ß#
97
+ dataset: MultimodalDataset,
98
+ device: torch.device,
99
+ diffuser: Diffuser,
100
+ clip_model: clip.model.CLIP,
101
+ ) -> Dict[str, Any]:
102
+ # Set arguments
103
+ set_random_seed(seed)
104
+ diffuser.gen_seeds = np.array([seed])
105
+ diffuser.guidance_weight = guidance_weight
106
+
107
+ # Inference
108
+ sample_id = SAMPLE_IDS[LABEL_TO_IDS[sample_label]]
109
+ seq_feat = diffuser.net.model.clip_sequential
110
+ batch = get_batch(prompt, sample_id, clip_model, dataset, seq_feat, device)
111
+ with torch.no_grad():
112
+ out = diffuser.predict_step(batch, 0)
113
+
114
+ # Run visualization
115
+ padding_mask = out["padding_mask"][0].to(bool).cpu()
116
+ padded_traj = out["gen_samples"][0].cpu()
117
+ traj = padded_traj[padding_mask]
118
+ padded_vertices = out["char_raw"]["char_vertices"][0]
119
+ vertices = padded_vertices[padding_mask]
120
+ faces = out["char_raw"]["char_faces"][0]
121
+ normals, meshes = get_normals(vertices, faces)
122
+ fx, fy, cx, cy = out["intrinsics"][0].cpu().numpy()
123
+ K = np.array([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
124
+ caption = out["caption_raw"][0]
125
+
126
+ rr.init(f"{sample_id}")
127
+ rr.save(".tmp_gr.rrd")
128
+ log_sample(
129
+ root_name="world",
130
+ traj=traj.numpy(),
131
+ K=K,
132
+ vertices=vertices.numpy(),
133
+ faces=faces.numpy(),
134
+ normals=normals.numpy(),
135
+ caption=caption,
136
+ mesh_masks=None,
137
+ )
138
+ return "./.tmp_gr.rrd"
139
+
140
+
141
+ # ------------------------------------------------------------------------------------- #
142
+
143
+
144
+ def main(gen_fn: Callable):
145
+ theme = gr.themes.Default(primary_hue="blue", secondary_hue="gray")
146
+
147
+ with gr.Blocks(theme=theme) as demo:
148
+ gr.Markdown(HEADER)
149
+
150
+ with gr.Row():
151
+ with gr.Column(scale=3):
152
+ with gr.Column(scale=2):
153
+ sample_str = gr.Dropdown(
154
+ choices=["static", "right", "complex"],
155
+ label="Character trajectory",
156
+ value="right",
157
+ interactive=True,
158
+ )
159
+ text = gr.Textbox(
160
+ placeholder="Type the camera motion you want to generate",
161
+ show_label=True,
162
+ label="Text prompt",
163
+ value=DEFAULT_TEXT[LABEL_TO_IDS[sample_str.value]],
164
+ )
165
+ seed = gr.Number(value=33, label="Seed")
166
+ guidance = gr.Slider(0, 10, value=1.4, label="Guidance", step=0.1)
167
+
168
+ with gr.Column(scale=1):
169
+ btn = gr.Button("Generate", variant="primary")
170
+
171
+ with gr.Column(scale=2):
172
+ examples = gr.Examples(
173
+ examples=[[x, None, None] for x in EXAMPLES],
174
+ inputs=[text],
175
+ )
176
+
177
+ with gr.Row():
178
+ output = Rerun()
179
+
180
+ def load_example(example_id):
181
+ processed_example = examples.non_none_processed_examples[example_id]
182
+ return gr.utils.resolve_singleton(processed_example)
183
+
184
+ def change_fn(change):
185
+ sample_index = LABEL_TO_IDS[change]
186
+ return gr.update(value=DEFAULT_TEXT[sample_index])
187
+
188
+ sample_str.change(fn=change_fn, inputs=[sample_str], outputs=[text])
189
+
190
+ inputs = [text, seed, guidance, sample_str]
191
+ examples.dataset.click(
192
+ load_example,
193
+ inputs=[examples.dataset],
194
+ outputs=examples.inputs_with_examples,
195
+ show_progress=False,
196
+ postprocess=False,
197
+ queue=False,
198
+ ).then(fn=gen_fn, inputs=inputs, outputs=[output])
199
+ btn.click(fn=gen_fn, inputs=inputs, outputs=[output])
200
+ text.submit(fn=gen_fn, inputs=inputs, outputs=[output])
201
+ demo.launch(share=False)
202
+
203
+
204
+ # ------------------------------------------------------------------------------------- #
205
+
206
+
207
+ if __name__ == "__main__":
208
+ # Initialize the models and dataset
209
+ diffuser, clip_model, dataset, device = init("config")
210
+ generate_sample = partial(
211
+ generate,
212
+ dataset=dataset,
213
+ device=device,
214
+ diffuser=diffuser,
215
+ clip_model=clip_model,
216
+ )
217
+
218
+ main(generate_sample)
configs/compnode/cpu.yaml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ device: cpu
2
+ num_gpus: 1
3
+ num_workers: 8
configs/config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ defaults:
2
+ - dataset: traj+caption+char
3
+ - diffuser: rn_director_edm
4
+ - compnode: cpu
5
+ - _self_
6
+
7
+ dataset:
8
+ char:
9
+ load_vertices: true
10
+
11
+ checkpoint_path: 'checkpoints/ca-mixed-e449.ckpt'
12
+ batch_size: 128
13
+ data_dir: data
14
+
15
+ hydra:
16
+ run:
17
+ dir: ./${results_dir}/${xp_name}/${timestamp}
configs/dataset/caption/caption.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.datasets.modalities.caption_dataset.CaptionDataset
2
+
3
+ name: caption
4
+
5
+ dataset_dir: ${dataset.dataset_dir}
6
+ segment_dir: ${dataset.dataset_dir}/cam_segments
7
+ raw_caption_dir: ${dataset.dataset_dir}/caption
8
+ feat_caption_dir: ${dataset.dataset_dir}/caption_clip
9
+
10
+ num_segments: 27
11
+ num_feats: 512
12
+ num_cams: ${dataset.standardization.num_cams}
13
+ sequential: ${diffuser.network.module.clip_sequential}
14
+ max_feat_length: 77
configs/dataset/char/char.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.datasets.modalities.char_dataset.CharacterDataset
2
+
3
+ name: char
4
+ dataset_dir: ${dataset.dataset_dir}
5
+ num_cams: ${dataset.num_cams}
6
+ num_raw_feats: 3
7
+ num_frequencies: 10
8
+ min_freq: 0
9
+ max_freq: 4
10
+ num_encoding: 3 # ${eval:'2 * ${dataset.char.num_frequencies} * ${dataset.char.num_raw_feats}'}
11
+ sequential: ${diffuser.network.module.cond_sequential}
12
+ num_feats: ${eval:'${dataset.char.num_encoding} if ${dataset.char.sequential} else ${dataset.num_cams} * ${dataset.char.num_encoding}'}
13
+ standardize: ${dataset.trajectory.standardize}
14
+ standardization: ${dataset.standardization}
15
+ load_vertices: ${diffuser.do_projection}
configs/dataset/standardization/0300.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: '0300'
2
+ num_interframes: 0
3
+ num_cams: 300
4
+ num_total_frames: ${eval:'${dataset.standardization.num_interframes} * (${dataset.standardization.num_cams} - 1) + ${dataset.standardization.num_cams} '}
5
+
6
+ norm_mean: [7.93987673e-05, -9.98621393e-05, 4.12940653e-04]
7
+ norm_std: [0.027841, 0.01819818, 0.03138536]
8
+
9
+ shift_mean: [0.00201079, -0.27488501, -1.23616805]
10
+ shift_std: [1.13433516, 1.19061042, 1.58744263]
11
+
12
+ norm_mean_h: [6.676e-05, -5.084e-05, -7.782e-04]
13
+ norm_std_h: [0.0105, 0.006958, 0.01145]
14
+
15
+ velocity: true
configs/dataset/traj+caption+char.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.datasets.multimodal_dataset.MultimodalDataset
2
+
3
+ defaults:
4
+ - _self_
5
+ - trajectory: rot6d_trajectory
6
+ - char: char
7
+ - caption: caption
8
+ - standardization: '0300'
9
+
10
+ name: "${dataset.standardization.name}-t:${dataset.trajectory.name}|c:${dataset.caption.name}|h:${dataset.char.name}"
11
+ dataset_name: ${dataset.standardization.name}
12
+ dataset_dir: ${data_dir}
13
+
14
+ num_rawfeats: 12
15
+ num_cams: ${dataset.standardization.num_cams}
16
+ feature_type: ${dataset.trajectory.name}
17
+ num_feats: ${dataset.trajectory.num_feats}
18
+ num_cond_feats: ['${dataset.char.num_feats}','${dataset.caption.num_feats}']
configs/dataset/trajectory/rot6d_trajectory.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.datasets.modalities.trajectory_dataset.TrajectoryDataset
2
+
3
+ name: rot6d
4
+ set_name: null
5
+ dataset_dir: ${dataset.dataset_dir}
6
+ num_feats: 9
7
+ num_rawfeats: ${dataset.num_rawfeats}
8
+ num_cams: ${dataset.num_cams}
9
+ standardize: true
10
+ standardization: ${dataset.standardization}
configs/diffuser/network/module/ca_director.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.models.modules.director.CrossAttentionDirector
2
+ name: ca_director
3
+ num_feats: ${dataset.num_feats}
4
+ num_rawfeats: ${dataset.num_rawfeats}
5
+ num_cams: ${dataset.num_cams}
6
+ num_cond_feats: ${dataset.num_cond_feats}
7
+ latent_dim: 512
8
+ mlp_multiplier: 4
9
+ num_layers: 8
10
+ num_heads: 16
11
+ dropout: 0.1
12
+ stochastic_depth: 0.1
13
+ label_dropout: 0.1
14
+ num_text_registers: 16
15
+ clip_sequential: True
16
+ cond_sequential: True
17
+ device: ${compnode.device}
configs/diffuser/network/rn_director.yaml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ _target_: src.models.networks.RnEDMPrecond
2
+
3
+ defaults:
4
+ - module: ca_director
5
+
6
+ name: rn_director
7
+ sigma_data: 0.5
configs/diffuser/rn_director_edm.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: src.diffuser.Diffuser
2
+
3
+ defaults:
4
+ - _self_
5
+ - network: rn_director
6
+
7
+ guidance_weight: 1.4
8
+ edm2_normalization: true
9
+
10
+ # EMA
11
+ ema_kwargs:
12
+ beta: 0.9999
13
+ update_every: 1
14
+
15
+ # Sampling
16
+ sampling_kwargs:
17
+ num_steps: 10
18
+ sigma_min: 0.002
19
+ sigma_max: 80
20
+ rho: 40
21
+ S_churn: 0
22
+ S_min: 0
23
+ S_max: inf
24
+ S_noise: 1
src/datasets/datamodule.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lightning import LightningDataModule
2
+ from torch.utils.data import Dataset, DataLoader
3
+
4
+
5
+ class Datamodule(LightningDataModule):
6
+ def __init__(
7
+ self,
8
+ train_dataset: Dataset,
9
+ eval_dataset: Dataset,
10
+ batch_train_size: int,
11
+ num_workers: int,
12
+ eval_batch_size: int = None,
13
+ ):
14
+ super().__init__()
15
+
16
+ self.train_dataset = train_dataset
17
+ self.eval_dataset = eval_dataset
18
+
19
+ self.batch_train_size = batch_train_size
20
+ self.eval_batch_size = (
21
+ eval_batch_size if eval_batch_size is not None else batch_train_size
22
+ )
23
+
24
+ self.num_workers = num_workers
25
+
26
+ def train_dataloader(self) -> DataLoader:
27
+ """Load train set loader."""
28
+ persistent_workers = True if self.num_workers > 0 else False
29
+
30
+ dataloader = DataLoader(
31
+ self.train_dataset,
32
+ batch_size=self.batch_train_size,
33
+ num_workers=self.num_workers,
34
+ pin_memory=True,
35
+ persistent_workers=persistent_workers,
36
+ )
37
+ return dataloader
38
+
39
+ def val_dataloader(self) -> DataLoader:
40
+ """Load val set loader."""
41
+ persistent_workers = True if self.num_workers > 0 else False
42
+
43
+ dataloader = DataLoader(
44
+ self.eval_dataset,
45
+ batch_size=self.eval_batch_size,
46
+ num_workers=self.num_workers,
47
+ pin_memory=True,
48
+ persistent_workers=persistent_workers,
49
+ )
50
+ return dataloader
51
+
52
+ def predict_dataloader(self) -> DataLoader:
53
+ """Load predict set loader."""
54
+ dataloader = DataLoader(
55
+ self.eval_dataset,
56
+ batch_size=self.eval_batch_size,
57
+ num_workers=self.num_workers,
58
+ )
59
+ return dataloader
60
+
61
+ def test_dataloader(self) -> DataLoader:
62
+ """Load test set loader."""
63
+ dataloader = DataLoader(
64
+ self.eval_dataset,
65
+ batch_size=self.eval_batch_size,
66
+ num_workers=self.num_workers,
67
+ )
68
+ return dataloader
src/datasets/modalities/caption_dataset.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import Counter
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ import torch.nn.functional as F
8
+
9
+ from utils.file_utils import load_txt
10
+
11
+
12
+ class CaptionDataset(Dataset):
13
+ def __init__(
14
+ self,
15
+ name: str,
16
+ dataset_dir: str,
17
+ num_cams: int,
18
+ num_feats: int,
19
+ num_segments: int,
20
+ sequential: bool,
21
+ **kwargs,
22
+ ):
23
+ super().__init__()
24
+ self.modality = name
25
+ self.name = name
26
+ self.dataset_dir = Path(dataset_dir)
27
+ # Set data paths (segments, captions, etc...)
28
+ for name, field in kwargs.items():
29
+ if isinstance(field, str):
30
+ field = Path(field)
31
+ if name == "feat_caption_dir":
32
+ field = field / "seq" if sequential else field / "token"
33
+ setattr(self, name, field)
34
+
35
+ self.filenames = None
36
+
37
+ self.clip_seq_dir = self.dataset_dir / "caption_clip" / "seq" # For CLaTrScore
38
+ self.num_cams = num_cams
39
+ self.num_feats = num_feats
40
+ self.num_segments = num_segments
41
+ self.sequential = sequential
42
+
43
+ def __len__(self):
44
+ return len(self.filenames)
45
+
46
+ def __getitem__(self, index):
47
+ filename = self.filenames[index]
48
+
49
+ # Load data
50
+ if hasattr(self, "segment_dir"):
51
+ raw_segments = torch.from_numpy(
52
+ np.load((self.segment_dir / (filename + ".npy")))
53
+ )
54
+ padded_raw_segments = F.pad(
55
+ raw_segments,
56
+ (0, self.num_cams - len(raw_segments)),
57
+ value=self.num_segments,
58
+ )
59
+ if hasattr(self, "raw_caption_dir"):
60
+ raw_caption = load_txt(self.raw_caption_dir / (filename + ".txt"))
61
+ if hasattr(self, "feat_caption_dir"):
62
+ feat_caption = torch.from_numpy(
63
+ np.load((self.feat_caption_dir / (filename + ".npy")))
64
+ )
65
+ if self.sequential:
66
+ feat_caption = F.pad(
67
+ feat_caption.to(torch.float32),
68
+ (0, 0, 0, self.max_feat_length - feat_caption.shape[0]),
69
+ )
70
+
71
+ if self.modality == "caption":
72
+ raw_data = {"caption": raw_caption, "segments": padded_raw_segments}
73
+ feat_data = (
74
+ feat_caption.permute(1, 0) if feat_caption.dim() == 2 else feat_caption
75
+ )
76
+ elif self.modality == "segments":
77
+ raw_data = {"segments": padded_raw_segments}
78
+ # Shift by one for padding
79
+ feat_data = F.one_hot(
80
+ padded_raw_segments, num_classes=self.num_segments + 1
81
+ ).to(torch.float32)
82
+ if self.sequential:
83
+ feat_data = feat_data.permute(1, 0)
84
+ else:
85
+ feat_data = feat_data.reshape(-1)
86
+ elif self.modality == "class":
87
+ raw_data = {"segments": padded_raw_segments}
88
+ most_frequent_segment = Counter(raw_segments).most_common(1)[0][0]
89
+ feat_data = F.one_hot(
90
+ torch.tensor(most_frequent_segment), num_classes=self.num_segments
91
+ ).to(torch.float32)
92
+ else:
93
+ raise ValueError(f"Modality {self.modality} not supported")
94
+
95
+ clip_seq_caption = torch.from_numpy(
96
+ np.load((self.clip_seq_dir / (filename + ".npy")))
97
+ )
98
+ padding_mask = torch.ones((self.max_feat_length))
99
+ padding_mask[clip_seq_caption.shape[0] :] = 0
100
+ clip_seq_caption = F.pad(
101
+ clip_seq_caption.to(torch.float32),
102
+ (0, 0, 0, self.max_feat_length - clip_seq_caption.shape[0]),
103
+ )
104
+ raw_data["clip_seq_caption"] = clip_seq_caption
105
+ raw_data["clip_seq_mask"] = padding_mask
106
+
107
+ return filename, feat_data, raw_data
src/datasets/modalities/char_dataset.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+ import torch.nn.functional as F
7
+
8
+ # ------------------------------------------------------------------------------------- #
9
+
10
+ num_frequencies = None
11
+
12
+ # ------------------------------------------------------------------------------------- #
13
+
14
+
15
+ class CharacterDataset(Dataset):
16
+ def __init__(
17
+ self,
18
+ name: str,
19
+ dataset_dir: str,
20
+ standardize: bool,
21
+ num_feats: int,
22
+ num_cams: int,
23
+ sequential: bool,
24
+ num_frequencies: int,
25
+ min_freq: int,
26
+ max_freq: int,
27
+ load_vertices: bool,
28
+ **kwargs,
29
+ ):
30
+ super().__init__()
31
+ self.modality = "char"
32
+ self.name = name
33
+ self.dataset_dir = Path(dataset_dir)
34
+ self.traj_dir = self.dataset_dir / "traj"
35
+ self.data_dir = self.dataset_dir / self.name
36
+ self.vert_dir = self.dataset_dir / "vert_raw"
37
+ self.center_dir = self.dataset_dir / "char_raw"
38
+
39
+ self.filenames = None
40
+ self.standardize = standardize
41
+ if self.standardize:
42
+ mean_std = kwargs["standardization"]
43
+ self.norm_mean = torch.Tensor(mean_std["norm_mean_h"])[:, None]
44
+ self.norm_std = torch.Tensor(mean_std["norm_std_h"])[:, None]
45
+ self.velocity = mean_std["velocity"]
46
+
47
+ self.num_cams = num_cams
48
+ self.num_feats = num_feats
49
+ self.sequential = sequential
50
+ self.num_frequencies = num_frequencies
51
+ self.min_freq = min_freq
52
+ self.max_freq = max_freq
53
+
54
+ self.load_vertices = load_vertices
55
+
56
+ def __len__(self):
57
+ return len(self.filenames)
58
+
59
+ def __getitem__(self, index):
60
+ filename = self.filenames[index]
61
+
62
+ char_filename = filename + ".npy"
63
+ char_path = self.data_dir / char_filename
64
+
65
+ raw_char_feature = torch.from_numpy(np.load((char_path))).to(torch.float32)
66
+ padding_size = self.num_cams - raw_char_feature.shape[0]
67
+ padded_raw_char_feature = F.pad(
68
+ raw_char_feature, (0, 0, 0, padding_size)
69
+ ).permute(1, 0)
70
+
71
+ center_path = self.center_dir / char_filename # Center to offset mesh
72
+ center_offset = torch.from_numpy(np.load(center_path)[0]).to(torch.float32)
73
+ if self.load_vertices:
74
+ vert_path = self.vert_dir / char_filename
75
+ raw_verts = np.load(vert_path, allow_pickle=True)[()]
76
+ if raw_verts["vertices"] is None:
77
+ num_frames = raw_char_feature.shape[0]
78
+ verts = torch.zeros((num_frames, 6890, 3), dtype=torch.float32)
79
+ padded_verts = torch.zeros(
80
+ (self.num_cams, 6890, 3), dtype=torch.float32
81
+ )
82
+ faces = torch.zeros((13776, 3), dtype=torch.int16)
83
+ else:
84
+ verts = torch.from_numpy(raw_verts["vertices"]).to(torch.float32)
85
+ verts -= center_offset
86
+ padded_verts = F.pad(verts, (0, 0, 0, 0, 0, padding_size))
87
+ faces = torch.from_numpy(raw_verts["faces"]).to(torch.int16)
88
+
89
+ char_feature = raw_char_feature.clone()
90
+ if self.velocity:
91
+ velocity = char_feature[1:].clone() - char_feature[:-1].clone()
92
+ char_feature = torch.cat([raw_char_feature[0][None], velocity])
93
+
94
+ if self.standardize:
95
+ # Normalize the first frame (orgin) and the rest (velocity) separately
96
+ if len(self.norm_mean) == 6:
97
+ char_feature[0] -= self.norm_mean[:3, 0].to(raw_char_feature.device)
98
+ char_feature[0] /= self.norm_std[:3, 0].to(raw_char_feature.device)
99
+ char_feature[1:] -= self.norm_mean[3:, 0].to(raw_char_feature.device)
100
+ char_feature[1:] /= self.norm_std[3:, 0].to(raw_char_feature.device)
101
+ # Normalize all in one
102
+ else:
103
+ char_feature -= self.norm_mean[:, 0].to(raw_char_feature.device)
104
+ char_feature /= self.norm_std[:, 0].to(raw_char_feature.device)
105
+ padded_char_feature = F.pad(
106
+ char_feature,
107
+ (0, 0, 0, self.num_cams - char_feature.shape[0]),
108
+ )
109
+
110
+ if self.sequential:
111
+ padded_char_feature = padded_char_feature.permute(1, 0)
112
+ else:
113
+ padded_char_feature = padded_char_feature.reshape(-1)
114
+
115
+ raw_feats = {"char_raw_feat": padded_raw_char_feature}
116
+ if self.load_vertices:
117
+ raw_feats["char_vertices"] = padded_verts
118
+ raw_feats["char_faces"] = faces
119
+
120
+ return char_filename, padded_char_feature, raw_feats
src/datasets/modalities/trajectory_dataset.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ from evo.tools.file_interface import read_kitti_poses_file
4
+ import numpy as np
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from torchtyping import TensorType
8
+ import torch.nn.functional as F
9
+ from typing import Tuple
10
+
11
+ from utils.file_utils import load_txt
12
+ from utils.rotation_utils import compute_rotation_matrix_from_ortho6d
13
+
14
+ num_cams = None
15
+
16
+
17
+ # ------------------------------------------------------------------------------------- #
18
+
19
+
20
+ class TrajectoryDataset(Dataset):
21
+ def __init__(
22
+ self,
23
+ name: str,
24
+ set_name: str,
25
+ dataset_dir: str,
26
+ num_rawfeats: int,
27
+ num_feats: int,
28
+ num_cams: int,
29
+ standardize: bool,
30
+ **kwargs,
31
+ ):
32
+ super().__init__()
33
+ self.name = name
34
+ self.set_name = set_name
35
+ self.dataset_dir = Path(dataset_dir)
36
+ if name == "relative":
37
+ self.data_dir = self.dataset_dir / "traj_raw"
38
+ self.relative_dir = self.dataset_dir / "relative"
39
+ else:
40
+ self.data_dir = self.dataset_dir / "traj"
41
+ self.intrinsics_dir = self.dataset_dir / "intrinsics"
42
+
43
+ self.num_rawfeats = num_rawfeats
44
+ self.num_feats = num_feats
45
+ self.num_cams = num_cams
46
+
47
+ self.augmentation = None
48
+ self.standardize = standardize
49
+ if self.standardize:
50
+ mean_std = kwargs["standardization"]
51
+ self.norm_mean = torch.Tensor(mean_std["norm_mean"])
52
+ self.norm_std = torch.Tensor(mean_std["norm_std"])
53
+ self.shift_mean = torch.Tensor(mean_std["shift_mean"])
54
+ self.shift_std = torch.Tensor(mean_std["shift_std"])
55
+ self.velocity = mean_std["velocity"]
56
+
57
+ # --------------------------------------------------------------------------------- #
58
+
59
+ def set_split(self, split: str, train_rate: float = 1.0):
60
+ self.split = split
61
+ split_path = Path(self.dataset_dir) / f"{split}_split.txt"
62
+ split_traj = load_txt(split_path).split("\n")
63
+ self.filenames = sorted(split_traj)
64
+
65
+ return self
66
+
67
+ # --------------------------------------------------------------------------------- #
68
+
69
+ def get_feature(
70
+ self, raw_matrix_trajectory: TensorType["num_cams", 4, 4]
71
+ ) -> TensorType[9, "num_cams"]:
72
+ matrix_trajectory = torch.clone(raw_matrix_trajectory)
73
+
74
+ raw_trans = torch.clone(matrix_trajectory[:, :3, 3])
75
+ if self.velocity:
76
+ velocity = raw_trans[1:] - raw_trans[:-1]
77
+ raw_trans = torch.cat([raw_trans[0][None], velocity])
78
+ if self.standardize:
79
+ raw_trans[0] -= self.shift_mean
80
+ raw_trans[0] /= self.shift_std
81
+ raw_trans[1:] -= self.norm_mean
82
+ raw_trans[1:] /= self.norm_std
83
+
84
+ # Compute the 6D continuous rotation
85
+ raw_rot = matrix_trajectory[:, :3, :3]
86
+ rot6d = raw_rot[:, :, :2].permute(0, 2, 1).reshape(-1, 6)
87
+
88
+ # Stack rotation 6D and translation
89
+ rot6d_trajectory = torch.hstack([rot6d, raw_trans]).permute(1, 0)
90
+
91
+ return rot6d_trajectory
92
+
93
+ def get_matrix(
94
+ self, raw_rot6d_trajectory: TensorType[9, "num_cams"]
95
+ ) -> TensorType["num_cams", 4, 4]:
96
+ rot6d_trajectory = torch.clone(raw_rot6d_trajectory)
97
+ device = rot6d_trajectory.device
98
+
99
+ num_cams = rot6d_trajectory.shape[1]
100
+ matrix_trajectory = torch.eye(4, device=device)[None].repeat(num_cams, 1, 1)
101
+
102
+ raw_trans = rot6d_trajectory[6:].permute(1, 0)
103
+ if self.standardize:
104
+ raw_trans[0] *= self.shift_std.to(device)
105
+ raw_trans[0] += self.shift_mean.to(device)
106
+ raw_trans[1:] *= self.norm_std.to(device)
107
+ raw_trans[1:] += self.norm_mean.to(device)
108
+ if self.velocity:
109
+ raw_trans = torch.cumsum(raw_trans, dim=0)
110
+ matrix_trajectory[:, :3, 3] = raw_trans
111
+
112
+ rot6d = rot6d_trajectory[:6].permute(1, 0)
113
+ raw_rot = compute_rotation_matrix_from_ortho6d(rot6d)
114
+ matrix_trajectory[:, :3, :3] = raw_rot
115
+
116
+ return matrix_trajectory
117
+
118
+ # --------------------------------------------------------------------------------- #
119
+
120
+ def __getitem__(self, index: int) -> Tuple[str, TensorType["num_cams", 4, 4]]:
121
+ filename = self.filenames[index]
122
+
123
+ trajectory_filename = filename + ".txt"
124
+ trajectory_path = self.data_dir / trajectory_filename
125
+
126
+ trajectory = read_kitti_poses_file(trajectory_path)
127
+ matrix_trajectory = torch.from_numpy(np.array(trajectory.poses_se3)).to(
128
+ torch.float32
129
+ )
130
+
131
+ trajectory_feature = self.get_feature(matrix_trajectory)
132
+
133
+ padded_trajectory_feature = F.pad(
134
+ trajectory_feature, (0, self.num_cams - trajectory_feature.shape[1])
135
+ )
136
+ # Padding mask: 1 for valid cams, 0 for padded cams
137
+ padding_mask = torch.ones((self.num_cams))
138
+ padding_mask[trajectory_feature.shape[1] :] = 0
139
+
140
+ intrinsics_filename = filename + ".npy"
141
+ intrinsics_path = self.intrinsics_dir / intrinsics_filename
142
+ intrinsics = np.load(intrinsics_path)
143
+
144
+ return (
145
+ trajectory_filename,
146
+ padded_trajectory_feature,
147
+ padding_mask,
148
+ intrinsics
149
+ )
150
+
151
+ def __len__(self):
152
+ return len(self.filenames)
src/datasets/multimodal_dataset.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy as dp
2
+ from pathlib import Path
3
+
4
+ from torch.utils.data import Dataset
5
+
6
+
7
+ class MultimodalDataset(Dataset):
8
+ def __init__(
9
+ self,
10
+ name,
11
+ dataset_name,
12
+ dataset_dir,
13
+ trajectory,
14
+ feature_type,
15
+ num_rawfeats,
16
+ num_feats,
17
+ num_cams,
18
+ num_cond_feats,
19
+ standardization,
20
+ augmentation=None,
21
+ **modalities,
22
+ ):
23
+ self.dataset_dir = Path(dataset_dir)
24
+ self.name = name
25
+ self.dataset_name = dataset_name
26
+ self.feature_type = feature_type
27
+ self.num_rawfeats = num_rawfeats
28
+ self.num_feats = num_feats
29
+ self.num_cams = num_cams
30
+ self.trajectory_dataset = trajectory
31
+ self.standardization = standardization
32
+ self.modality_datasets = modalities
33
+
34
+ if augmentation is not None:
35
+ self.augmentation = True
36
+ self.augmentation_rate = augmentation.rate
37
+ self.trajectory_dataset.set_augmentation(augmentation.trajectory)
38
+ if hasattr(augmentation, "modalities"):
39
+ for modality, augments in augmentation.modalities:
40
+ self.modality_datasets[modality].set_augmentation(augments)
41
+ else:
42
+ self.augmentation = False
43
+
44
+ # --------------------------------------------------------------------------------- #
45
+
46
+ def set_split(self, split: str, train_rate: float = 1.0):
47
+ self.split = split
48
+
49
+ # Get trajectory split
50
+ self.trajectory_dataset = dp(self.trajectory_dataset).set_split(
51
+ split, train_rate
52
+ )
53
+ self.root_filenames = self.trajectory_dataset.filenames
54
+
55
+ # Get modality split
56
+ for modality_name in self.modality_datasets.keys():
57
+ self.modality_datasets[modality_name].filenames = self.root_filenames
58
+
59
+ self.get_feature = self.trajectory_dataset.get_feature
60
+ self.get_matrix = self.trajectory_dataset.get_matrix
61
+
62
+ return self
63
+
64
+ # --------------------------------------------------------------------------------- #
65
+
66
+ def __getitem__(self, index):
67
+ traj_out = self.trajectory_dataset[index]
68
+ traj_filename, traj_feature, padding_mask, intrinsics = traj_out
69
+
70
+ out = {
71
+ "traj_filename": traj_filename,
72
+ "traj_feat": traj_feature,
73
+ "padding_mask": padding_mask,
74
+ "intrinsics": intrinsics,
75
+ }
76
+
77
+ for modality_name, modality_dataset in self.modality_datasets.items():
78
+ modality_filename, modality_feature, modality_raw = modality_dataset[index]
79
+ assert traj_filename.split(".")[0] == modality_filename.split(".")[0]
80
+ out[f"{modality_name}_filename"] = modality_filename
81
+ out[f"{modality_name}_feat"] = modality_feature
82
+ out[f"{modality_name}_raw"] = modality_raw
83
+ out[f"{modality_name}_padding_mask"] = padding_mask
84
+
85
+ return out
86
+
87
+ def __len__(self):
88
+ return len(self.trajectory_dataset)
src/diffuser.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from omegaconf.dictconfig import DictConfig
2
+ from typing import List, Tuple
3
+
4
+ from ema_pytorch import EMA
5
+ import numpy as np
6
+ import torch
7
+ from torchtyping import TensorType
8
+ import torch.nn as nn
9
+ import lightning as L
10
+
11
+ from utils.random_utils import StackedRandomGenerator
12
+
13
+ # ------------------------------------------------------------------------------------- #
14
+
15
+ batch_size, num_samples = None, None
16
+ num_feats, num_rawfeats, num_cams = None, None, None
17
+ RawTrajectory = TensorType["num_samples", "num_rawfeats", "num_cams"]
18
+
19
+ # ------------------------------------------------------------------------------------- #
20
+
21
+
22
+ class Diffuser(L.LightningModule):
23
+ def __init__(
24
+ self,
25
+ network: nn.Module,
26
+ guidance_weight: float,
27
+ ema_kwargs: DictConfig,
28
+ sampling_kwargs: DictConfig,
29
+ edm2_normalization: bool,
30
+ **kwargs,
31
+ ):
32
+ super().__init__()
33
+
34
+ # Network and EMA
35
+ self.net = network
36
+ self.ema = EMA(self.net, **ema_kwargs)
37
+ self.guidance_weight = guidance_weight
38
+ self.edm2_normalization = edm2_normalization
39
+ self.sigma_data = network.sigma_data
40
+
41
+ # Sampling
42
+ self.num_steps = sampling_kwargs.num_steps
43
+ self.sigma_min = sampling_kwargs.sigma_min
44
+ self.sigma_max = sampling_kwargs.sigma_max
45
+ self.rho = sampling_kwargs.rho
46
+ self.S_churn = sampling_kwargs.S_churn
47
+ self.S_noise = sampling_kwargs.S_noise
48
+ self.S_min = sampling_kwargs.S_min
49
+ self.S_max = (
50
+ sampling_kwargs.S_max
51
+ if isinstance(sampling_kwargs.S_max, float)
52
+ else float("inf")
53
+ )
54
+
55
+ # ---------------------------------------------------------------------------------- #
56
+
57
+ def on_predict_start(self):
58
+ eval_dataset = self.trainer.datamodule.eval_dataset
59
+ self.modalities = list(eval_dataset.modality_datasets.keys())
60
+
61
+ self.get_matrix = self.trainer.datamodule.train_dataset.get_matrix
62
+ self.v_get_matrix = self.trainer.datamodule.eval_dataset.get_matrix
63
+
64
+ def predict_step(self, batch, batch_idx):
65
+ ref_samples, mask = batch["traj_feat"], batch["padding_mask"]
66
+
67
+ if len(self.modalities) > 0:
68
+ cond_k = [x for x in batch.keys() if "traj" not in x and "feat" in x]
69
+ cond_data = [batch[cond] for cond in cond_k]
70
+ conds = {}
71
+ for cond in cond_k:
72
+ cond_name = cond.replace("_feat", "")
73
+ if isinstance(batch[f"{cond_name}_raw"], dict):
74
+ for cond_name_, x in batch[f"{cond_name}_raw"].items():
75
+ conds[cond_name_] = x
76
+ else:
77
+ conds[cond_name] = batch[f"{cond_name}_raw"]
78
+ batch["conds"] = conds
79
+ else:
80
+ cond_data = None
81
+
82
+ # cf edm2 sigma_data normalization / https://arxiv.org/pdf/2312.02696.pdf
83
+ if self.edm2_normalization:
84
+ ref_samples *= self.sigma_data
85
+ _, gen_samples = self.sample(self.ema.ema_model, ref_samples, cond_data, mask)
86
+
87
+ batch["ref_samples"] = torch.stack([self.v_get_matrix(x) for x in ref_samples])
88
+ batch["gen_samples"] = torch.stack([self.get_matrix(x) for x in gen_samples])
89
+
90
+ return batch
91
+
92
+ # --------------------------------------------------------------------------------- #
93
+
94
+ def sample(
95
+ self,
96
+ net: torch.nn.Module,
97
+ traj_samples: RawTrajectory,
98
+ cond_samples: TensorType["num_samples", "num_feats"],
99
+ mask: TensorType["num_samples", "num_feats"],
100
+ external_seeds: List[int] = None,
101
+ ) -> Tuple[RawTrajectory, RawTrajectory]:
102
+ # Pick latents
103
+ num_samples = traj_samples.shape[0]
104
+ seeds = self.gen_seeds if hasattr(self, "gen_seeds") else range(num_samples)
105
+ rnd = StackedRandomGenerator(self.device, seeds)
106
+
107
+ sz = [num_samples, self.net.num_feats, self.net.num_cams]
108
+ latents = rnd.randn_rn(sz, device=self.device)
109
+ # Generate trajectories.
110
+ generations = self.edm_sampler(
111
+ net,
112
+ latents,
113
+ class_labels=cond_samples,
114
+ mask=mask,
115
+ randn_like=rnd.randn_like,
116
+ guidance_weight=self.guidance_weight,
117
+ # ----------------------------------- #
118
+ num_steps=self.num_steps,
119
+ sigma_min=self.sigma_min,
120
+ sigma_max=self.sigma_max,
121
+ rho=self.rho,
122
+ S_churn=self.S_churn,
123
+ S_min=self.S_min,
124
+ S_max=self.S_max,
125
+ S_noise=self.S_noise,
126
+ )
127
+
128
+ return latents, generations
129
+
130
+ @staticmethod
131
+ def edm_sampler(
132
+ net,
133
+ latents,
134
+ class_labels=None,
135
+ mask=None,
136
+ guidance_weight=2.0,
137
+ randn_like=torch.randn_like,
138
+ num_steps=18,
139
+ sigma_min=0.002,
140
+ sigma_max=80,
141
+ rho=7,
142
+ S_churn=0,
143
+ S_min=0,
144
+ S_max=float("inf"),
145
+ S_noise=1,
146
+ ):
147
+ # Time step discretization.
148
+ step_indices = torch.arange(num_steps, device=latents.device)
149
+ t_steps = (
150
+ sigma_max ** (1 / rho)
151
+ + step_indices
152
+ / (num_steps - 1)
153
+ * (sigma_min ** (1 / rho) - sigma_max ** (1 / rho))
154
+ ) ** rho
155
+ t_steps = torch.cat(
156
+ [torch.as_tensor(t_steps), torch.zeros_like(t_steps[:1])]
157
+ ) # t_N = 0
158
+
159
+ # Main sampling loop.
160
+ bool_mask = ~mask.to(bool)
161
+ x_next = latents * t_steps[0]
162
+ bs = latents.shape[0]
163
+ for i, (t_cur, t_next) in enumerate(
164
+ zip(t_steps[:-1], t_steps[1:])
165
+ ): # 0, ..., N-1
166
+ x_cur = x_next
167
+
168
+ # Increase noise temporarily.
169
+ gamma = (
170
+ min(S_churn / num_steps, np.sqrt(2) - 1)
171
+ if S_min <= t_cur <= S_max
172
+ else 0
173
+ )
174
+ t_hat = torch.as_tensor(t_cur + gamma * t_cur)
175
+ x_hat = x_cur + (t_hat**2 - t_cur**2).sqrt() * S_noise * randn_like(x_cur)
176
+
177
+ # Euler step.
178
+ if class_labels is not None:
179
+ class_label_knot = [torch.zeros_like(label) for label in class_labels]
180
+ x_hat_both = torch.cat([x_hat, x_hat], dim=0)
181
+ y_label_both = [
182
+ torch.cat([y, y_knot], dim=0)
183
+ for y, y_knot in zip(class_labels, class_label_knot)
184
+ ]
185
+
186
+ bool_mask_both = torch.cat([bool_mask, bool_mask], dim=0)
187
+ t_hat_both = torch.cat([t_hat.expand(bs), t_hat.expand(bs)], dim=0)
188
+ cond_denoised, denoised = net(
189
+ x_hat_both, t_hat_both, y=y_label_both, mask=bool_mask_both
190
+ ).chunk(2, dim=0)
191
+ denoised = denoised + (cond_denoised - denoised) * guidance_weight
192
+ else:
193
+ denoised = net(x_hat, t_hat.expand(bs), mask=bool_mask)
194
+ d_cur = (x_hat - denoised) / t_hat
195
+ x_next = x_hat + (t_next - t_hat) * d_cur
196
+
197
+ # Apply 2nd order correction.
198
+ if i < num_steps - 1:
199
+ if class_labels is not None:
200
+ class_label_knot = [
201
+ torch.zeros_like(label) for label in class_labels
202
+ ]
203
+ x_next_both = torch.cat([x_next, x_next], dim=0)
204
+ y_label_both = [
205
+ torch.cat([y, y_knot], dim=0)
206
+ for y, y_knot in zip(class_labels, class_label_knot)
207
+ ]
208
+ bool_mask_both = torch.cat([bool_mask, bool_mask], dim=0)
209
+ t_next_both = torch.cat(
210
+ [t_next.expand(bs), t_next.expand(bs)], dim=0
211
+ )
212
+ cond_denoised, denoised = net(
213
+ x_next_both, t_next_both, y=y_label_both, mask=bool_mask_both
214
+ ).chunk(2, dim=0)
215
+ denoised = denoised + (cond_denoised - denoised) * guidance_weight
216
+ else:
217
+ denoised = net(x_next, t_next.expand(bs), mask=bool_mask)
218
+ d_prime = (x_next - denoised) / t_next
219
+ x_next = x_hat + (t_next - t_hat) * (0.5 * d_cur + 0.5 * d_prime)
220
+
221
+ return x_next
src/models/modules/director.py ADDED
@@ -0,0 +1,1154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch import Tensor
4
+ import numpy as np
5
+ from einops import rearrange
6
+
7
+ from typing import Optional, List
8
+ from torchtyping import TensorType
9
+ from einops._torch_specific import allow_ops_in_compiled_graph # requires einops>=0.6.1
10
+
11
+ allow_ops_in_compiled_graph()
12
+
13
+ batch_size, num_cond_feats = None, None
14
+
15
+
16
+ class FusedMLP(nn.Sequential):
17
+ def __init__(
18
+ self,
19
+ dim_model: int,
20
+ dropout: float,
21
+ activation: nn.Module,
22
+ hidden_layer_multiplier: int = 4,
23
+ bias: bool = True,
24
+ ):
25
+ super().__init__(
26
+ nn.Linear(dim_model, dim_model * hidden_layer_multiplier, bias=bias),
27
+ activation(),
28
+ nn.Dropout(dropout),
29
+ nn.Linear(dim_model * hidden_layer_multiplier, dim_model, bias=bias),
30
+ )
31
+
32
+
33
+ def _cast_if_autocast_enabled(tensor):
34
+ if torch.is_autocast_enabled():
35
+ if tensor.device.type == "cuda":
36
+ dtype = torch.get_autocast_gpu_dtype()
37
+ elif tensor.device.type == "cpu":
38
+ dtype = torch.get_autocast_cpu_dtype()
39
+ else:
40
+ raise NotImplementedError()
41
+ return tensor.to(dtype=dtype)
42
+ return tensor
43
+
44
+
45
+ class LayerNorm16Bits(torch.nn.LayerNorm):
46
+ """
47
+ 16-bit friendly version of torch.nn.LayerNorm
48
+ """
49
+
50
+ def __init__(
51
+ self,
52
+ normalized_shape,
53
+ eps=1e-06,
54
+ elementwise_affine=True,
55
+ device=None,
56
+ dtype=None,
57
+ ):
58
+ super().__init__(
59
+ normalized_shape=normalized_shape,
60
+ eps=eps,
61
+ elementwise_affine=elementwise_affine,
62
+ device=device,
63
+ dtype=dtype,
64
+ )
65
+
66
+ def forward(self, x):
67
+ module_device = x.device
68
+ downcast_x = _cast_if_autocast_enabled(x)
69
+ downcast_weight = (
70
+ _cast_if_autocast_enabled(self.weight)
71
+ if self.weight is not None
72
+ else self.weight
73
+ )
74
+ downcast_bias = (
75
+ _cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
76
+ )
77
+ with torch.autocast(enabled=False, device_type=module_device.type):
78
+ return nn.functional.layer_norm(
79
+ downcast_x,
80
+ self.normalized_shape,
81
+ downcast_weight,
82
+ downcast_bias,
83
+ self.eps,
84
+ )
85
+
86
+
87
+ class StochatichDepth(nn.Module):
88
+ def __init__(self, p: float):
89
+ super().__init__()
90
+ self.survival_prob = 1.0 - p
91
+
92
+ def forward(self, x: Tensor) -> Tensor:
93
+ if self.training and self.survival_prob < 1:
94
+ mask = (
95
+ torch.empty(x.shape[0], 1, 1, device=x.device).uniform_()
96
+ + self.survival_prob
97
+ )
98
+ mask = mask.floor()
99
+ if self.survival_prob > 0:
100
+ mask = mask / self.survival_prob
101
+ return x * mask
102
+ else:
103
+ return x
104
+
105
+
106
+ class CrossAttentionOp(nn.Module):
107
+ def __init__(
108
+ self, attention_dim, num_heads, dim_q, dim_kv, use_biases=True, is_sa=False
109
+ ):
110
+ super().__init__()
111
+ self.dim_q = dim_q
112
+ self.dim_kv = dim_kv
113
+ self.attention_dim = attention_dim
114
+ self.num_heads = num_heads
115
+ self.use_biases = use_biases
116
+ self.is_sa = is_sa
117
+ if self.is_sa:
118
+ self.qkv = nn.Linear(dim_q, attention_dim * 3, bias=use_biases)
119
+ else:
120
+ self.q = nn.Linear(dim_q, attention_dim, bias=use_biases)
121
+ self.kv = nn.Linear(dim_kv, attention_dim * 2, bias=use_biases)
122
+ self.out = nn.Linear(attention_dim, dim_q, bias=use_biases)
123
+
124
+ def forward(self, x_to, x_from=None, attention_mask=None):
125
+ if x_from is None:
126
+ x_from = x_to
127
+ if self.is_sa:
128
+ q, k, v = self.qkv(x_to).chunk(3, dim=-1)
129
+ else:
130
+ q = self.q(x_to)
131
+ k, v = self.kv(x_from).chunk(2, dim=-1)
132
+ q = rearrange(q, "b n (h d) -> b h n d", h=self.num_heads)
133
+ k = rearrange(k, "b n (h d) -> b h n d", h=self.num_heads)
134
+ v = rearrange(v, "b n (h d) -> b h n d", h=self.num_heads)
135
+ if attention_mask is not None:
136
+ attention_mask = attention_mask.unsqueeze(1)
137
+ x = torch.nn.functional.scaled_dot_product_attention(
138
+ q, k, v, attn_mask=attention_mask
139
+ )
140
+ x = rearrange(x, "b h n d -> b n (h d)")
141
+ x = self.out(x)
142
+ return x
143
+
144
+
145
+ class CrossAttentionBlock(nn.Module):
146
+ def __init__(
147
+ self,
148
+ dim_q: int,
149
+ dim_kv: int,
150
+ num_heads: int,
151
+ attention_dim: int = 0,
152
+ mlp_multiplier: int = 4,
153
+ dropout: float = 0.0,
154
+ stochastic_depth: float = 0.0,
155
+ use_biases: bool = True,
156
+ retrieve_attention_scores: bool = False,
157
+ use_layernorm16: bool = True,
158
+ ):
159
+ super().__init__()
160
+ layer_norm = (
161
+ nn.LayerNorm
162
+ if not use_layernorm16 or retrieve_attention_scores
163
+ else LayerNorm16Bits
164
+ )
165
+ self.retrieve_attention_scores = retrieve_attention_scores
166
+ self.initial_to_ln = layer_norm(dim_q, eps=1e-6)
167
+ attention_dim = min(dim_q, dim_kv) if attention_dim == 0 else attention_dim
168
+ self.ca = CrossAttentionOp(
169
+ attention_dim, num_heads, dim_q, dim_kv, is_sa=False, use_biases=use_biases
170
+ )
171
+ self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
172
+ self.middle_ln = layer_norm(dim_q, eps=1e-6)
173
+ self.ffn = FusedMLP(
174
+ dim_model=dim_q,
175
+ dropout=dropout,
176
+ activation=nn.GELU,
177
+ hidden_layer_multiplier=mlp_multiplier,
178
+ bias=use_biases,
179
+ )
180
+ self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
181
+
182
+ def forward(
183
+ self,
184
+ to_tokens: Tensor,
185
+ from_tokens: Tensor,
186
+ to_token_mask: Optional[Tensor] = None,
187
+ from_token_mask: Optional[Tensor] = None,
188
+ ) -> Tensor:
189
+ if to_token_mask is None and from_token_mask is None:
190
+ attention_mask = None
191
+ else:
192
+ if to_token_mask is None:
193
+ to_token_mask = torch.ones(
194
+ to_tokens.shape[0],
195
+ to_tokens.shape[1],
196
+ dtype=torch.bool,
197
+ device=to_tokens.device,
198
+ )
199
+ if from_token_mask is None:
200
+ from_token_mask = torch.ones(
201
+ from_tokens.shape[0],
202
+ from_tokens.shape[1],
203
+ dtype=torch.bool,
204
+ device=from_tokens.device,
205
+ )
206
+ attention_mask = from_token_mask.unsqueeze(1) * to_token_mask.unsqueeze(2)
207
+ attention_output = self.ca(
208
+ self.initial_to_ln(to_tokens),
209
+ from_tokens,
210
+ attention_mask=attention_mask,
211
+ )
212
+ to_tokens = to_tokens + self.ca_stochastic_depth(attention_output)
213
+ to_tokens = to_tokens + self.ffn_stochastic_depth(
214
+ self.ffn(self.middle_ln(to_tokens))
215
+ )
216
+ return to_tokens
217
+
218
+
219
+ class SelfAttentionBlock(nn.Module):
220
+ def __init__(
221
+ self,
222
+ dim_qkv: int,
223
+ num_heads: int,
224
+ attention_dim: int = 0,
225
+ mlp_multiplier: int = 4,
226
+ dropout: float = 0.0,
227
+ stochastic_depth: float = 0.0,
228
+ use_biases: bool = True,
229
+ use_layer_scale: bool = False,
230
+ layer_scale_value: float = 0.0,
231
+ use_layernorm16: bool = True,
232
+ ):
233
+ super().__init__()
234
+ layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
235
+ self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
236
+ attention_dim = dim_qkv if attention_dim == 0 else attention_dim
237
+ self.sa = CrossAttentionOp(
238
+ attention_dim,
239
+ num_heads,
240
+ dim_qkv,
241
+ dim_qkv,
242
+ is_sa=True,
243
+ use_biases=use_biases,
244
+ )
245
+ self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
246
+ self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
247
+ self.ffn = FusedMLP(
248
+ dim_model=dim_qkv,
249
+ dropout=dropout,
250
+ activation=nn.GELU,
251
+ hidden_layer_multiplier=mlp_multiplier,
252
+ bias=use_biases,
253
+ )
254
+ self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
255
+ self.use_layer_scale = use_layer_scale
256
+ if use_layer_scale:
257
+ self.layer_scale_1 = nn.Parameter(
258
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
259
+ )
260
+ self.layer_scale_2 = nn.Parameter(
261
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
262
+ )
263
+
264
+ def forward(
265
+ self,
266
+ tokens: torch.Tensor,
267
+ token_mask: Optional[torch.Tensor] = None,
268
+ ):
269
+ if token_mask is None:
270
+ attention_mask = None
271
+ else:
272
+ attention_mask = token_mask.unsqueeze(1) * torch.ones(
273
+ tokens.shape[0],
274
+ tokens.shape[1],
275
+ 1,
276
+ dtype=torch.bool,
277
+ device=tokens.device,
278
+ )
279
+ attention_output = self.sa(
280
+ self.initial_ln(tokens),
281
+ attention_mask=attention_mask,
282
+ )
283
+ if self.use_layer_scale:
284
+ tokens = tokens + self.sa_stochastic_depth(
285
+ self.layer_scale_1 * attention_output
286
+ )
287
+ tokens = tokens + self.ffn_stochastic_depth(
288
+ self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
289
+ )
290
+ else:
291
+ tokens = tokens + self.sa_stochastic_depth(attention_output)
292
+ tokens = tokens + self.ffn_stochastic_depth(
293
+ self.ffn(self.middle_ln(tokens))
294
+ )
295
+ return tokens
296
+
297
+
298
+ class AdaLNSABlock(nn.Module):
299
+ def __init__(
300
+ self,
301
+ dim_qkv: int,
302
+ dim_cond: int,
303
+ num_heads: int,
304
+ attention_dim: int = 0,
305
+ mlp_multiplier: int = 4,
306
+ dropout: float = 0.0,
307
+ stochastic_depth: float = 0.0,
308
+ use_biases: bool = True,
309
+ use_layer_scale: bool = False,
310
+ layer_scale_value: float = 0.1,
311
+ use_layernorm16: bool = True,
312
+ ):
313
+ super().__init__()
314
+ layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
315
+ self.initial_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False)
316
+ attention_dim = dim_qkv if attention_dim == 0 else attention_dim
317
+ self.adaln_modulation = nn.Sequential(
318
+ nn.SiLU(),
319
+ nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases),
320
+ )
321
+ # Zero init
322
+ nn.init.zeros_(self.adaln_modulation[1].weight)
323
+ nn.init.zeros_(self.adaln_modulation[1].bias)
324
+
325
+ self.sa = CrossAttentionOp(
326
+ attention_dim,
327
+ num_heads,
328
+ dim_qkv,
329
+ dim_qkv,
330
+ is_sa=True,
331
+ use_biases=use_biases,
332
+ )
333
+ self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
334
+ self.middle_ln = layer_norm(dim_qkv, eps=1e-6, elementwise_affine=False)
335
+ self.ffn = FusedMLP(
336
+ dim_model=dim_qkv,
337
+ dropout=dropout,
338
+ activation=nn.GELU,
339
+ hidden_layer_multiplier=mlp_multiplier,
340
+ bias=use_biases,
341
+ )
342
+ self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
343
+ self.use_layer_scale = use_layer_scale
344
+ if use_layer_scale:
345
+ self.layer_scale_1 = nn.Parameter(
346
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
347
+ )
348
+ self.layer_scale_2 = nn.Parameter(
349
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
350
+ )
351
+
352
+ def forward(
353
+ self,
354
+ tokens: torch.Tensor,
355
+ cond: torch.Tensor,
356
+ token_mask: Optional[torch.Tensor] = None,
357
+ ):
358
+ if token_mask is None:
359
+ attention_mask = None
360
+ else:
361
+ attention_mask = token_mask.unsqueeze(1) * torch.ones(
362
+ tokens.shape[0],
363
+ tokens.shape[1],
364
+ 1,
365
+ dtype=torch.bool,
366
+ device=tokens.device,
367
+ )
368
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
369
+ self.adaln_modulation(cond).chunk(6, dim=-1)
370
+ )
371
+ attention_output = self.sa(
372
+ modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa),
373
+ attention_mask=attention_mask,
374
+ )
375
+ if self.use_layer_scale:
376
+ tokens = tokens + self.sa_stochastic_depth(
377
+ gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output
378
+ )
379
+ tokens = tokens + self.ffn_stochastic_depth(
380
+ gate_mlp.unsqueeze(1)
381
+ * self.layer_scale_2
382
+ * self.ffn(
383
+ modulate_shift_and_scale(
384
+ self.middle_ln(tokens), shift_mlp, scale_mlp
385
+ )
386
+ )
387
+ )
388
+ else:
389
+ tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth(
390
+ attention_output
391
+ )
392
+ tokens = tokens + self.ffn_stochastic_depth(
393
+ gate_mlp.unsqueeze(1)
394
+ * self.ffn(
395
+ modulate_shift_and_scale(
396
+ self.middle_ln(tokens), shift_mlp, scale_mlp
397
+ )
398
+ )
399
+ )
400
+ return tokens
401
+
402
+
403
+ class CrossAttentionSABlock(nn.Module):
404
+ def __init__(
405
+ self,
406
+ dim_qkv: int,
407
+ dim_cond: int,
408
+ num_heads: int,
409
+ attention_dim: int = 0,
410
+ mlp_multiplier: int = 4,
411
+ dropout: float = 0.0,
412
+ stochastic_depth: float = 0.0,
413
+ use_biases: bool = True,
414
+ use_layer_scale: bool = False,
415
+ layer_scale_value: float = 0.0,
416
+ use_layernorm16: bool = True,
417
+ ):
418
+ super().__init__()
419
+ layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
420
+ attention_dim = dim_qkv if attention_dim == 0 else attention_dim
421
+ self.ca = CrossAttentionOp(
422
+ attention_dim,
423
+ num_heads,
424
+ dim_qkv,
425
+ dim_cond,
426
+ is_sa=False,
427
+ use_biases=use_biases,
428
+ )
429
+ self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
430
+ self.ca_ln = layer_norm(dim_qkv, eps=1e-6)
431
+
432
+ self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
433
+ attention_dim = dim_qkv if attention_dim == 0 else attention_dim
434
+
435
+ self.sa = CrossAttentionOp(
436
+ attention_dim,
437
+ num_heads,
438
+ dim_qkv,
439
+ dim_qkv,
440
+ is_sa=True,
441
+ use_biases=use_biases,
442
+ )
443
+ self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
444
+ self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
445
+ self.ffn = FusedMLP(
446
+ dim_model=dim_qkv,
447
+ dropout=dropout,
448
+ activation=nn.GELU,
449
+ hidden_layer_multiplier=mlp_multiplier,
450
+ bias=use_biases,
451
+ )
452
+ self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
453
+ self.use_layer_scale = use_layer_scale
454
+ if use_layer_scale:
455
+ self.layer_scale_1 = nn.Parameter(
456
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
457
+ )
458
+ self.layer_scale_2 = nn.Parameter(
459
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
460
+ )
461
+
462
+ def forward(
463
+ self,
464
+ tokens: torch.Tensor,
465
+ cond: torch.Tensor,
466
+ token_mask: Optional[torch.Tensor] = None,
467
+ cond_mask: Optional[torch.Tensor] = None,
468
+ ):
469
+ if cond_mask is None:
470
+ cond_attention_mask = None
471
+ else:
472
+ cond_attention_mask = torch.ones(
473
+ cond.shape[0],
474
+ 1,
475
+ cond.shape[1],
476
+ dtype=torch.bool,
477
+ device=tokens.device,
478
+ ) * token_mask.unsqueeze(2)
479
+ if token_mask is None:
480
+ attention_mask = None
481
+ else:
482
+ attention_mask = token_mask.unsqueeze(1) * torch.ones(
483
+ tokens.shape[0],
484
+ tokens.shape[1],
485
+ 1,
486
+ dtype=torch.bool,
487
+ device=tokens.device,
488
+ )
489
+ ca_output = self.ca(
490
+ self.ca_ln(tokens),
491
+ cond,
492
+ attention_mask=cond_attention_mask,
493
+ )
494
+ ca_output = torch.nan_to_num(
495
+ ca_output, nan=0.0, posinf=0.0, neginf=0.0
496
+ ) # Needed as some tokens get attention from no token so Nan
497
+ tokens = tokens + self.ca_stochastic_depth(ca_output)
498
+ attention_output = self.sa(
499
+ self.initial_ln(tokens),
500
+ attention_mask=attention_mask,
501
+ )
502
+ if self.use_layer_scale:
503
+ tokens = tokens + self.sa_stochastic_depth(
504
+ self.layer_scale_1 * attention_output
505
+ )
506
+ tokens = tokens + self.ffn_stochastic_depth(
507
+ self.layer_scale_2 * self.ffn(self.middle_ln(tokens))
508
+ )
509
+ else:
510
+ tokens = tokens + self.sa_stochastic_depth(attention_output)
511
+ tokens = tokens + self.ffn_stochastic_depth(
512
+ self.ffn(self.middle_ln(tokens))
513
+ )
514
+ return tokens
515
+
516
+
517
+ class CAAdaLNSABlock(nn.Module):
518
+ def __init__(
519
+ self,
520
+ dim_qkv: int,
521
+ dim_cond: int,
522
+ num_heads: int,
523
+ attention_dim: int = 0,
524
+ mlp_multiplier: int = 4,
525
+ dropout: float = 0.0,
526
+ stochastic_depth: float = 0.0,
527
+ use_biases: bool = True,
528
+ use_layer_scale: bool = False,
529
+ layer_scale_value: float = 0.1,
530
+ use_layernorm16: bool = True,
531
+ ):
532
+ super().__init__()
533
+ layer_norm = LayerNorm16Bits if use_layernorm16 else nn.LayerNorm
534
+ self.ca = CrossAttentionOp(
535
+ attention_dim,
536
+ num_heads,
537
+ dim_qkv,
538
+ dim_cond,
539
+ is_sa=False,
540
+ use_biases=use_biases,
541
+ )
542
+ self.ca_stochastic_depth = StochatichDepth(stochastic_depth)
543
+ self.ca_ln = layer_norm(dim_qkv, eps=1e-6)
544
+ self.initial_ln = layer_norm(dim_qkv, eps=1e-6)
545
+ attention_dim = dim_qkv if attention_dim == 0 else attention_dim
546
+ self.adaln_modulation = nn.Sequential(
547
+ nn.SiLU(),
548
+ nn.Linear(dim_cond, dim_qkv * 6, bias=use_biases),
549
+ )
550
+ # Zero init
551
+ nn.init.zeros_(self.adaln_modulation[1].weight)
552
+ nn.init.zeros_(self.adaln_modulation[1].bias)
553
+
554
+ self.sa = CrossAttentionOp(
555
+ attention_dim,
556
+ num_heads,
557
+ dim_qkv,
558
+ dim_qkv,
559
+ is_sa=True,
560
+ use_biases=use_biases,
561
+ )
562
+ self.sa_stochastic_depth = StochatichDepth(stochastic_depth)
563
+ self.middle_ln = layer_norm(dim_qkv, eps=1e-6)
564
+ self.ffn = FusedMLP(
565
+ dim_model=dim_qkv,
566
+ dropout=dropout,
567
+ activation=nn.GELU,
568
+ hidden_layer_multiplier=mlp_multiplier,
569
+ bias=use_biases,
570
+ )
571
+ self.ffn_stochastic_depth = StochatichDepth(stochastic_depth)
572
+ self.use_layer_scale = use_layer_scale
573
+ if use_layer_scale:
574
+ self.layer_scale_1 = nn.Parameter(
575
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
576
+ )
577
+ self.layer_scale_2 = nn.Parameter(
578
+ torch.ones(dim_qkv) * layer_scale_value, requires_grad=True
579
+ )
580
+
581
+ def forward(
582
+ self,
583
+ tokens: torch.Tensor,
584
+ cond_1: torch.Tensor,
585
+ cond_2: torch.Tensor,
586
+ cond_1_mask: Optional[torch.Tensor] = None,
587
+ token_mask: Optional[torch.Tensor] = None,
588
+ ):
589
+ if token_mask is None and cond_1_mask is None:
590
+ cond_attention_mask = None
591
+ elif token_mask is None:
592
+ cond_attention_mask = cond_1_mask.unsqueeze(1) * torch.ones(
593
+ cond_1.shape[0],
594
+ cond_1.shape[1],
595
+ 1,
596
+ dtype=torch.bool,
597
+ device=cond_1.device,
598
+ )
599
+ elif cond_1_mask is None:
600
+ cond_attention_mask = torch.ones(
601
+ tokens.shape[0],
602
+ 1,
603
+ tokens.shape[1],
604
+ dtype=torch.bool,
605
+ device=tokens.device,
606
+ ) * token_mask.unsqueeze(2)
607
+ else:
608
+ cond_attention_mask = cond_1_mask.unsqueeze(1) * token_mask.unsqueeze(2)
609
+ if token_mask is None:
610
+ attention_mask = None
611
+ else:
612
+ attention_mask = token_mask.unsqueeze(1) * torch.ones(
613
+ tokens.shape[0],
614
+ tokens.shape[1],
615
+ 1,
616
+ dtype=torch.bool,
617
+ device=tokens.device,
618
+ )
619
+ ca_output = self.ca(
620
+ self.ca_ln(tokens),
621
+ cond_1,
622
+ attention_mask=cond_attention_mask,
623
+ )
624
+ ca_output = torch.nan_to_num(ca_output, nan=0.0, posinf=0.0, neginf=0.0)
625
+ tokens = tokens + self.ca_stochastic_depth(ca_output)
626
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
627
+ self.adaln_modulation(cond_2).chunk(6, dim=-1)
628
+ )
629
+ attention_output = self.sa(
630
+ modulate_shift_and_scale(self.initial_ln(tokens), shift_msa, scale_msa),
631
+ attention_mask=attention_mask,
632
+ )
633
+ if self.use_layer_scale:
634
+ tokens = tokens + self.sa_stochastic_depth(
635
+ gate_msa.unsqueeze(1) * self.layer_scale_1 * attention_output
636
+ )
637
+ tokens = tokens + self.ffn_stochastic_depth(
638
+ gate_mlp.unsqueeze(1)
639
+ * self.layer_scale_2
640
+ * self.ffn(
641
+ modulate_shift_and_scale(
642
+ self.middle_ln(tokens), shift_mlp, scale_mlp
643
+ )
644
+ )
645
+ )
646
+ else:
647
+ tokens = tokens + gate_msa.unsqueeze(1) * self.sa_stochastic_depth(
648
+ attention_output
649
+ )
650
+ tokens = tokens + self.ffn_stochastic_depth(
651
+ gate_mlp.unsqueeze(1)
652
+ * self.ffn(
653
+ modulate_shift_and_scale(
654
+ self.middle_ln(tokens), shift_mlp, scale_mlp
655
+ )
656
+ )
657
+ )
658
+ return tokens
659
+
660
+
661
+ class PositionalEmbedding(nn.Module):
662
+ """
663
+ Taken from https://github.com/NVlabs/edm
664
+ """
665
+
666
+ def __init__(self, num_channels, max_positions=10000, endpoint=False):
667
+ super().__init__()
668
+ self.num_channels = num_channels
669
+ self.max_positions = max_positions
670
+ self.endpoint = endpoint
671
+ freqs = torch.arange(start=0, end=self.num_channels // 2, dtype=torch.float32)
672
+ freqs = 2 * freqs / self.num_channels
673
+ freqs = (1 / self.max_positions) ** freqs
674
+ self.register_buffer("freqs", freqs)
675
+
676
+ def forward(self, x):
677
+ x = torch.outer(x, self.freqs)
678
+ out = torch.cat([x.cos(), x.sin()], dim=1)
679
+ return out.to(x.dtype)
680
+
681
+
682
+ class PositionalEncoding(nn.Module):
683
+ def __init__(self, d_model, dropout=0.0, max_len=10000):
684
+ super().__init__()
685
+ self.dropout = nn.Dropout(p=dropout)
686
+
687
+ pe = torch.zeros(max_len, d_model)
688
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
689
+ div_term = torch.exp(
690
+ torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
691
+ )
692
+ pe[:, 0::2] = torch.sin(position * div_term)
693
+ pe[:, 1::2] = torch.cos(position * div_term)
694
+ pe = pe.unsqueeze(0)
695
+
696
+ self.register_buffer("pe", pe)
697
+
698
+ def forward(self, x):
699
+ # not used in the final model
700
+ x = x + self.pe[:, : x.shape[1], :]
701
+ return self.dropout(x)
702
+
703
+
704
+ class TimeEmbedder(nn.Module):
705
+ def __init__(
706
+ self,
707
+ dim: int,
708
+ time_scaling: float,
709
+ expansion: int = 4,
710
+ ):
711
+ super().__init__()
712
+ self.encode_time = PositionalEmbedding(num_channels=dim, endpoint=True)
713
+
714
+ self.time_scaling = time_scaling
715
+ self.map_time = nn.Sequential(
716
+ nn.Linear(dim, dim * expansion),
717
+ nn.SiLU(),
718
+ nn.Linear(dim * expansion, dim * expansion),
719
+ )
720
+
721
+ def forward(self, t: Tensor) -> Tensor:
722
+ time = self.encode_time(t * self.time_scaling)
723
+ time_mean = time.mean(dim=-1, keepdim=True)
724
+ time_std = time.std(dim=-1, keepdim=True)
725
+ time = (time - time_mean) / time_std
726
+ return self.map_time(time)
727
+
728
+
729
+ def modulate_shift_and_scale(x: Tensor, shift: Tensor, scale: Tensor) -> Tensor:
730
+ return x * (1 + scale).unsqueeze(1) + shift.unsqueeze(1)
731
+
732
+
733
+ # ------------------------------------------------------------------------------------- #
734
+
735
+
736
+ class BaseDirector(nn.Module):
737
+ def __init__(
738
+ self,
739
+ name: str,
740
+ num_feats: int,
741
+ num_cond_feats: int,
742
+ num_cams: int,
743
+ latent_dim: int,
744
+ mlp_multiplier: int,
745
+ num_layers: int,
746
+ num_heads: int,
747
+ dropout: float,
748
+ stochastic_depth: float,
749
+ label_dropout: float,
750
+ num_rawfeats: int,
751
+ clip_sequential: bool = False,
752
+ cond_sequential: bool = False,
753
+ device: str = "cuda",
754
+ **kwargs,
755
+ ):
756
+ super().__init__()
757
+ self.name = name
758
+ self.label_dropout = label_dropout
759
+ self.num_rawfeats = num_rawfeats
760
+ self.num_feats = num_feats
761
+ self.num_cams = num_cams
762
+ self.clip_sequential = clip_sequential
763
+ self.cond_sequential = cond_sequential
764
+ self.use_layernorm16 = device == "cuda"
765
+
766
+ self.input_projection = nn.Sequential(
767
+ nn.Linear(num_feats, latent_dim),
768
+ PositionalEncoding(latent_dim),
769
+ )
770
+ self.time_embedding = TimeEmbedder(latent_dim // 4, time_scaling=1000)
771
+ self.init_conds_mappings(num_cond_feats, latent_dim)
772
+ self.init_backbone(
773
+ num_layers, latent_dim, mlp_multiplier, num_heads, dropout, stochastic_depth
774
+ )
775
+ self.init_output_projection(num_feats, latent_dim)
776
+
777
+ def forward(
778
+ self,
779
+ x: Tensor,
780
+ timesteps: Tensor,
781
+ y: List[Tensor] = None,
782
+ mask: Tensor = None,
783
+ ) -> Tensor:
784
+ mask = mask.logical_not() if mask is not None else None
785
+ x = rearrange(x, "b c n -> b n c")
786
+ x = self.input_projection(x)
787
+ t = self.time_embedding(timesteps)
788
+ if y is not None:
789
+ y = self.mask_cond(y)
790
+ y = self.cond_mapping(y, mask, t)
791
+
792
+ x = self.backbone(x, y, mask)
793
+ x = self.output_projection(x, y)
794
+ return rearrange(x, "b n c -> b c n")
795
+
796
+ def init_conds_mappings(self, num_cond_feats, latent_dim):
797
+ raise NotImplementedError(
798
+ "This method should be implemented in the derived class"
799
+ )
800
+
801
+ def init_backbone(self):
802
+ raise NotImplementedError(
803
+ "This method should be implemented in the derived class"
804
+ )
805
+
806
+ def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
807
+ raise NotImplementedError(
808
+ "This method should be implemented in the derived class"
809
+ )
810
+
811
+ def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
812
+ raise NotImplementedError(
813
+ "This method should be implemented in the derived class"
814
+ )
815
+
816
+ def mask_cond(
817
+ self, cond: List[TensorType["batch_size", "num_cond_feats"]]
818
+ ) -> TensorType["batch_size", "num_cond_feats"]:
819
+ bs = cond[0].shape[0]
820
+ if self.training and self.label_dropout > 0.0:
821
+ # 1-> use null_cond, 0-> use real cond
822
+ prob = torch.ones(bs, device=cond[0].device) * self.label_dropout
823
+ masked_cond = []
824
+ common_mask = torch.bernoulli(prob) # Common to all modalities
825
+ for _cond in cond:
826
+ modality_mask = torch.bernoulli(prob) # Modality only
827
+ mask = torch.clip(common_mask + modality_mask, 0, 1)
828
+ mask = mask.view(bs, 1, 1) if _cond.dim() == 3 else mask.view(bs, 1)
829
+ masked_cond.append(_cond * (1.0 - mask))
830
+ return masked_cond
831
+ else:
832
+ return cond
833
+
834
+ def init_output_projection(self, num_feats, latent_dim):
835
+ raise NotImplementedError(
836
+ "This method should be implemented in the derived class"
837
+ )
838
+
839
+ def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
840
+ raise NotImplementedError(
841
+ "This method should be implemented in the derived class"
842
+ )
843
+
844
+
845
+ class AdaLNDirector(BaseDirector):
846
+ def __init__(
847
+ self,
848
+ name: str,
849
+ num_feats: int,
850
+ num_cond_feats: int,
851
+ num_cams: int,
852
+ latent_dim: int,
853
+ mlp_multiplier: int,
854
+ num_layers: int,
855
+ num_heads: int,
856
+ dropout: float,
857
+ stochastic_depth: float,
858
+ label_dropout: float,
859
+ num_rawfeats: int,
860
+ clip_sequential: bool = False,
861
+ cond_sequential: bool = False,
862
+ device: str = "cuda",
863
+ **kwargs,
864
+ ):
865
+ super().__init__(
866
+ name=name,
867
+ num_feats=num_feats,
868
+ num_cond_feats=num_cond_feats,
869
+ num_cams=num_cams,
870
+ latent_dim=latent_dim,
871
+ mlp_multiplier=mlp_multiplier,
872
+ num_layers=num_layers,
873
+ num_heads=num_heads,
874
+ dropout=dropout,
875
+ stochastic_depth=stochastic_depth,
876
+ label_dropout=label_dropout,
877
+ num_rawfeats=num_rawfeats,
878
+ clip_sequential=clip_sequential,
879
+ cond_sequential=cond_sequential,
880
+ device=device,
881
+ )
882
+ assert not (clip_sequential and cond_sequential)
883
+
884
+ def init_conds_mappings(self, num_cond_feats, latent_dim):
885
+ self.joint_cond_projection = nn.Linear(sum(num_cond_feats), latent_dim)
886
+
887
+ def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
888
+ c_emb = torch.cat(cond, dim=-1)
889
+ return self.joint_cond_projection(c_emb) + t
890
+
891
+ def init_backbone(
892
+ self,
893
+ num_layers,
894
+ latent_dim,
895
+ mlp_multiplier,
896
+ num_heads,
897
+ dropout,
898
+ stochastic_depth,
899
+ ):
900
+ self.backbone_module = nn.ModuleList(
901
+ [
902
+ AdaLNSABlock(
903
+ dim_qkv=latent_dim,
904
+ dim_cond=latent_dim,
905
+ num_heads=num_heads,
906
+ mlp_multiplier=mlp_multiplier,
907
+ dropout=dropout,
908
+ stochastic_depth=stochastic_depth,
909
+ use_layernorm16=self.use_layernorm16,
910
+ )
911
+ for _ in range(num_layers)
912
+ ]
913
+ )
914
+
915
+ def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
916
+ for block in self.backbone_module:
917
+ x = block(x, y, mask)
918
+ return x
919
+
920
+ def init_output_projection(self, num_feats, latent_dim):
921
+ layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
922
+
923
+ self.final_norm = layer_norm(latent_dim, eps=1e-6, elementwise_affine=False)
924
+ self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
925
+ self.final_adaln = nn.Sequential(
926
+ nn.SiLU(),
927
+ nn.Linear(latent_dim, latent_dim * 2, bias=True),
928
+ )
929
+ # Zero init
930
+ nn.init.zeros_(self.final_adaln[1].weight)
931
+ nn.init.zeros_(self.final_adaln[1].bias)
932
+
933
+ def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
934
+ shift, scale = self.final_adaln(y).chunk(2, dim=-1)
935
+ x = modulate_shift_and_scale(self.final_norm(x), shift, scale)
936
+ return self.final_linear(x)
937
+
938
+
939
+ class CrossAttentionDirector(BaseDirector):
940
+ def __init__(
941
+ self,
942
+ name: str,
943
+ num_feats: int,
944
+ num_cond_feats: int,
945
+ num_cams: int,
946
+ latent_dim: int,
947
+ mlp_multiplier: int,
948
+ num_layers: int,
949
+ num_heads: int,
950
+ dropout: float,
951
+ stochastic_depth: float,
952
+ label_dropout: float,
953
+ num_rawfeats: int,
954
+ num_text_registers: int,
955
+ clip_sequential: bool = True,
956
+ cond_sequential: bool = True,
957
+ device: str = "cuda",
958
+ **kwargs,
959
+ ):
960
+ self.num_text_registers = num_text_registers
961
+ self.num_heads = num_heads
962
+ self.dropout = dropout
963
+ self.mlp_multiplier = mlp_multiplier
964
+ self.stochastic_depth = stochastic_depth
965
+ super().__init__(
966
+ name=name,
967
+ num_feats=num_feats,
968
+ num_cond_feats=num_cond_feats,
969
+ num_cams=num_cams,
970
+ latent_dim=latent_dim,
971
+ mlp_multiplier=mlp_multiplier,
972
+ num_layers=num_layers,
973
+ num_heads=num_heads,
974
+ dropout=dropout,
975
+ stochastic_depth=stochastic_depth,
976
+ label_dropout=label_dropout,
977
+ num_rawfeats=num_rawfeats,
978
+ clip_sequential=clip_sequential,
979
+ cond_sequential=cond_sequential,
980
+ device=device,
981
+ )
982
+ assert clip_sequential and cond_sequential
983
+
984
+ def init_conds_mappings(self, num_cond_feats, latent_dim):
985
+ self.cond_projection = nn.ModuleList(
986
+ [nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats]
987
+ )
988
+ self.cond_registers = nn.Parameter(
989
+ torch.randn(self.num_text_registers, latent_dim), requires_grad=True
990
+ )
991
+ nn.init.trunc_normal_(self.cond_registers, std=0.02, a=-2 * 0.02, b=2 * 0.02)
992
+ self.cond_sa = nn.ModuleList(
993
+ [
994
+ SelfAttentionBlock(
995
+ dim_qkv=latent_dim,
996
+ num_heads=self.num_heads,
997
+ mlp_multiplier=self.mlp_multiplier,
998
+ dropout=self.dropout,
999
+ stochastic_depth=self.stochastic_depth,
1000
+ use_layernorm16=self.use_layernorm16,
1001
+ )
1002
+ for _ in range(2)
1003
+ ]
1004
+ )
1005
+ self.cond_positional_embedding = PositionalEncoding(latent_dim, max_len=10000)
1006
+
1007
+ def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
1008
+ batch_size = cond[0].shape[0]
1009
+ cond_emb = [
1010
+ cond_proj(rearrange(c, "b c n -> b n c"))
1011
+ for cond_proj, c in zip(self.cond_projection, cond)
1012
+ ]
1013
+ cond_emb = [
1014
+ self.cond_registers.unsqueeze(0).expand(batch_size, -1, -1),
1015
+ t.unsqueeze(1),
1016
+ ] + cond_emb
1017
+ cond_emb = torch.cat(cond_emb, dim=1)
1018
+ cond_emb = self.cond_positional_embedding(cond_emb)
1019
+ for block in self.cond_sa:
1020
+ cond_emb = block(cond_emb)
1021
+ return cond_emb
1022
+
1023
+ def init_backbone(
1024
+ self,
1025
+ num_layers,
1026
+ latent_dim,
1027
+ mlp_multiplier,
1028
+ num_heads,
1029
+ dropout,
1030
+ stochastic_depth,
1031
+ ):
1032
+ self.backbone_module = nn.ModuleList(
1033
+ [
1034
+ CrossAttentionSABlock(
1035
+ dim_qkv=latent_dim,
1036
+ dim_cond=latent_dim,
1037
+ num_heads=num_heads,
1038
+ mlp_multiplier=mlp_multiplier,
1039
+ dropout=dropout,
1040
+ stochastic_depth=stochastic_depth,
1041
+ use_layernorm16=self.use_layernorm16,
1042
+ )
1043
+ for _ in range(num_layers)
1044
+ ]
1045
+ )
1046
+
1047
+ def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
1048
+ for block in self.backbone_module:
1049
+ x = block(x, y, mask, None)
1050
+ return x
1051
+
1052
+ def init_output_projection(self, num_feats, latent_dim):
1053
+ layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
1054
+
1055
+ self.final_norm = layer_norm(latent_dim, eps=1e-6)
1056
+ self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
1057
+
1058
+ def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
1059
+ return self.final_linear(self.final_norm(x))
1060
+
1061
+
1062
+ class InContextDirector(BaseDirector):
1063
+ def __init__(
1064
+ self,
1065
+ name: str,
1066
+ num_feats: int,
1067
+ num_cond_feats: int,
1068
+ num_cams: int,
1069
+ latent_dim: int,
1070
+ mlp_multiplier: int,
1071
+ num_layers: int,
1072
+ num_heads: int,
1073
+ dropout: float,
1074
+ stochastic_depth: float,
1075
+ label_dropout: float,
1076
+ num_rawfeats: int,
1077
+ clip_sequential: bool = False,
1078
+ cond_sequential: bool = False,
1079
+ device: str = "cuda",
1080
+ **kwargs,
1081
+ ):
1082
+ super().__init__(
1083
+ name=name,
1084
+ num_feats=num_feats,
1085
+ num_cond_feats=num_cond_feats,
1086
+ num_cams=num_cams,
1087
+ latent_dim=latent_dim,
1088
+ mlp_multiplier=mlp_multiplier,
1089
+ num_layers=num_layers,
1090
+ num_heads=num_heads,
1091
+ dropout=dropout,
1092
+ stochastic_depth=stochastic_depth,
1093
+ label_dropout=label_dropout,
1094
+ num_rawfeats=num_rawfeats,
1095
+ clip_sequential=clip_sequential,
1096
+ cond_sequential=cond_sequential,
1097
+ device=device,
1098
+ )
1099
+
1100
+ def init_conds_mappings(self, num_cond_feats, latent_dim):
1101
+ self.cond_projection = nn.ModuleList(
1102
+ [nn.Linear(num_cond_feat, latent_dim) for num_cond_feat in num_cond_feats]
1103
+ )
1104
+
1105
+ def cond_mapping(self, cond: List[Tensor], mask: Tensor, t: Tensor) -> Tensor:
1106
+ for i in range(len(cond)):
1107
+ if cond[i].dim() == 3:
1108
+ cond[i] = rearrange(cond[i], "b c n -> b n c")
1109
+ cond_emb = [cond_proj(c) for cond_proj, c in zip(self.cond_projection, cond)]
1110
+ cond_emb = [c.unsqueeze(1) if c.dim() == 2 else cond_emb for c in cond_emb]
1111
+ cond_emb = torch.cat([t.unsqueeze(1)] + cond_emb, dim=1)
1112
+ return cond_emb
1113
+
1114
+ def init_backbone(
1115
+ self,
1116
+ num_layers,
1117
+ latent_dim,
1118
+ mlp_multiplier,
1119
+ num_heads,
1120
+ dropout,
1121
+ stochastic_depth,
1122
+ ):
1123
+ self.backbone_module = nn.ModuleList(
1124
+ [
1125
+ SelfAttentionBlock(
1126
+ dim_qkv=latent_dim,
1127
+ num_heads=num_heads,
1128
+ mlp_multiplier=mlp_multiplier,
1129
+ dropout=dropout,
1130
+ stochastic_depth=stochastic_depth,
1131
+ use_layernorm16=self.use_layernorm16,
1132
+ )
1133
+ for _ in range(num_layers)
1134
+ ]
1135
+ )
1136
+
1137
+ def backbone(self, x: Tensor, y: Tensor, mask: Tensor) -> Tensor:
1138
+ bs, n_y, _ = y.shape
1139
+ mask = torch.cat([torch.ones(bs, n_y, device=y.device), mask], dim=1)
1140
+ x = torch.cat([y, x], dim=1)
1141
+ for block in self.backbone_module:
1142
+ x = block(x, mask)
1143
+ return x
1144
+
1145
+ def init_output_projection(self, num_feats, latent_dim):
1146
+ layer_norm = LayerNorm16Bits if self.use_layernorm16 else nn.LayerNorm
1147
+
1148
+ self.final_norm = layer_norm(latent_dim, eps=1e-6)
1149
+ self.final_linear = nn.Linear(latent_dim, num_feats, bias=True)
1150
+
1151
+ def output_projection(self, x: Tensor, y: Tensor) -> Tensor:
1152
+ num_y = y.shape[1]
1153
+ x = x[:, num_y:]
1154
+ return self.final_linear(self.final_norm(x))
src/models/networks.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ # ----------------------------------------------------------------------------
5
+ # Improved preconditioning proposed in the paper "Elucidating the Design
6
+ # Space of Diffusion-Based Generative Models" (EDM).
7
+
8
+
9
+ class RnEDMPrecond(nn.Module):
10
+ def __init__(self, sigma_data: float = 0.5, module: nn.Module = None, **kwargs):
11
+ super().__init__()
12
+ self.sigma_data = sigma_data
13
+
14
+ self.model = module
15
+ self.num_rawfeats = module.num_rawfeats
16
+ self.num_feats = module.num_feats
17
+ self.num_cams = module.num_cams
18
+
19
+ def forward(self, x, sigma, y=None, mask=None):
20
+ """
21
+ x: [batch_size, num_feats, max_frames], denoted x_t in the paper
22
+ sigma: [batch_size] (int)
23
+ """
24
+ sigma = sigma.reshape(-1, 1, 1)
25
+ c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2)
26
+ c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2).sqrt()
27
+ c_in = 1 / (self.sigma_data**2 + sigma**2).sqrt()
28
+ c_noise = sigma.log() / 4
29
+
30
+ F_x = self.model(c_in * x, c_noise.flatten(), y=y, mask=mask)
31
+ D_x = c_skip * x + c_out * F_x
32
+
33
+ return D_x
utils/common_viz.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict, List, Tuple
2
+
3
+ import clip
4
+ from hydra import compose, initialize
5
+ from hydra.utils import instantiate
6
+ from omegaconf import OmegaConf
7
+ import torch
8
+ from torchtyping import TensorType
9
+ from torch.utils.data import DataLoader
10
+ import torch.nn.functional as F
11
+
12
+ from src.diffuser import Diffuser
13
+ from src.datasets.multimodal_dataset import MultimodalDataset
14
+
15
+ # ------------------------------------------------------------------------------------- #
16
+
17
+ batch_size, context_length = None, None
18
+ collate_fn = DataLoader([]).collate_fn
19
+
20
+ # ------------------------------------------------------------------------------------- #
21
+
22
+
23
+ def to_device(batch: Dict[str, Any], device: torch.device) -> Dict[str, Any]:
24
+ for key, value in batch.items():
25
+ if isinstance(value, torch.Tensor):
26
+ batch[key] = value.to(device)
27
+ return batch
28
+
29
+
30
+ def load_clip_model(version: str, device: str) -> clip.model.CLIP:
31
+ model, _ = clip.load(version, device=device, jit=False)
32
+ model.eval()
33
+ for p in model.parameters():
34
+ p.requires_grad = False
35
+ return model
36
+
37
+
38
+ def encode_text(
39
+ caption_raws: List[str], # batch_size
40
+ clip_model: clip.model.CLIP,
41
+ max_token_length: int,
42
+ device: str,
43
+ ) -> TensorType["batch_size", "context_length"]:
44
+ if max_token_length is not None:
45
+ default_context_length = 77
46
+ context_length = max_token_length + 2 # start_token + 20 + end_token
47
+ assert context_length < default_context_length
48
+ # [bs, context_length] # if n_tokens > context_length -> will truncate
49
+ texts = clip.tokenize(
50
+ caption_raws, context_length=context_length, truncate=True
51
+ )
52
+ zero_pad = torch.zeros(
53
+ [texts.shape[0], default_context_length - context_length],
54
+ dtype=texts.dtype,
55
+ device=texts.device,
56
+ )
57
+ texts = torch.cat([texts, zero_pad], dim=1)
58
+ else:
59
+ # [bs, context_length] # if n_tokens > 77 -> will truncate
60
+ texts = clip.tokenize(caption_raws, truncate=True)
61
+
62
+ # [batch_size, n_ctx, d_model]
63
+ x = clip_model.token_embedding(texts.to(device)).type(clip_model.dtype)
64
+ x = x + clip_model.positional_embedding.type(clip_model.dtype)
65
+ x = x.permute(1, 0, 2) # NLD -> LND
66
+ x = clip_model.transformer(x)
67
+ x = x.permute(1, 0, 2) # LND -> NLD
68
+ x = clip_model.ln_final(x).type(clip_model.dtype)
69
+ # x.shape = [batch_size, n_ctx, transformer.width]
70
+ # take features from the eot embedding (eot_token is the highest in each sequence)
71
+ x_tokens = x[torch.arange(x.shape[0]), texts.argmax(dim=-1)].float()
72
+ x_seq = [x[k, : (m + 1)].float() for k, m in enumerate(texts.argmax(dim=-1))]
73
+
74
+ return x_seq, x_tokens
75
+
76
+
77
+ def get_batch(
78
+ prompt: str,
79
+ sample_id: str,
80
+ clip_model: clip.model.CLIP,
81
+ dataset: MultimodalDataset,
82
+ seq_feat: bool,
83
+ device: torch.device,
84
+ ) -> Dict[str, Any]:
85
+ # Get base batch
86
+ sample_index = dataset.root_filenames.index(sample_id)
87
+ raw_batch = dataset[sample_index]
88
+ batch = collate_fn([to_device(raw_batch, device)])
89
+
90
+ # Encode text
91
+ caption_seq, caption_tokens = encode_text([prompt], clip_model, None, device)
92
+
93
+ if seq_feat:
94
+ caption_feat = caption_seq[0]
95
+ caption_feat = F.pad(caption_feat, (0, 0, 0, 77 - caption_feat.shape[0]))
96
+ caption_feat = caption_feat.unsqueeze(0).permute(0, 2, 1)
97
+ else:
98
+ caption_feat = caption_tokens
99
+
100
+ # Update batch
101
+ batch["caption_raw"] = [prompt]
102
+ batch["caption_feat"] = caption_feat
103
+
104
+ return batch
105
+
106
+
107
+ def init(
108
+ config_name: str,
109
+ ) -> Tuple[Diffuser, clip.model.CLIP, MultimodalDataset, torch.device]:
110
+ with initialize(version_base="1.3", config_path="../configs"):
111
+ config = compose(config_name=config_name)
112
+
113
+ OmegaConf.register_new_resolver("eval", eval)
114
+
115
+ # Initialize model
116
+ device = torch.device(config.compnode.device)
117
+ diffuser = instantiate(config.diffuser)
118
+ state_dict = torch.load(config.checkpoint_path, map_location=device)["state_dict"]
119
+ state_dict["ema.initted"] = diffuser.ema.initted
120
+ state_dict["ema.step"] = diffuser.ema.step
121
+ diffuser.load_state_dict(state_dict, strict=False)
122
+ diffuser.to(device).eval()
123
+
124
+ # Initialize CLIP model
125
+ clip_model = load_clip_model("ViT-B/32", device)
126
+
127
+ # Initialize dataset
128
+ config.dataset.char.load_vertices = True
129
+ config.batch_size = 1
130
+ dataset = instantiate(config.dataset)
131
+ dataset.set_split("demo")
132
+ diffuser.modalities = list(dataset.modality_datasets.keys())
133
+ diffuser.get_matrix = dataset.get_matrix
134
+ diffuser.v_get_matrix = dataset.get_matrix
135
+
136
+ return diffuser, clip_model, dataset, device
utils/file_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ import pickle
5
+ import subprocess
6
+ from typing import Any
7
+
8
+ import h5py
9
+ import numpy as np
10
+ import pandas as pd
11
+ import torch
12
+ import torchaudio
13
+ from torchtyping import TensorType
14
+
15
+ num_channels, num_frames, height, width = None, None, None, None
16
+
17
+
18
+ def create_dir(dir_name: str):
19
+ """Create a directory if it does not exist yet."""
20
+ if not osp.exists(dir_name):
21
+ os.makedirs(dir_name)
22
+
23
+
24
+ def move_files(source_path: str, destpath: str):
25
+ """Move files from `source_path` to `dest_path`."""
26
+ subprocess.call(["mv", source_path, destpath])
27
+
28
+
29
+ def load_pickle(pickle_path: str) -> Any:
30
+ """Load a pickle file."""
31
+ with open(pickle_path, "rb") as f:
32
+ data = pickle.load(f)
33
+ return data
34
+
35
+
36
+ def load_hdf5(hdf5_path: str) -> Any:
37
+ with h5py.File(hdf5_path, "r") as h5file:
38
+ data = {key: np.array(value) for key, value in h5file.items()}
39
+ return data
40
+
41
+
42
+ def save_hdf5(data: Any, hdf5_path: str):
43
+ with h5py.File(hdf5_path, "w") as h5file:
44
+ for key, value in data.items():
45
+ h5file.create_dataset(key, data=value)
46
+
47
+
48
+ def save_pickle(data: Any, pickle_path: str):
49
+ """Save data in a pickle file."""
50
+ with open(pickle_path, "wb") as f:
51
+ pickle.dump(data, f, protocol=4)
52
+
53
+
54
+ def load_txt(txt_path: str):
55
+ """Load a txt file."""
56
+ with open(txt_path, "r") as f:
57
+ data = f.read()
58
+ return data
59
+
60
+
61
+ def save_txt(data: str, txt_path: str):
62
+ """Save data in a txt file."""
63
+ with open(txt_path, "w") as f:
64
+ f.write(data)
65
+
66
+
67
+ def load_pth(pth_path: str) -> Any:
68
+ """Load a pth (PyTorch) file."""
69
+ data = torch.load(pth_path)
70
+ return data
71
+
72
+
73
+ def save_pth(data: Any, pth_path: str):
74
+ """Save data in a pth (PyTorch) file."""
75
+ torch.save(data, pth_path)
76
+
77
+
78
+ def load_csv(csv_path: str, header: Any = None) -> pd.DataFrame:
79
+ """Load a csv file."""
80
+ try:
81
+ data = pd.read_csv(csv_path, header=header)
82
+ except pd.errors.EmptyDataError:
83
+ data = pd.DataFrame()
84
+ return data
85
+
86
+
87
+ def save_csv(data: Any, csv_path: str):
88
+ """Save data in a csv file."""
89
+ pd.DataFrame(data).to_csv(csv_path, header=False, index=False)
90
+
91
+
92
+ def load_json(json_path: str, header: Any = None) -> pd.DataFrame:
93
+ """Load a json file."""
94
+ with open(json_path, "r") as f:
95
+ data = json.load(f)
96
+ return data
97
+
98
+
99
+ def save_json(data: Any, json_path: str):
100
+ """Save data in a json file."""
101
+ with open(json_path, "w") as json_file:
102
+ json.dump(data, json_file)
103
+
104
+
105
+ def load_audio(audio_path: str, **kwargs):
106
+ """Load an audio file."""
107
+ waveform, sample_rate = torchaudio.load(audio_path, **kwargs)
108
+ return waveform, sample_rate
109
+
110
+
111
+ def save_audio(
112
+ data: TensorType["num_channels", "num_frames"],
113
+ audio_path: str,
114
+ sample_rate: int = 44100,
115
+ ):
116
+ """Save data in an audio file."""
117
+ torchaudio.save(audio_path, data, sample_rate)
utils/random_utils.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import random
3
+ import torch
4
+
5
+
6
+ def set_random_seed(seed: int):
7
+ torch.manual_seed((seed) % (1 << 31))
8
+ torch.cuda.manual_seed((seed) % (1 << 31))
9
+ torch.cuda.manual_seed_all((seed) % (1 << 31))
10
+ np.random.seed((seed) % (1 << 31))
11
+ random.seed((seed) % (1 << 31))
12
+ torch.backends.cudnn.benchmark = False
13
+ torch.backends.cudnn.deterministic = True
14
+
15
+
16
+ class StackedRandomGenerator:
17
+ """
18
+ Wrapper for torch.Generator that allows specifying a different random seed for each
19
+ sample in a minibatch.
20
+ """
21
+
22
+ def __init__(self, device, seeds):
23
+ super().__init__()
24
+ self.generators = [
25
+ torch.Generator(device).manual_seed(int(seed) % (1 << 31)) for seed in seeds
26
+ ]
27
+
28
+ def randn_rn(self, size, **kwargs):
29
+ assert size[0] == len(self.generators)
30
+ return torch.stack(
31
+ [torch.randn(size[1:], generator=gen, **kwargs) for gen in self.generators]
32
+ )
33
+
34
+ def randn_like(self, input):
35
+ return self.randn_rn(
36
+ input.shape, dtype=input.dtype, layout=input.layout, device=input.device
37
+ )
38
+
39
+ def randint(self, *args, size, **kwargs):
40
+ assert size[0] == len(self.generators)
41
+ return torch.stack(
42
+ [
43
+ torch.randint(*args, size=size[1:], generator=gen, **kwargs)
44
+ for gen in self.generators
45
+ ]
46
+ )
utils/rerun.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from matplotlib import colormaps
3
+ import rerun as rr
4
+ from rerun.components import Material
5
+ from scipy.spatial import transform
6
+
7
+
8
+ def color_fn(x, cmap="tab10"):
9
+ return colormaps[cmap](x % colormaps[cmap].N)
10
+
11
+
12
+ def log_sample(
13
+ root_name: str,
14
+ traj: np.ndarray,
15
+ K: np.ndarray,
16
+ vertices: np.ndarray,
17
+ faces: np.ndarray,
18
+ normals: np.ndarray,
19
+ caption: str,
20
+ mesh_masks: np.ndarray,
21
+ ):
22
+ num_cameras = traj.shape[0]
23
+
24
+ rr.log(root_name, rr.ViewCoordinates.RIGHT_HAND_Y_DOWN, timeless=True)
25
+ rr.log(
26
+ f"{root_name}/trajectory/points",
27
+ rr.Points3D(traj[:, :3, 3]),
28
+ timeless=True,
29
+ )
30
+ rr.log(
31
+ f"{root_name}/trajectory/line",
32
+ rr.LineStrips3D(
33
+ np.stack((traj[:, :3, 3][:-1], traj[:, :3, 3][1:]), axis=1),
34
+ colors=[(1.0, 0.0, 1.0, 1.0)],
35
+ ),
36
+ timeless=True,
37
+ )
38
+ for k in range(num_cameras):
39
+ rr.set_time_sequence("frame_idx", k)
40
+
41
+ translation = traj[k][:3, 3]
42
+ rotation_q = transform.Rotation.from_matrix(traj[k][:3, :3]).as_quat()
43
+ rr.log(
44
+ f"{root_name}/camera/image",
45
+ rr.Pinhole(
46
+ image_from_camera=K,
47
+ width=K[0, -1] * 2,
48
+ height=K[1, -1] * 2,
49
+ ),
50
+ )
51
+ rr.log(
52
+ f"{root_name}/camera",
53
+ rr.Transform3D(
54
+ translation=translation,
55
+ rotation=rr.Quaternion(xyzw=rotation_q),
56
+ ),
57
+ )
58
+ rr.set_time_sequence("image", k)
59
+
60
+ # Null vertices
61
+ if vertices[k].sum() == 0:
62
+ rr.log(f"{root_name}/char/char", rr.Clear(recursive=False))
63
+ rr.log(f"{root_name}/camera/image/bbox", rr.Clear(recursive=False))
64
+ continue
65
+
66
+ rr.log(
67
+ f"{root_name}/char/char",
68
+ rr.Mesh3D(
69
+ vertex_positions=vertices[k],
70
+ indices=faces,
71
+ vertex_normals=normals[k],
72
+ mesh_material=Material(albedo_factor=color_fn(0)),
73
+ ),
74
+ )
utils/rotation_utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy.spatial.transform import Rotation as R
3
+ import torch
4
+ from torchtyping import TensorType
5
+ from itertools import product
6
+
7
+ num_samples, num_cams = None, None
8
+
9
+
10
+ def rotvec_to_matrix(rotvec):
11
+ return R.from_rotvec(rotvec).as_matrix()
12
+
13
+
14
+ def matrix_to_rotvec(mat):
15
+ return R.from_matrix(mat).as_rotvec()
16
+
17
+
18
+ def compose_rotvec(r1, r2):
19
+ """
20
+ #TODO: adapt to torch
21
+ Compose two rotation euler vectors.
22
+ """
23
+ r1 = r1.cpu().numpy() if isinstance(r1, torch.Tensor) else r1
24
+ r2 = r2.cpu().numpy() if isinstance(r2, torch.Tensor) else r2
25
+
26
+ R1 = rotvec_to_matrix(r1)
27
+ R2 = rotvec_to_matrix(r2)
28
+ cR = np.einsum("...ij,...jk->...ik", R1, R2)
29
+ return torch.from_numpy(matrix_to_rotvec(cR))
30
+
31
+
32
+ def quat_to_rotvec(quat, eps=1e-6):
33
+ # w > 0 to ensure 0 <= angle <= pi
34
+ flip = (quat[..., :1] < 0).float()
35
+ quat = (-1 * quat) * flip + (1 - flip) * quat
36
+
37
+ angle = 2 * torch.atan2(torch.linalg.norm(quat[..., 1:], dim=-1), quat[..., 0])
38
+
39
+ angle2 = angle * angle
40
+ small_angle_scales = 2 + angle2 / 12 + 7 * angle2 * angle2 / 2880
41
+ large_angle_scales = angle / torch.sin(angle / 2 + eps)
42
+
43
+ small_angles = (angle <= 1e-3).float()
44
+ rot_vec_scale = (
45
+ small_angle_scales * small_angles + (1 - small_angles) * large_angle_scales
46
+ )
47
+ rot_vec = rot_vec_scale[..., None] * quat[..., 1:]
48
+ return rot_vec
49
+
50
+
51
+ # batch*n
52
+ def normalize_vector(v, return_mag=False):
53
+ batch = v.shape[0]
54
+ v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
55
+ v_mag = torch.max(
56
+ v_mag, torch.autograd.Variable(torch.FloatTensor([1e-8])).to(v.device)
57
+ )
58
+ v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
59
+ v = v / v_mag
60
+ if return_mag is True:
61
+ return v, v_mag[:, 0]
62
+ else:
63
+ return v
64
+
65
+
66
+ # u, v batch*n
67
+ def cross_product(u, v):
68
+ batch = u.shape[0]
69
+ i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
70
+ j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
71
+ k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
72
+
73
+ out = torch.cat(
74
+ (i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1
75
+ ) # [batch, 6]
76
+
77
+ return out
78
+
79
+
80
+ def compute_rotation_matrix_from_ortho6d(ortho6d):
81
+ x_raw = ortho6d[:, 0:3] # [batch, 6]
82
+ y_raw = ortho6d[:, 3:6] # [batch, 6]
83
+
84
+ x = normalize_vector(x_raw) # [batch, 6]
85
+ z = cross_product(x, y_raw) # [batch, 6]
86
+ z = normalize_vector(z) # [batch, 6]
87
+ y = cross_product(z, x) # [batch, 6]
88
+
89
+ x = x.view(-1, 3, 1)
90
+ y = y.view(-1, 3, 1)
91
+ z = z.view(-1, 3, 1)
92
+ matrix = torch.cat((x, y, z), 2) # [batch, 3, 3]
93
+ return matrix
94
+
95
+
96
+ def invert_rotvec(rotvec: TensorType["num_samples", 3]):
97
+ angle = torch.norm(rotvec, dim=-1)
98
+ axis = rotvec / (angle.unsqueeze(-1) + 1e-6)
99
+ inverted_rotvec = -angle.unsqueeze(-1) * axis
100
+ return inverted_rotvec
101
+
102
+
103
+ def are_rotations(matrix: TensorType["num_samples", 3, 3]) -> TensorType["num_samples"]:
104
+ """Check if a matrix is a rotation matrix."""
105
+ # Check if the matrix is orthogonal
106
+ identity = torch.eye(3, device=matrix.device)
107
+ is_orthogonal = (
108
+ torch.isclose(torch.bmm(matrix, matrix.transpose(1, 2)), identity, atol=1e-6)
109
+ .all(dim=1)
110
+ .all(dim=1)
111
+ )
112
+
113
+ # Check if the determinant is 1
114
+ determinant = torch.det(matrix)
115
+ is_determinant_one = torch.isclose(
116
+ determinant, torch.tensor(1.0, device=matrix.device), atol=1e-6
117
+ )
118
+
119
+ return torch.logical_and(is_orthogonal, is_determinant_one)
120
+
121
+
122
+ def project_so3(
123
+ matrix: TensorType["num_samples", 4, 4]
124
+ ) -> TensorType["num_samples", 4, 4]:
125
+ # Project rotation matrix to SO(3)
126
+ # TODO: use torch
127
+ rot = R.from_matrix(matrix[:, :3, :3].cpu().numpy()).as_matrix()
128
+
129
+ projection = torch.eye(4).unsqueeze(0).repeat(matrix.shape[0], 1, 1).to(matrix)
130
+ projection[:, :3, :3] = torch.from_numpy(rot).to(matrix)
131
+ projection[:, :3, 3] = matrix[:, :3, 3]
132
+
133
+ return projection
134
+
135
+
136
+ def pairwise_geodesic(
137
+ R_x: TensorType["num_samples", "num_cams", 3, 3],
138
+ R_y: TensorType["num_samples", "num_cams", 3, 3],
139
+ reduction: str = "mean",
140
+ block_size: int = 200,
141
+ ):
142
+ def arange(start, stop, step, endpoint=True):
143
+ arr = torch.arange(start, stop, step)
144
+ if endpoint and arr[-1] != stop - 1:
145
+ arr = torch.cat((arr, torch.tensor([stop - 1], dtype=arr.dtype)))
146
+ return arr
147
+
148
+ # Geodesic distance
149
+ # https://math.stackexchange.com/questions/2113634/comparing-two-rotation-matrices
150
+ num_samples, num_cams, _, _ = R_x.shape
151
+
152
+ C = torch.zeros(num_samples, num_samples, device=R_x.device)
153
+ chunk_indices = arange(0, num_samples + 1, block_size, endpoint=True)
154
+ for i, j in product(
155
+ range(chunk_indices.shape[0] - 1), range(chunk_indices.shape[0] - 1)
156
+ ):
157
+ start_x, stop_x = chunk_indices[i], chunk_indices[i + 1]
158
+ start_y, stop_y = chunk_indices[j], chunk_indices[j + 1]
159
+ r_x, r_y = R_x[start_x:stop_x], R_y[start_y:stop_y]
160
+
161
+ # Compute rotations between each pair of cameras of each sample
162
+ r_xy = torch.einsum("anjk,bnlk->abnjl", r_x, r_y) # b, b, N, 3, 3
163
+
164
+ # Compute axis-angle representations: angle is the geodesic distance
165
+ traces = r_xy.diagonal(dim1=-2, dim2=-1).sum(-1)
166
+ c = torch.acos(torch.clamp((traces - 1) / 2, -1, 1)) / torch.pi
167
+
168
+ # Average distance between cameras over samples
169
+ if reduction == "mean":
170
+ C[start_x:stop_x, start_y:stop_y] = c.mean(-1)
171
+ elif reduction == "sum":
172
+ C[start_x:stop_x, start_y:stop_y] = c.sum(-1)
173
+
174
+ # Check for NaN values in traces
175
+ if torch.isnan(c).any():
176
+ raise ValueError("NaN values detected in traces")
177
+
178
+ return C