IsaacLabe commited on
Commit
420af5a
·
verified ·
1 Parent(s): 3d20002

Upload 66 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +10 -0
  2. EXP6_SOMwoTrack_Fulld_EPOCHS200/.DS_Store +0 -0
  3. EXP6_SOMwoTrack_Fulld_EPOCHS200/Log_of_loss.txt +0 -0
  4. EXP6_SOMwoTrack_Fulld_EPOCHS200/cfg.yaml +70 -0
  5. EXP6_SOMwoTrack_Fulld_EPOCHS200/checkpoints/last.ckpt +3 -0
  6. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__init__.py +0 -0
  7. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/__init__.cpython-311.pyc +0 -0
  8. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/configs.cpython-311.pyc +0 -0
  9. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/init_utils.cpython-311.pyc +0 -0
  10. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/loss_utils.cpython-311.pyc +0 -0
  11. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/metrics.cpython-311.pyc +0 -0
  12. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/params.cpython-311.pyc +0 -0
  13. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/scene_model.cpython-311.pyc +0 -0
  14. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/tensor_dataclass.cpython-311.pyc +0 -0
  15. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/trainer.cpython-311.pyc +0 -0
  16. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/transforms.cpython-311.pyc +0 -0
  17. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/validator.cpython-311.pyc +0 -0
  18. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/configs.py +70 -0
  19. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__init__.py +39 -0
  20. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/__init__.cpython-311.pyc +0 -0
  21. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/base_dataset.cpython-311.pyc +0 -0
  22. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/casual_dataset.cpython-311.pyc +0 -0
  23. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/colmap.cpython-311.pyc +0 -0
  24. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/iphone_dataset.cpython-311.pyc +0 -0
  25. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/utils.cpython-311.pyc +0 -0
  26. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/base_dataset.py +77 -0
  27. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/casual_dataset.py +496 -0
  28. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/colmap.py +369 -0
  29. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/iphone_dataset.py +837 -0
  30. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/utils.py +360 -0
  31. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/init_utils.py +650 -0
  32. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/loss_utils.py +157 -0
  33. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/metrics.py +313 -0
  34. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/params.py +193 -0
  35. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/renderer.py +89 -0
  36. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/scene_model.py +343 -0
  37. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/tensor_dataclass.py +96 -0
  38. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/trainer.py +827 -0
  39. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/trajectories.py +200 -0
  40. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/transforms.py +129 -0
  41. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/validator.py +443 -0
  42. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__init__.py +0 -0
  43. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/__init__.cpython-311.pyc +0 -0
  44. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/playback_panel.cpython-311.pyc +0 -0
  45. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/render_panel.cpython-311.pyc +0 -0
  46. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/utils.cpython-311.pyc +0 -0
  47. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/viewer.cpython-311.pyc +0 -0
  48. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/playback_panel.py +68 -0
  49. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/render_panel.py +1165 -0
  50. EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/utils.py +544 -0
.gitattributes CHANGED
@@ -33,3 +33,13 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0100/depths.mp4 filter=lfs diff=lfs merge=lfs -text
37
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0100/motion_coefs.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0100/PCA.mp4 filter=lfs diff=lfs merge=lfs -text
39
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0100/rgbs.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0100/tracks_2d.mp4 filter=lfs diff=lfs merge=lfs -text
41
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0199/depths.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0199/motion_coefs.mp4 filter=lfs diff=lfs merge=lfs -text
43
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0199/PCA.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0199/rgbs.mp4 filter=lfs diff=lfs merge=lfs -text
45
+ EXP6_SOMwoTrack_Fulld_EPOCHS200/videos/epoch_0199/tracks_2d.mp4 filter=lfs diff=lfs merge=lfs -text
EXP6_SOMwoTrack_Fulld_EPOCHS200/.DS_Store ADDED
Binary file (6.15 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/Log_of_loss.txt ADDED
The diff for this file is too large to render. See raw diff
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/cfg.yaml ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 4
2
+ data:
3
+ camera_type: refined
4
+ data_dir: /content/DGD-SOM/backpack
5
+ depth_type: depth_anything_colmap
6
+ end: -1
7
+ load_from_cache: false
8
+ num_targets_per_frame: 4
9
+ scene_norm_dict: null
10
+ skip_load_imgs: false
11
+ split: train
12
+ start: 0
13
+ use_median_filter: false
14
+ feature_dim: 384
15
+ foundation_model: DINOv2
16
+ loss:
17
+ w_depth_const: 0.1
18
+ w_depth_grad: 1
19
+ w_depth_reg: 0.5
20
+ w_feature: 0.5
21
+ w_mask: 1.0
22
+ w_rgb: 1.0
23
+ w_scale_var: 0.01
24
+ w_smooth_bases: 0.1
25
+ w_smooth_tracks: 2.0
26
+ w_track: 2.0
27
+ w_z_accel: 1.0
28
+ lr:
29
+ bg:
30
+ colors: 0.01
31
+ features: 0.001
32
+ means: 0.00016
33
+ opacities: 0.05
34
+ quats: 0.001
35
+ scales: 0.005
36
+ fg:
37
+ colors: 0.01
38
+ features: 0.001
39
+ means: 0.00016
40
+ motion_coefs: 0.01
41
+ opacities: 0.01
42
+ quats: 0.001
43
+ scales: 0.005
44
+ motion_bases:
45
+ rots: 0.00016
46
+ transls: 0.00016
47
+ num_bg: 100000
48
+ num_dl_workers: 4
49
+ num_epochs: 200
50
+ num_fg: 40000
51
+ num_motion_bases: 10
52
+ optim:
53
+ control_every: 100
54
+ cull_opacity_threshold: 0.1
55
+ cull_scale_threshold: 0.5
56
+ cull_screen_threshold: 0.15
57
+ densify_scale_threshold: 0.01
58
+ densify_screen_threshold: 0.05
59
+ densify_xys_grad_threshold: 0.0002
60
+ max_steps: 5000
61
+ reset_opacity_every_n_controls: 30
62
+ stop_control_by_screen_steps: 4000
63
+ stop_control_steps: 4000
64
+ stop_densify_steps: 15000
65
+ warmup_steps: 200
66
+ port: null
67
+ save_videos_every: 100
68
+ validate_every: 100
69
+ vis_debug: false
70
+ work_dir: /content/DGD-SOM/EXP6_SOMwoTrack_Fulld_EPOCHS200
EXP6_SOMwoTrack_Fulld_EPOCHS200/checkpoints/last.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43e3f2d442e243a904193be0c159c38ca449f74c0c71427f17bdcf8ffa9c8bca
3
+ size 2193490744
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__init__.py ADDED
File without changes
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (144 Bytes). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/configs.cpython-311.pyc ADDED
Binary file (3.92 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/init_utils.cpython-311.pyc ADDED
Binary file (33.6 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/loss_utils.cpython-311.pyc ADDED
Binary file (9.8 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/metrics.cpython-311.pyc ADDED
Binary file (19.6 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/params.cpython-311.pyc ADDED
Binary file (12.5 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/scene_model.cpython-311.pyc ADDED
Binary file (16.4 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/tensor_dataclass.cpython-311.pyc ADDED
Binary file (6.77 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/trainer.cpython-311.pyc ADDED
Binary file (42.6 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/transforms.cpython-311.pyc ADDED
Binary file (6.64 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/__pycache__/validator.cpython-311.pyc ADDED
Binary file (25.6 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/configs.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+
3
+
4
+ @dataclass
5
+ class FGLRConfig:
6
+ means: float = 1.6e-4
7
+ opacities: float = 1e-2
8
+ scales: float = 5e-3
9
+ quats: float = 1e-3
10
+ colors: float = 1e-2
11
+ features: float = 1e-3
12
+ motion_coefs: float = 1e-2
13
+
14
+
15
+ @dataclass
16
+ class BGLRConfig:
17
+ means: float = 1.6e-4
18
+ opacities: float = 5e-2
19
+ scales: float = 5e-3
20
+ quats: float = 1e-3
21
+ colors: float = 1e-2
22
+ features: float = 1e-3
23
+
24
+
25
+ @dataclass
26
+ class MotionLRConfig:
27
+ rots: float = 1.6e-4
28
+ transls: float = 1.6e-4
29
+
30
+
31
+ @dataclass
32
+ class SceneLRConfig:
33
+ fg: FGLRConfig
34
+ bg: BGLRConfig
35
+ motion_bases: MotionLRConfig
36
+
37
+
38
+ @dataclass
39
+ class LossesConfig:
40
+ w_rgb: float = 1.0
41
+ w_feature: float = 0.5
42
+ w_depth_reg: float = 0.5
43
+ w_depth_const: float = 0.1
44
+ w_depth_grad: float = 1
45
+ w_track: float = 2.0
46
+ w_mask: float = 1.0
47
+ w_smooth_bases: float = 0.1
48
+ w_smooth_tracks: float = 2.0
49
+ w_scale_var: float = 0.01
50
+ w_z_accel: float = 1.0
51
+
52
+
53
+ @dataclass
54
+ class OptimizerConfig:
55
+ max_steps: int = 5000
56
+ ## Adaptive gaussian control
57
+ warmup_steps: int = 200
58
+ control_every: int = 100
59
+ reset_opacity_every_n_controls: int = 30
60
+ stop_control_by_screen_steps: int = 4000
61
+ stop_control_steps: int = 4000
62
+ ### Densify.
63
+ densify_xys_grad_threshold: float = 0.0002
64
+ densify_scale_threshold: float = 0.01
65
+ densify_screen_threshold: float = 0.05
66
+ stop_densify_steps: int = 15000
67
+ ### Cull.
68
+ cull_opacity_threshold: float = 0.1
69
+ cull_scale_threshold: float = 0.5
70
+ cull_screen_threshold: float = 0.15
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__init__.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import asdict, replace
2
+
3
+ from torch.utils.data import Dataset
4
+
5
+ from .base_dataset import BaseDataset
6
+ from .casual_dataset import CasualDataset, CustomDataConfig, DavisDataConfig
7
+ from .iphone_dataset import (
8
+ iPhoneDataConfig,
9
+ iPhoneDataset,
10
+ iPhoneDatasetKeypointView,
11
+ iPhoneDatasetVideoView,
12
+ )
13
+
14
+
15
+ def get_train_val_datasets(
16
+ data_cfg: iPhoneDataConfig | DavisDataConfig | CustomDataConfig, load_val: bool
17
+ ) -> tuple[BaseDataset, Dataset | None, Dataset | None, Dataset | None]:
18
+ train_video_view = None
19
+ val_img_dataset = None
20
+ val_kpt_dataset = None
21
+ if isinstance(data_cfg, iPhoneDataConfig):
22
+ train_dataset = iPhoneDataset(**asdict(data_cfg))
23
+ train_video_view = iPhoneDatasetVideoView(train_dataset)
24
+ if load_val:
25
+ val_img_dataset = (
26
+ iPhoneDataset(
27
+ **asdict(replace(data_cfg, split="val", load_from_cache=True))
28
+ )
29
+ if train_dataset.has_validation
30
+ else None
31
+ )
32
+ val_kpt_dataset = iPhoneDatasetKeypointView(train_dataset)
33
+ elif isinstance(data_cfg, DavisDataConfig) or isinstance(
34
+ data_cfg, CustomDataConfig
35
+ ):
36
+ train_dataset = CasualDataset(**asdict(data_cfg))
37
+ else:
38
+ raise ValueError(f"Unknown data config: {data_cfg}")
39
+ return train_dataset, train_video_view, val_img_dataset, val_kpt_dataset
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (1.97 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/base_dataset.cpython-311.pyc ADDED
Binary file (4.17 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/casual_dataset.cpython-311.pyc ADDED
Binary file (31.1 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/colmap.cpython-311.pyc ADDED
Binary file (19.9 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/iphone_dataset.cpython-311.pyc ADDED
Binary file (39.7 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/__pycache__/utils.cpython-311.pyc ADDED
Binary file (18.1 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/base_dataset.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+
3
+ import torch
4
+ from torch.utils.data import Dataset, default_collate
5
+
6
+
7
+ class BaseDataset(Dataset):
8
+ @property
9
+ @abstractmethod
10
+ def num_frames(self) -> int: ...
11
+
12
+ @property
13
+ def keyframe_idcs(self) -> torch.Tensor:
14
+ return torch.arange(self.num_frames)
15
+
16
+ @abstractmethod
17
+ def get_w2cs(self) -> torch.Tensor: ...
18
+
19
+ @abstractmethod
20
+ def get_Ks(self) -> torch.Tensor: ...
21
+
22
+ @abstractmethod
23
+ def get_image(self, index: int) -> torch.Tensor: ...
24
+
25
+ @abstractmethod
26
+ def get_depth(self, index: int) -> torch.Tensor: ...
27
+
28
+ @abstractmethod
29
+ def get_mask(self, index: int) -> torch.Tensor: ...
30
+
31
+ def get_img_wh(self) -> tuple[int, int]: ...
32
+
33
+ @abstractmethod
34
+ def get_tracks_3d(
35
+ self, num_samples: int, **kwargs
36
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
37
+ """
38
+ Returns 3D tracks:
39
+ coordinates (N, T, 3),
40
+ visibles (N, T),
41
+ invisibles (N, T),
42
+ confidences (N, T),
43
+ colors (N, 3)
44
+ """
45
+ ...
46
+
47
+ @abstractmethod
48
+ def get_bkgd_points(
49
+ self, num_samples: int, **kwargs
50
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
51
+ """
52
+ Returns background points:
53
+ coordinates (N, 3),
54
+ normals (N, 3),
55
+ colors (N, 3)
56
+ """
57
+ ...
58
+
59
+ @staticmethod
60
+ def train_collate_fn(batch):
61
+ collated = {}
62
+ for k in batch[0]:
63
+ if k not in [
64
+ "query_tracks_2d",
65
+ "target_ts",
66
+ "target_w2cs",
67
+ "target_Ks",
68
+ "target_tracks_2d",
69
+ "target_visibles",
70
+ "target_track_depths",
71
+ "target_invisibles",
72
+ "target_confidences",
73
+ ]:
74
+ collated[k] = default_collate([sample[k] for sample in batch])
75
+ else:
76
+ collated[k] = [sample[k] for sample in batch]
77
+ return collated
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/casual_dataset.py ADDED
@@ -0,0 +1,496 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dataclasses import dataclass
3
+ from functools import partial
4
+ from typing import Literal, cast
5
+
6
+ import cv2
7
+ import imageio
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ import tyro
12
+ from loguru import logger as guru
13
+ from roma import roma
14
+ from tqdm import tqdm
15
+
16
+ from flow3d.data.base_dataset import BaseDataset
17
+ from flow3d.data.utils import (
18
+ UINT16_MAX,
19
+ SceneNormDict,
20
+ get_tracks_3d_for_query_frame,
21
+ median_filter_2d,
22
+ normal_from_depth_image,
23
+ normalize_coords,
24
+ parse_tapir_track_info,
25
+ )
26
+ from flow3d.transforms import rt_to_mat4
27
+
28
+
29
+ @dataclass
30
+ class DavisDataConfig:
31
+ seq_name: str
32
+ root_dir: str
33
+ start: int = 0
34
+ end: int = -1
35
+ res: str = "480p"
36
+ image_type: str = "JPEGImages"
37
+ mask_type: str = "Annotations"
38
+ depth_type: Literal[
39
+ "aligned_depth_anything",
40
+ "aligned_depth_anything_v2",
41
+ "depth_anything",
42
+ "depth_anything_v2",
43
+ "unidepth_disp",
44
+ ] = "aligned_depth_anything"
45
+ camera_type: Literal["droid_recon"] = "droid_recon"
46
+ track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir"
47
+ mask_erosion_radius: int = 3
48
+ scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None
49
+ num_targets_per_frame: int = 4
50
+ load_from_cache: bool = False
51
+
52
+
53
+ @dataclass
54
+ class CustomDataConfig:
55
+ seq_name: str
56
+ root_dir: str
57
+ start: int = 0
58
+ end: int = -1
59
+ res: str = ""
60
+ image_type: str = "images"
61
+ mask_type: str = "masks"
62
+ depth_type: Literal[
63
+ "aligned_depth_anything",
64
+ "aligned_depth_anything_v2",
65
+ "depth_anything",
66
+ "depth_anything_v2",
67
+ "unidepth_disp",
68
+ ] = "aligned_depth_anything"
69
+ camera_type: Literal["droid_recon"] = "droid_recon"
70
+ track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir"
71
+ mask_erosion_radius: int = 7
72
+ scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None
73
+ num_targets_per_frame: int = 4
74
+ load_from_cache: bool = False
75
+
76
+
77
+ class CasualDataset(BaseDataset):
78
+ def __init__(
79
+ self,
80
+ seq_name: str,
81
+ root_dir: str,
82
+ start: int = 0,
83
+ end: int = -1,
84
+ res: str = "480p",
85
+ image_type: str = "JPEGImages",
86
+ mask_type: str = "Annotations",
87
+ depth_type: Literal[
88
+ "aligned_depth_anything",
89
+ "aligned_depth_anything_v2",
90
+ "depth_anything",
91
+ "depth_anything_v2",
92
+ "unidepth_disp",
93
+ ] = "aligned_depth_anything",
94
+ camera_type: Literal["droid_recon"] = "droid_recon",
95
+ track_2d_type: Literal["bootstapir", "tapir"] = "bootstapir",
96
+ mask_erosion_radius: int = 3,
97
+ scene_norm_dict: SceneNormDict | None = None,
98
+ num_targets_per_frame: int = 4,
99
+ load_from_cache: bool = False,
100
+ **_,
101
+ ):
102
+ super().__init__()
103
+
104
+ self.seq_name = seq_name
105
+ self.root_dir = root_dir
106
+ self.res = res
107
+ self.depth_type = depth_type
108
+ self.num_targets_per_frame = num_targets_per_frame
109
+ self.load_from_cache = load_from_cache
110
+ self.has_validation = False
111
+ self.mask_erosion_radius = mask_erosion_radius
112
+
113
+ self.img_dir = f"{root_dir}/{image_type}/{res}/{seq_name}"
114
+ self.img_ext = os.path.splitext(os.listdir(self.img_dir)[0])[1]
115
+ self.depth_dir = f"{root_dir}/{depth_type}/{res}/{seq_name}"
116
+ self.mask_dir = f"{root_dir}/{mask_type}/{res}/{seq_name}"
117
+ self.tracks_dir = f"{root_dir}/{track_2d_type}/{res}/{seq_name}"
118
+ self.cache_dir = f"{root_dir}/flow3d_preprocessed/{res}/{seq_name}"
119
+ # self.cache_dir = f"datasets/davis/flow3d_preprocessed/{res}/{seq_name}"
120
+ frame_names = [os.path.splitext(p)[0] for p in sorted(os.listdir(self.img_dir))]
121
+
122
+ if end == -1:
123
+ end = len(frame_names)
124
+ self.start = start
125
+ self.end = end
126
+ self.frame_names = frame_names[start:end]
127
+
128
+ self.imgs: list[torch.Tensor | None] = [None for _ in self.frame_names]
129
+ self.depths: list[torch.Tensor | None] = [None for _ in self.frame_names]
130
+ self.masks: list[torch.Tensor | None] = [None for _ in self.frame_names]
131
+
132
+ # load cameras
133
+ if camera_type == "droid_recon":
134
+ img = self.get_image(0)
135
+ H, W = img.shape[:2]
136
+ w2cs, Ks, tstamps = load_cameras(
137
+ f"{root_dir}/{camera_type}/{seq_name}.npy", H, W
138
+ )
139
+ else:
140
+ raise ValueError(f"Unknown camera type: {camera_type}")
141
+ assert (
142
+ len(frame_names) == len(w2cs) == len(Ks)
143
+ ), f"{len(frame_names)}, {len(w2cs)}, {len(Ks)}"
144
+ self.w2cs = w2cs[start:end]
145
+ self.Ks = Ks[start:end]
146
+ tmask = (tstamps >= start) & (tstamps < end)
147
+ self._keyframe_idcs = tstamps[tmask] - start
148
+ self.scale = 1
149
+
150
+ if scene_norm_dict is None:
151
+ cached_scene_norm_dict_path = os.path.join(
152
+ self.cache_dir, "scene_norm_dict.pth"
153
+ )
154
+ if os.path.exists(cached_scene_norm_dict_path) and self.load_from_cache:
155
+ guru.info("loading cached scene norm dict...")
156
+ scene_norm_dict = torch.load(
157
+ os.path.join(self.cache_dir, "scene_norm_dict.pth")
158
+ )
159
+ else:
160
+ tracks_3d = self.get_tracks_3d(5000, step=self.num_frames // 10)[0]
161
+ scale, transfm = compute_scene_norm(tracks_3d, self.w2cs)
162
+ scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm)
163
+ os.makedirs(self.cache_dir, exist_ok=True)
164
+ torch.save(scene_norm_dict, cached_scene_norm_dict_path)
165
+
166
+ # transform cameras
167
+ self.scene_norm_dict = cast(SceneNormDict, scene_norm_dict)
168
+ self.scale = self.scene_norm_dict["scale"]
169
+ transform = self.scene_norm_dict["transfm"]
170
+ guru.info(f"scene norm {self.scale=}, {transform=}")
171
+ self.w2cs = torch.einsum("nij,jk->nik", self.w2cs, torch.linalg.inv(transform))
172
+ self.w2cs[:, :3, 3] /= self.scale
173
+
174
+ @property
175
+ def num_frames(self) -> int:
176
+ return len(self.frame_names)
177
+
178
+ @property
179
+ def keyframe_idcs(self) -> torch.Tensor:
180
+ return self._keyframe_idcs
181
+
182
+ def __len__(self):
183
+ return len(self.frame_names)
184
+
185
+ def get_w2cs(self) -> torch.Tensor:
186
+ return self.w2cs
187
+
188
+ def get_Ks(self) -> torch.Tensor:
189
+ return self.Ks
190
+
191
+ def get_img_wh(self) -> tuple[int, int]:
192
+ return self.get_image(0).shape[1::-1]
193
+
194
+ def get_image(self, index) -> torch.Tensor:
195
+ if self.imgs[index] is None:
196
+ self.imgs[index] = self.load_image(index)
197
+ img = cast(torch.Tensor, self.imgs[index])
198
+ return img
199
+
200
+ def get_mask(self, index) -> torch.Tensor:
201
+ if self.masks[index] is None:
202
+ self.masks[index] = self.load_mask(index)
203
+ mask = cast(torch.Tensor, self.masks[index])
204
+ return mask
205
+
206
+ def get_depth(self, index) -> torch.Tensor:
207
+ if self.depths[index] is None:
208
+ self.depths[index] = self.load_depth(index)
209
+ return self.depths[index] / self.scale
210
+
211
+ def load_image(self, index) -> torch.Tensor:
212
+ path = f"{self.img_dir}/{self.frame_names[index]}{self.img_ext}"
213
+ return torch.from_numpy(imageio.imread(path)).float() / 255.0
214
+
215
+ def load_mask(self, index) -> torch.Tensor:
216
+ path = f"{self.mask_dir}/{self.frame_names[index]}.png"
217
+ r = self.mask_erosion_radius
218
+ mask = imageio.imread(path)
219
+ fg_mask = mask.reshape((*mask.shape[:2], -1)).max(axis=-1) > 0
220
+ bg_mask = ~fg_mask
221
+ fg_mask_erode = cv2.erode(
222
+ fg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1
223
+ )
224
+ bg_mask_erode = cv2.erode(
225
+ bg_mask.astype(np.uint8), np.ones((r, r), np.uint8), iterations=1
226
+ )
227
+ out_mask = np.zeros_like(fg_mask, dtype=np.float32)
228
+ out_mask[bg_mask_erode > 0] = -1
229
+ out_mask[fg_mask_erode > 0] = 1
230
+ return torch.from_numpy(out_mask).float()
231
+
232
+ def load_depth(self, index) -> torch.Tensor:
233
+ path = f"{self.depth_dir}/{self.frame_names[index]}.npy"
234
+ disp = np.load(path)
235
+ depth = 1.0 / np.clip(disp, a_min=1e-6, a_max=1e6)
236
+ depth = torch.from_numpy(depth).float()
237
+ depth = median_filter_2d(depth[None, None], 11, 1)[0, 0]
238
+ return depth
239
+
240
+ def load_target_tracks(
241
+ self, query_index: int, target_indices: list[int], dim: int = 1
242
+ ):
243
+ """
244
+ tracks are 2d, occs and uncertainties
245
+ :param dim (int), default 1: dimension to stack the time axis
246
+ return (N, T, 4) if dim=1, (T, N, 4) if dim=0
247
+ """
248
+ q_name = self.frame_names[query_index]
249
+ all_tracks = []
250
+ for ti in target_indices:
251
+ t_name = self.frame_names[ti]
252
+ path = f"{self.tracks_dir}/{q_name}_{t_name}.npy"
253
+ tracks = np.load(path).astype(np.float32)
254
+ all_tracks.append(tracks)
255
+ return torch.from_numpy(np.stack(all_tracks, axis=dim))
256
+
257
+ def get_tracks_3d(
258
+ self, num_samples: int, start: int = 0, end: int = -1, step: int = 1, **kwargs
259
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
260
+ num_frames = self.num_frames
261
+ if end < 0:
262
+ end = num_frames + 1 + end
263
+ query_idcs = list(range(start, end, step))
264
+ target_idcs = list(range(start, end, step))
265
+ masks = torch.stack([self.get_mask(i) for i in target_idcs], dim=0)
266
+ fg_masks = (masks == 1).float()
267
+ depths = torch.stack([self.get_depth(i) for i in target_idcs], dim=0)
268
+ inv_Ks = torch.linalg.inv(self.Ks[target_idcs])
269
+ c2ws = torch.linalg.inv(self.w2cs[target_idcs])
270
+
271
+ num_per_query_frame = int(np.ceil(num_samples / len(query_idcs)))
272
+ cur_num = 0
273
+ tracks_all_queries = []
274
+ for q_idx in query_idcs:
275
+ # (N, T, 4)
276
+ tracks_2d = self.load_target_tracks(q_idx, target_idcs)
277
+ num_sel = int(
278
+ min(num_per_query_frame, num_samples - cur_num, len(tracks_2d))
279
+ )
280
+ if num_sel < len(tracks_2d):
281
+ sel_idcs = np.random.choice(len(tracks_2d), num_sel, replace=False)
282
+ tracks_2d = tracks_2d[sel_idcs]
283
+ cur_num += tracks_2d.shape[0]
284
+ img = self.get_image(q_idx)
285
+ tidx = target_idcs.index(q_idx)
286
+ tracks_tuple = get_tracks_3d_for_query_frame(
287
+ tidx, img, tracks_2d, depths, fg_masks, inv_Ks, c2ws
288
+ )
289
+ tracks_all_queries.append(tracks_tuple)
290
+ tracks_3d, colors, visibles, invisibles, confidences = map(
291
+ partial(torch.cat, dim=0), zip(*tracks_all_queries)
292
+ )
293
+ return tracks_3d, visibles, invisibles, confidences, colors
294
+
295
+ def get_bkgd_points(
296
+ self,
297
+ num_samples: int,
298
+ use_kf_tstamps: bool = True,
299
+ stride: int = 8,
300
+ down_rate: int = 8,
301
+ min_per_frame: int = 64,
302
+ **kwargs,
303
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
304
+ start = 0
305
+ end = self.num_frames
306
+ H, W = self.get_image(0).shape[:2]
307
+ grid = torch.stack(
308
+ torch.meshgrid(
309
+ torch.arange(0, W, dtype=torch.float32),
310
+ torch.arange(0, H, dtype=torch.float32),
311
+ indexing="xy",
312
+ ),
313
+ dim=-1,
314
+ )
315
+
316
+ if use_kf_tstamps:
317
+ query_idcs = self.keyframe_idcs.tolist()
318
+ else:
319
+ num_query_frames = self.num_frames // stride
320
+ query_endpts = torch.linspace(start, end, num_query_frames + 1)
321
+ query_idcs = ((query_endpts[:-1] + query_endpts[1:]) / 2).long().tolist()
322
+
323
+ bg_geometry = []
324
+ print(f"{query_idcs=}")
325
+ for query_idx in tqdm(query_idcs, desc="Loading bkgd points", leave=False):
326
+ img = self.get_image(query_idx)
327
+ depth = self.get_depth(query_idx)
328
+ bg_mask = self.get_mask(query_idx) < 0
329
+ bool_mask = (bg_mask * (depth > 0)).to(torch.bool)
330
+ w2c = self.w2cs[query_idx]
331
+ K = self.Ks[query_idx]
332
+
333
+ # get the bounding box of previous points that reproject into frame
334
+ # inefficient but works for now
335
+ bmax_x, bmax_y, bmin_x, bmin_y = 0, 0, W, H
336
+ for p3d, _, _ in bg_geometry:
337
+ if len(p3d) < 1:
338
+ continue
339
+ # reproject into current frame
340
+ p2d = torch.einsum(
341
+ "ij,jk,pk->pi", K, w2c[:3], F.pad(p3d, (0, 1), value=1.0)
342
+ )
343
+ p2d = p2d[:, :2] / p2d[:, 2:].clamp(min=1e-6)
344
+ xmin, xmax = p2d[:, 0].min().item(), p2d[:, 0].max().item()
345
+ ymin, ymax = p2d[:, 1].min().item(), p2d[:, 1].max().item()
346
+
347
+ bmin_x = min(bmin_x, int(xmin))
348
+ bmin_y = min(bmin_y, int(ymin))
349
+ bmax_x = max(bmax_x, int(xmax))
350
+ bmax_y = max(bmax_y, int(ymax))
351
+
352
+ # don't include points that are covered by previous points
353
+ bmin_x = max(0, bmin_x)
354
+ bmin_y = max(0, bmin_y)
355
+ bmax_x = min(W, bmax_x)
356
+ bmax_y = min(H, bmax_y)
357
+ overlap_mask = torch.ones_like(bool_mask)
358
+ overlap_mask[bmin_y:bmax_y, bmin_x:bmax_x] = 0
359
+
360
+ bool_mask &= overlap_mask
361
+ if bool_mask.sum() < min_per_frame:
362
+ guru.debug(f"skipping {query_idx=}")
363
+ continue
364
+
365
+ points = (
366
+ torch.einsum(
367
+ "ij,pj->pi",
368
+ torch.linalg.inv(K),
369
+ F.pad(grid[bool_mask], (0, 1), value=1.0),
370
+ )
371
+ * depth[bool_mask][:, None]
372
+ )
373
+ points = torch.einsum(
374
+ "ij,pj->pi", torch.linalg.inv(w2c)[:3], F.pad(points, (0, 1), value=1.0)
375
+ )
376
+ point_normals = normal_from_depth_image(depth, K, w2c)[bool_mask]
377
+ point_colors = img[bool_mask]
378
+
379
+ num_sel = max(len(points) // down_rate, min_per_frame)
380
+ sel_idcs = np.random.choice(len(points), num_sel, replace=False)
381
+ points = points[sel_idcs]
382
+ point_normals = point_normals[sel_idcs]
383
+ point_colors = point_colors[sel_idcs]
384
+ guru.debug(f"{query_idx=} {points.shape=}")
385
+ bg_geometry.append((points, point_normals, point_colors))
386
+
387
+ bg_points, bg_normals, bg_colors = map(
388
+ partial(torch.cat, dim=0), zip(*bg_geometry)
389
+ )
390
+ if len(bg_points) > num_samples:
391
+ sel_idcs = np.random.choice(len(bg_points), num_samples, replace=False)
392
+ bg_points = bg_points[sel_idcs]
393
+ bg_normals = bg_normals[sel_idcs]
394
+ bg_colors = bg_colors[sel_idcs]
395
+
396
+ return bg_points, bg_normals, bg_colors
397
+
398
+ def __getitem__(self, index: int):
399
+ index = np.random.randint(0, self.num_frames)
400
+ data = {
401
+ # ().
402
+ "frame_names": self.frame_names[index],
403
+ # ().
404
+ "ts": torch.tensor(index),
405
+ # (4, 4).
406
+ "w2cs": self.w2cs[index],
407
+ # (3, 3).
408
+ "Ks": self.Ks[index],
409
+ # (H, W, 3).
410
+ "imgs": self.get_image(index),
411
+ "depths": self.get_depth(index),
412
+ }
413
+ tri_mask = self.get_mask(index)
414
+ valid_mask = tri_mask != 0 # not fg or bg
415
+ mask = tri_mask == 1 # fg mask
416
+ data["masks"] = mask.float()
417
+ data["valid_masks"] = valid_mask.float()
418
+
419
+ # (P, 2)
420
+ query_tracks = self.load_target_tracks(index, [index])[:, 0, :2]
421
+ target_inds = torch.from_numpy(
422
+ np.random.choice(
423
+ self.num_frames, (self.num_targets_per_frame,), replace=False
424
+ )
425
+ )
426
+ # (N, P, 4)
427
+ target_tracks = self.load_target_tracks(index, target_inds.tolist(), dim=0)
428
+ data["query_tracks_2d"] = query_tracks
429
+ data["target_ts"] = target_inds
430
+ data["target_w2cs"] = self.w2cs[target_inds]
431
+ data["target_Ks"] = self.Ks[target_inds]
432
+ data["target_tracks_2d"] = target_tracks[..., :2]
433
+ # (N, P).
434
+ (
435
+ data["target_visibles"],
436
+ data["target_invisibles"],
437
+ data["target_confidences"],
438
+ ) = parse_tapir_track_info(target_tracks[..., 2], target_tracks[..., 3])
439
+ # (N, H, W)
440
+ target_depths = torch.stack([self.get_depth(i) for i in target_inds], dim=0)
441
+ H, W = target_depths.shape[-2:]
442
+ data["target_track_depths"] = F.grid_sample(
443
+ target_depths[:, None],
444
+ normalize_coords(target_tracks[..., None, :2], H, W),
445
+ align_corners=True,
446
+ padding_mode="border",
447
+ )[:, 0, :, 0]
448
+ return data
449
+
450
+
451
+ def load_cameras(
452
+ path: str, H: int, W: int
453
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
454
+ assert os.path.exists(path), f"Camera file {path} does not exist."
455
+ recon = np.load(path, allow_pickle=True).item()
456
+ guru.debug(f"{recon.keys()=}")
457
+ traj_c2w = recon["traj_c2w"] # (N, 4, 4)
458
+ h, w = recon["img_shape"]
459
+ sy, sx = H / h, W / w
460
+ traj_w2c = np.linalg.inv(traj_c2w)
461
+ fx, fy, cx, cy = recon["intrinsics"] # (4,)
462
+ K = np.array([[fx * sx, 0, cx * sx], [0, fy * sy, cy * sy], [0, 0, 1]]) # (3, 3)
463
+ Ks = np.tile(K[None, ...], (len(traj_c2w), 1, 1)) # (N, 3, 3)
464
+ kf_tstamps = recon["tstamps"].astype("int")
465
+ return (
466
+ torch.from_numpy(traj_w2c).float(),
467
+ torch.from_numpy(Ks).float(),
468
+ torch.from_numpy(kf_tstamps),
469
+ )
470
+
471
+
472
+ def compute_scene_norm(
473
+ X: torch.Tensor, w2cs: torch.Tensor
474
+ ) -> tuple[float, torch.Tensor]:
475
+ """
476
+ :param X: [N*T, 3]
477
+ :param w2cs: [N, 4, 4]
478
+ """
479
+ X = X.reshape(-1, 3)
480
+ scene_center = X.mean(dim=0)
481
+ X = X - scene_center[None]
482
+ min_scale = X.quantile(0.05, dim=0)
483
+ max_scale = X.quantile(0.95, dim=0)
484
+ scale = (max_scale - min_scale).max().item() / 2.0
485
+ original_up = -F.normalize(w2cs[:, 1, :3].mean(0), dim=-1)
486
+ target_up = original_up.new_tensor([0.0, 0.0, 1.0])
487
+ R = roma.rotvec_to_rotmat(
488
+ F.normalize(original_up.cross(target_up), dim=-1)
489
+ * original_up.dot(target_up).acos_()
490
+ )
491
+ transfm = rt_to_mat4(R, torch.einsum("ij,j->i", -R, scene_center))
492
+ return scale, transfm
493
+
494
+
495
+ if __name__ == "__main__":
496
+ d = CasualDataset("bear", "/shared/vye/datasets/DAVIS", camera_type="droid_recon")
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/colmap.py ADDED
@@ -0,0 +1,369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import struct
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Dict, Union
6
+
7
+ import numpy as np
8
+
9
+
10
+ def get_colmap_camera_params(colmap_dir, img_files):
11
+ cameras = read_cameras_binary(colmap_dir + "/cameras.bin")
12
+ images = read_images_binary(colmap_dir + "/images.bin")
13
+ colmap_image_idcs = {v.name: k for k, v in images.items()}
14
+ img_names = [os.path.basename(img_file) for img_file in img_files]
15
+ num_imgs = len(img_names)
16
+ K_all = np.zeros((num_imgs, 4, 4))
17
+ extrinsics_all = np.zeros((num_imgs, 4, 4))
18
+ for idx, name in enumerate(img_names):
19
+ key = colmap_image_idcs[name]
20
+ image = images[key]
21
+ assert image.name == name
22
+ K, extrinsics = get_intrinsics_extrinsics(image, cameras)
23
+ K_all[idx] = K
24
+ extrinsics_all[idx] = extrinsics
25
+
26
+ return K_all, extrinsics_all
27
+
28
+
29
+ @dataclass(frozen=True)
30
+ class CameraModel:
31
+ model_id: int
32
+ model_name: str
33
+ num_params: int
34
+
35
+
36
+ @dataclass(frozen=True)
37
+ class Camera:
38
+ id: int
39
+ model: str
40
+ width: int
41
+ height: int
42
+ params: np.ndarray
43
+
44
+
45
+ @dataclass(frozen=True)
46
+ class BaseImage:
47
+ id: int
48
+ qvec: np.ndarray
49
+ tvec: np.ndarray
50
+ camera_id: int
51
+ name: str
52
+ xys: np.ndarray
53
+ point3D_ids: np.ndarray
54
+
55
+
56
+ @dataclass(frozen=True)
57
+ class Point3D:
58
+ id: int
59
+ xyz: np.ndarray
60
+ rgb: np.ndarray
61
+ error: Union[float, np.ndarray]
62
+ image_ids: np.ndarray
63
+ point2D_idxs: np.ndarray
64
+
65
+
66
+ class Image(BaseImage):
67
+ def qvec2rotmat(self):
68
+ return qvec2rotmat(self.qvec)
69
+
70
+
71
+ CAMERA_MODELS = {
72
+ CameraModel(model_id=0, model_name="SIMPLE_PINHOLE", num_params=3),
73
+ CameraModel(model_id=1, model_name="PINHOLE", num_params=4),
74
+ CameraModel(model_id=2, model_name="SIMPLE_RADIAL", num_params=4),
75
+ CameraModel(model_id=3, model_name="RADIAL", num_params=5),
76
+ CameraModel(model_id=4, model_name="OPENCV", num_params=8),
77
+ CameraModel(model_id=5, model_name="OPENCV_FISHEYE", num_params=8),
78
+ CameraModel(model_id=6, model_name="FULL_OPENCV", num_params=12),
79
+ CameraModel(model_id=7, model_name="FOV", num_params=5),
80
+ CameraModel(model_id=8, model_name="SIMPLE_RADIAL_FISHEYE", num_params=4),
81
+ CameraModel(model_id=9, model_name="RADIAL_FISHEYE", num_params=5),
82
+ CameraModel(model_id=10, model_name="THIN_PRISM_FISHEYE", num_params=12),
83
+ }
84
+ CAMERA_MODEL_IDS = dict(
85
+ [(camera_model.model_id, camera_model) for camera_model in CAMERA_MODELS]
86
+ )
87
+
88
+
89
+ def read_next_bytes(fid, num_bytes, format_char_sequence, endian_character="<"):
90
+ """Read and unpack the next bytes from a binary file.
91
+ :param fid:
92
+ :param num_bytes: Sum of combination of {2, 4, 8}, e.g. 2, 6, 16, 30, etc.
93
+ :param format_char_sequence: List of {c, e, f, d, h, H, i, I, l, L, q, Q}.
94
+ :param endian_character: Any of {@, =, <, >, !}
95
+ :return: Tuple of read and unpacked values.
96
+ """
97
+ data = fid.read(num_bytes)
98
+ return struct.unpack(endian_character + format_char_sequence, data)
99
+
100
+
101
+ def read_cameras_text(path: Union[str, Path]) -> Dict[int, Camera]:
102
+ """
103
+ see: src/base/reconstruction.cc
104
+ void Reconstruction::WriteCamerasText(const std::string& path)
105
+ void Reconstruction::ReadCamerasText(const std::string& path)
106
+ """
107
+ cameras = {}
108
+ with open(path, "r") as fid:
109
+ while True:
110
+ line = fid.readline()
111
+ if not line:
112
+ break
113
+ line = line.strip()
114
+ if len(line) > 0 and line[0] != "#":
115
+ elems = line.split()
116
+ camera_id = int(elems[0])
117
+ model = elems[1]
118
+ width = int(elems[2])
119
+ height = int(elems[3])
120
+ params = np.array(tuple(map(float, elems[4:])))
121
+ cameras[camera_id] = Camera(
122
+ id=camera_id, model=model, width=width, height=height, params=params
123
+ )
124
+ return cameras
125
+
126
+
127
+ def read_cameras_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Camera]:
128
+ """
129
+ see: src/base/reconstruction.cc
130
+ void Reconstruction::WriteCamerasBinary(const std::string& path)
131
+ void Reconstruction::ReadCamerasBinary(const std::string& path)
132
+ """
133
+ cameras = {}
134
+ with open(path_to_model_file, "rb") as fid:
135
+ num_cameras = read_next_bytes(fid, 8, "Q")[0]
136
+ for camera_line_index in range(num_cameras):
137
+ camera_properties = read_next_bytes(
138
+ fid, num_bytes=24, format_char_sequence="iiQQ"
139
+ )
140
+ camera_id = camera_properties[0]
141
+ model_id = camera_properties[1]
142
+ model_name = CAMERA_MODEL_IDS[camera_properties[1]].model_name
143
+ width = camera_properties[2]
144
+ height = camera_properties[3]
145
+ num_params = CAMERA_MODEL_IDS[model_id].num_params
146
+ params = read_next_bytes(
147
+ fid, num_bytes=8 * num_params, format_char_sequence="d" * num_params
148
+ )
149
+ cameras[camera_id] = Camera(
150
+ id=camera_id,
151
+ model=model_name,
152
+ width=width,
153
+ height=height,
154
+ params=np.array(params),
155
+ )
156
+ assert len(cameras) == num_cameras
157
+ return cameras
158
+
159
+
160
+ def read_images_text(path: Union[str, Path]) -> Dict[int, Image]:
161
+ """
162
+ see: src/base/reconstruction.cc
163
+ void Reconstruction::ReadImagesText(const std::string& path)
164
+ void Reconstruction::WriteImagesText(const std::string& path)
165
+ """
166
+ images = {}
167
+ with open(path, "r") as fid:
168
+ while True:
169
+ line = fid.readline()
170
+ if not line:
171
+ break
172
+ line = line.strip()
173
+ if len(line) > 0 and line[0] != "#":
174
+ elems = line.split()
175
+ image_id = int(elems[0])
176
+ qvec = np.array(tuple(map(float, elems[1:5])))
177
+ tvec = np.array(tuple(map(float, elems[5:8])))
178
+ camera_id = int(elems[8])
179
+ image_name = elems[9]
180
+ elems = fid.readline().split()
181
+ xys = np.column_stack(
182
+ [tuple(map(float, elems[0::3])), tuple(map(float, elems[1::3]))]
183
+ )
184
+ point3D_ids = np.array(tuple(map(int, elems[2::3])))
185
+ images[image_id] = Image(
186
+ id=image_id,
187
+ qvec=qvec,
188
+ tvec=tvec,
189
+ camera_id=camera_id,
190
+ name=image_name,
191
+ xys=xys,
192
+ point3D_ids=point3D_ids,
193
+ )
194
+ return images
195
+
196
+
197
+ def read_images_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Image]:
198
+ """
199
+ see: src/base/reconstruction.cc
200
+ void Reconstruction::ReadImagesBinary(const std::string& path)
201
+ void Reconstruction::WriteImagesBinary(const std::string& path)
202
+ """
203
+ images = {}
204
+ with open(path_to_model_file, "rb") as fid:
205
+ num_reg_images = read_next_bytes(fid, 8, "Q")[0]
206
+ for image_index in range(num_reg_images):
207
+ binary_image_properties = read_next_bytes(
208
+ fid, num_bytes=64, format_char_sequence="idddddddi"
209
+ )
210
+ image_id = binary_image_properties[0]
211
+ qvec = np.array(binary_image_properties[1:5])
212
+ tvec = np.array(binary_image_properties[5:8])
213
+ camera_id = binary_image_properties[8]
214
+ image_name = ""
215
+ current_char = read_next_bytes(fid, 1, "c")[0]
216
+ while current_char != b"\x00": # look for the ASCII 0 entry
217
+ image_name += current_char.decode("utf-8")
218
+ current_char = read_next_bytes(fid, 1, "c")[0]
219
+ num_points2D = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
220
+ 0
221
+ ]
222
+ x_y_id_s = read_next_bytes(
223
+ fid,
224
+ num_bytes=24 * num_points2D,
225
+ format_char_sequence="ddq" * num_points2D,
226
+ )
227
+ xys = np.column_stack(
228
+ [tuple(map(float, x_y_id_s[0::3])), tuple(map(float, x_y_id_s[1::3]))]
229
+ )
230
+ point3D_ids = np.array(tuple(map(int, x_y_id_s[2::3])))
231
+ images[image_id] = Image(
232
+ id=image_id,
233
+ qvec=qvec,
234
+ tvec=tvec,
235
+ camera_id=camera_id,
236
+ name=image_name,
237
+ xys=xys,
238
+ point3D_ids=point3D_ids,
239
+ )
240
+ return images
241
+
242
+
243
+ def read_points3D_text(path: Union[str, Path]):
244
+ """
245
+ see: src/base/reconstruction.cc
246
+ void Reconstruction::ReadPoints3DText(const std::string& path)
247
+ void Reconstruction::WritePoints3DText(const std::string& path)
248
+ """
249
+ points3D = {}
250
+ with open(path, "r") as fid:
251
+ while True:
252
+ line = fid.readline()
253
+ if not line:
254
+ break
255
+ line = line.strip()
256
+ if len(line) > 0 and line[0] != "#":
257
+ elems = line.split()
258
+ point3D_id = int(elems[0])
259
+ xyz = np.array(tuple(map(float, elems[1:4])))
260
+ rgb = np.array(tuple(map(int, elems[4:7])))
261
+ error = float(elems[7])
262
+ image_ids = np.array(tuple(map(int, elems[8::2])))
263
+ point2D_idxs = np.array(tuple(map(int, elems[9::2])))
264
+ points3D[point3D_id] = Point3D(
265
+ id=point3D_id,
266
+ xyz=xyz,
267
+ rgb=rgb,
268
+ error=error,
269
+ image_ids=image_ids,
270
+ point2D_idxs=point2D_idxs,
271
+ )
272
+ return points3D
273
+
274
+
275
+ def read_points3d_binary(path_to_model_file: Union[str, Path]) -> Dict[int, Point3D]:
276
+ """
277
+ see: src/base/reconstruction.cc
278
+ void Reconstruction::ReadPoints3DBinary(const std::string& path)
279
+ void Reconstruction::WritePoints3DBinary(const std::string& path)
280
+ """
281
+ points3D = {}
282
+ with open(path_to_model_file, "rb") as fid:
283
+ num_points = read_next_bytes(fid, 8, "Q")[0]
284
+ for point_line_index in range(num_points):
285
+ binary_point_line_properties = read_next_bytes(
286
+ fid, num_bytes=43, format_char_sequence="QdddBBBd"
287
+ )
288
+ point3D_id = binary_point_line_properties[0]
289
+ xyz = np.array(binary_point_line_properties[1:4])
290
+ rgb = np.array(binary_point_line_properties[4:7])
291
+ error = np.array(binary_point_line_properties[7])
292
+ track_length = read_next_bytes(fid, num_bytes=8, format_char_sequence="Q")[
293
+ 0
294
+ ]
295
+ track_elems = read_next_bytes(
296
+ fid,
297
+ num_bytes=8 * track_length,
298
+ format_char_sequence="ii" * track_length,
299
+ )
300
+ image_ids = np.array(tuple(map(int, track_elems[0::2])))
301
+ point2D_idxs = np.array(tuple(map(int, track_elems[1::2])))
302
+ points3D[point3D_id] = Point3D(
303
+ id=point3D_id,
304
+ xyz=xyz,
305
+ rgb=rgb,
306
+ error=error,
307
+ image_ids=image_ids,
308
+ point2D_idxs=point2D_idxs,
309
+ )
310
+ return points3D
311
+
312
+
313
+ def qvec2rotmat(qvec):
314
+ return np.array(
315
+ [
316
+ [
317
+ 1 - 2 * qvec[2] ** 2 - 2 * qvec[3] ** 2,
318
+ 2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
319
+ 2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2],
320
+ ],
321
+ [
322
+ 2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
323
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[3] ** 2,
324
+ 2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1],
325
+ ],
326
+ [
327
+ 2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
328
+ 2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
329
+ 1 - 2 * qvec[1] ** 2 - 2 * qvec[2] ** 2,
330
+ ],
331
+ ]
332
+ )
333
+
334
+
335
+ def get_intrinsics_extrinsics(img, cameras):
336
+ # world to cam transformation
337
+ R = qvec2rotmat(img.qvec)
338
+ # translation
339
+ t = img.tvec
340
+ cam = cameras[img.camera_id]
341
+
342
+ if cam.model in ("SIMPLE_PINHOLE", "SIMPLE_RADIAL", "RADIAL"):
343
+ fx = fy = cam.params[0]
344
+ cx = cam.params[1]
345
+ cy = cam.params[2]
346
+ elif cam.model in (
347
+ "PINHOLE",
348
+ "OPENCV",
349
+ "OPENCV_FISHEYE",
350
+ "FULL_OPENCV",
351
+ ):
352
+ fx = cam.params[0]
353
+ fy = cam.params[1]
354
+ cx = cam.params[2]
355
+ cy = cam.params[3]
356
+ else:
357
+ raise Exception("Camera model not supported")
358
+
359
+ # intrinsics
360
+ K = np.identity(4)
361
+ K[0, 0] = fx
362
+ K[1, 1] = fy
363
+ K[0, 2] = cx
364
+ K[1, 2] = cy
365
+
366
+ extrinsics = np.eye(4)
367
+ extrinsics[:3, :3] = R
368
+ extrinsics[:3, 3] = t
369
+ return K, extrinsics
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/iphone_dataset.py ADDED
@@ -0,0 +1,837 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import os.path as osp
4
+ from dataclasses import dataclass
5
+ from glob import glob
6
+ from itertools import product
7
+ from typing import Literal
8
+
9
+ import imageio.v3 as iio
10
+ import numpy as np
11
+ import roma
12
+ import torch
13
+ import torch.nn.functional as F
14
+ import tyro
15
+ from loguru import logger as guru
16
+ from torch.utils.data import Dataset
17
+ from tqdm import tqdm
18
+
19
+ from flow3d.data.base_dataset import BaseDataset
20
+ from flow3d.data.colmap import get_colmap_camera_params
21
+ from flow3d.data.utils import (
22
+ SceneNormDict,
23
+ masked_median_blur,
24
+ normal_from_depth_image,
25
+ normalize_coords,
26
+ parse_tapir_track_info,
27
+ )
28
+ from flow3d.transforms import rt_to_mat4
29
+
30
+
31
+ @dataclass
32
+ class iPhoneDataConfig:
33
+ data_dir: str
34
+ start: int = 0
35
+ end: int = -1
36
+ split: Literal["train", "val"] = "train"
37
+ depth_type: Literal[
38
+ "midas",
39
+ "depth_anything",
40
+ "lidar",
41
+ "depth_anything_colmap",
42
+ ] = "depth_anything_colmap"
43
+ camera_type: Literal["original", "refined"] = "refined"
44
+ use_median_filter: bool = False
45
+ num_targets_per_frame: int = 4
46
+ scene_norm_dict: tyro.conf.Suppress[SceneNormDict | None] = None
47
+ load_from_cache: bool = False
48
+ skip_load_imgs: bool = False
49
+
50
+
51
+ class iPhoneDataset(BaseDataset):
52
+ def __init__(
53
+ self,
54
+ data_dir: str,
55
+ start: int = 0,
56
+ end: int = -1,
57
+ factor: int = 1,
58
+ split: Literal["train", "val"] = "train",
59
+ depth_type: Literal[
60
+ "midas",
61
+ "depth_anything",
62
+ "lidar",
63
+ "depth_anything_colmap",
64
+ ] = "depth_anything_colmap",
65
+ camera_type: Literal["original", "refined"] = "refined",
66
+ use_median_filter: bool = False,
67
+ num_targets_per_frame: int = 1,
68
+ scene_norm_dict: SceneNormDict | None = None,
69
+ load_from_cache: bool = False,
70
+ skip_load_imgs: bool = False,
71
+ **_,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.data_dir = data_dir
76
+ self.training = split == "train"
77
+ self.split = split
78
+ self.factor = factor
79
+ self.start = start
80
+ self.end = end
81
+ self.depth_type = depth_type
82
+ self.camera_type = camera_type
83
+ self.use_median_filter = use_median_filter
84
+ self.num_targets_per_frame = num_targets_per_frame
85
+ self.scene_norm_dict = scene_norm_dict
86
+ self.load_from_cache = load_from_cache
87
+ self.cache_dir = osp.join(data_dir, "flow3d_preprocessed", "cache")
88
+ os.makedirs(self.cache_dir, exist_ok=True)
89
+
90
+ # Test if the current data has validation set.
91
+ with open(osp.join(data_dir, "splits", "val.json")) as f:
92
+ split_dict = json.load(f)
93
+ self.has_validation = len(split_dict["frame_names"]) > 0
94
+
95
+ # Load metadata.
96
+ with open(osp.join(data_dir, "splits", f"{split}.json")) as f:
97
+ split_dict = json.load(f)
98
+ full_len = len(split_dict["frame_names"])
99
+ end = min(end, full_len) if end > 0 else full_len
100
+ self.end = end
101
+ self.frame_names = split_dict["frame_names"][start:end]
102
+ time_ids = [t for t in split_dict["time_ids"] if t >= start and t < end]
103
+ self.time_ids = torch.tensor(time_ids) - start
104
+ guru.info(f"{self.time_ids.min()=} {self.time_ids.max()=}")
105
+ # with open(osp.join(data_dir, "dataset.json")) as f:
106
+ # dataset_dict = json.load(f)
107
+ # self.num_frames = dataset_dict["num_exemplars"]
108
+ guru.info(f"{self.num_frames=}")
109
+ with open(osp.join(data_dir, "extra.json")) as f:
110
+ extra_dict = json.load(f)
111
+ self.fps = float(extra_dict["fps"])
112
+
113
+ # Load cameras.
114
+ if self.camera_type == "original":
115
+ Ks, w2cs = [], []
116
+ for frame_name in self.frame_names:
117
+ with open(osp.join(data_dir, "camera", f"{frame_name}.json")) as f:
118
+ camera_dict = json.load(f)
119
+ focal_length = camera_dict["focal_length"]
120
+ principal_point = camera_dict["principal_point"]
121
+ Ks.append(
122
+ [
123
+ [focal_length, 0.0, principal_point[0]],
124
+ [0.0, focal_length, principal_point[1]],
125
+ [0.0, 0.0, 1.0],
126
+ ]
127
+ )
128
+ orientation = np.array(camera_dict["orientation"])
129
+ position = np.array(camera_dict["position"])
130
+ w2cs.append(
131
+ np.block(
132
+ [
133
+ [orientation, -orientation @ position[:, None]],
134
+ [np.zeros((1, 3)), np.ones((1, 1))],
135
+ ]
136
+ ).astype(np.float32)
137
+ )
138
+ self.Ks = torch.tensor(Ks)
139
+ self.Ks[:, :2] /= factor
140
+ self.w2cs = torch.from_numpy(np.array(w2cs))
141
+ elif self.camera_type == "refined":
142
+ Ks, w2cs = get_colmap_camera_params(
143
+ osp.join(data_dir, "flow3d_preprocessed/colmap/sparse/"),
144
+ [frame_name + ".png" for frame_name in self.frame_names],
145
+ )
146
+ self.Ks = torch.from_numpy(Ks[:, :3, :3].astype(np.float32))
147
+ self.Ks[:, :2] /= factor
148
+ self.w2cs = torch.from_numpy(w2cs.astype(np.float32))
149
+ if not skip_load_imgs:
150
+ # Load images.
151
+ imgs = torch.from_numpy(
152
+ np.array(
153
+ [
154
+ iio.imread(
155
+ osp.join(self.data_dir, f"rgb/{factor}x/{frame_name}.png")
156
+ )
157
+ for frame_name in tqdm(
158
+ self.frame_names,
159
+ desc=f"Loading {self.split} images",
160
+ leave=False,
161
+ )
162
+ ],
163
+ )
164
+ )
165
+ self.imgs = imgs[..., :3] / 255.0
166
+ self.valid_masks = imgs[..., 3] / 255.0
167
+ # Load masks.
168
+ self.masks = (
169
+ torch.from_numpy(
170
+ np.array(
171
+ [
172
+ iio.imread(
173
+ osp.join(
174
+ self.data_dir,
175
+ "flow3d_preprocessed/track_anything/",
176
+ f"{factor}x/{frame_name}.png",
177
+ )
178
+ )
179
+ for frame_name in tqdm(
180
+ self.frame_names,
181
+ desc=f"Loading {self.split} masks",
182
+ leave=False,
183
+ )
184
+ ],
185
+ )
186
+ )
187
+ / 255.0
188
+ )
189
+ if self.training:
190
+ # Load depths.
191
+ def load_depth(frame_name):
192
+ if self.depth_type == "lidar":
193
+ depth = np.load(
194
+ osp.join(
195
+ self.data_dir,
196
+ f"depth/{factor}x/{frame_name}.npy",
197
+ )
198
+ )[..., 0]
199
+ else:
200
+ depth = np.load(
201
+ osp.join(
202
+ self.data_dir,
203
+ f"flow3d_preprocessed/aligned_{self.depth_type}/",
204
+ f"{factor}x/{frame_name}.npy",
205
+ )
206
+ )
207
+ depth[depth < 1e-3] = 1e-3
208
+ depth = 1.0 / depth
209
+ return depth
210
+
211
+ self.depths = torch.from_numpy(
212
+ np.array(
213
+ [
214
+ load_depth(frame_name)
215
+ for frame_name in tqdm(
216
+ self.frame_names,
217
+ desc=f"Loading {self.split} depths",
218
+ leave=False,
219
+ )
220
+ ],
221
+ np.float32,
222
+ )
223
+ )
224
+ max_depth_values_per_frame = self.depths.reshape(
225
+ self.num_frames, -1
226
+ ).max(1)[0]
227
+ max_depth_value = max_depth_values_per_frame.median() * 2.5
228
+ print("max_depth_value", max_depth_value)
229
+ self.depths = torch.clamp(self.depths, 0, max_depth_value)
230
+ # Median filter depths.
231
+ # NOTE(hangg): This operator is very expensive.
232
+ if self.use_median_filter:
233
+ for i in tqdm(
234
+ range(self.num_frames), desc="Processing depths", leave=False
235
+ ):
236
+ depth = masked_median_blur(
237
+ self.depths[[i]].unsqueeze(1).to("cuda"),
238
+ (
239
+ self.masks[[i]]
240
+ * self.valid_masks[[i]]
241
+ * (self.depths[[i]] > 0)
242
+ )
243
+ .unsqueeze(1)
244
+ .to("cuda"),
245
+ )[0, 0].cpu()
246
+ self.depths[i] = depth * self.masks[i] + self.depths[i] * (
247
+ 1 - self.masks[i]
248
+ )
249
+ # Load the query pixels from 2D tracks.
250
+ self.query_tracks_2d = [
251
+ torch.from_numpy(
252
+ np.load(
253
+ osp.join(
254
+ self.data_dir,
255
+ "flow3d_preprocessed/2d_tracks/",
256
+ f"{factor}x/{frame_name}_{frame_name}.npy",
257
+ )
258
+ ).astype(np.float32)
259
+ )
260
+ for frame_name in self.frame_names
261
+ ]
262
+ guru.info(
263
+ f"{len(self.query_tracks_2d)=} {self.query_tracks_2d[0].shape=}"
264
+ )
265
+
266
+ # Load sam features.
267
+ # sam_feat_dir = osp.join(
268
+ # data_dir, f"flow3d_preprocessed/sam_features/{factor}x"
269
+ # )
270
+ # assert osp.exists(sam_feat_dir), f"SAM features not exist!"
271
+ # sam_features, original_size, input_size = load_sam_features(
272
+ # sam_feat_dir, self.frame_names
273
+ # )
274
+ # guru.info(f"{sam_features.shape=} {original_size=} {input_size=}")
275
+ # self.sam_features = sam_features
276
+ # self.sam_original_size = original_size
277
+ # self.sam_input_size = input_size
278
+ else:
279
+ # Load covisible masks.
280
+ self.covisible_masks = (
281
+ torch.from_numpy(
282
+ np.array(
283
+ [
284
+ iio.imread(
285
+ osp.join(
286
+ self.data_dir,
287
+ "flow3d_preprocessed/covisible/",
288
+ f"{factor}x/{split}/{frame_name}.png",
289
+ )
290
+ )
291
+ for frame_name in tqdm(
292
+ self.frame_names,
293
+ desc=f"Loading {self.split} covisible masks",
294
+ leave=False,
295
+ )
296
+ ],
297
+ )
298
+ )
299
+ / 255.0
300
+ )
301
+
302
+ if self.scene_norm_dict is None:
303
+ cached_scene_norm_dict_path = osp.join(
304
+ self.cache_dir, "scene_norm_dict.pth"
305
+ )
306
+ if osp.exists(cached_scene_norm_dict_path) and self.load_from_cache:
307
+ print("loading cached scene norm dict...")
308
+ self.scene_norm_dict = torch.load(
309
+ osp.join(self.cache_dir, "scene_norm_dict.pth")
310
+ )
311
+ elif self.training:
312
+ # Compute the scene scale and transform for normalization.
313
+ # Normalize the scene based on the foreground 3D tracks.
314
+ subsampled_tracks_3d = self.get_tracks_3d(
315
+ num_samples=10000, step=self.num_frames // 10, show_pbar=False
316
+ )[0]
317
+ scene_center = subsampled_tracks_3d.mean((0, 1))
318
+ tracks_3d_centered = subsampled_tracks_3d - scene_center
319
+ min_scale = tracks_3d_centered.quantile(0.05, dim=0)
320
+ max_scale = tracks_3d_centered.quantile(0.95, dim=0)
321
+ scale = torch.max(max_scale - min_scale).item() / 2.0
322
+ original_up = -F.normalize(self.w2cs[:, 1, :3].mean(0), dim=-1)
323
+ target_up = original_up.new_tensor([0.0, 0.0, 1.0])
324
+ R = roma.rotvec_to_rotmat(
325
+ F.normalize(original_up.cross(target_up, dim=-1), dim=-1)
326
+ * original_up.dot(target_up).acos_()
327
+ )
328
+ transfm = rt_to_mat4(R, torch.einsum("ij,j->i", -R, scene_center))
329
+ self.scene_norm_dict = SceneNormDict(scale=scale, transfm=transfm)
330
+ torch.save(self.scene_norm_dict, cached_scene_norm_dict_path)
331
+ else:
332
+ raise ValueError("scene_norm_dict must be provided for validation.")
333
+
334
+ # Normalize the scene.
335
+ scale = self.scene_norm_dict["scale"]
336
+ transfm = self.scene_norm_dict["transfm"]
337
+ self.w2cs = self.w2cs @ torch.linalg.inv(transfm)
338
+ self.w2cs[:, :3, 3] /= scale
339
+ if self.training and not skip_load_imgs:
340
+ self.depths /= scale
341
+
342
+ if not skip_load_imgs:
343
+ guru.info(
344
+ f"{self.imgs.shape=} {self.valid_masks.shape=} {self.masks.shape=}"
345
+ )
346
+
347
+ @property
348
+ def num_frames(self) -> int:
349
+ return len(self.frame_names)
350
+
351
+ def __len__(self):
352
+ return self.imgs.shape[0]
353
+
354
+ def get_w2cs(self) -> torch.Tensor:
355
+ return self.w2cs
356
+
357
+ def get_Ks(self) -> torch.Tensor:
358
+ return self.Ks
359
+
360
+ def get_image(self, index: int) -> torch.Tensor:
361
+ return self.imgs[index]
362
+
363
+ def get_depth(self, index: int) -> torch.Tensor:
364
+ return self.depths[index]
365
+
366
+ def get_masks(self, index: int) -> torch.Tensor:
367
+ return self.masks[index]
368
+
369
+ def get_img_wh(self) -> tuple[int, int]:
370
+ return iio.imread(
371
+ osp.join(self.data_dir, f"rgb/{self.factor}x/{self.frame_names[0]}.png")
372
+ ).shape[1::-1]
373
+
374
+ # def get_sam_features(self) -> list[torch.Tensor, tuple[int, int], tuple[int, int]]:
375
+ # return self.sam_features, self.sam_original_size, self.sam_input_size
376
+
377
+ def get_tracks_3d(
378
+ self, num_samples: int, step: int = 1, show_pbar: bool = True, **kwargs
379
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
380
+ """Get 3D tracks from the dataset.
381
+
382
+ Args:
383
+ num_samples (int | None): The number of samples to fetch. If None,
384
+ fetch all samples. If not None, fetch roughly a same number of
385
+ samples across each frame. Note that this might result in
386
+ number of samples less than what is specified.
387
+ step (int): The step to temporally subsample the track.
388
+ """
389
+ assert (
390
+ self.split == "train"
391
+ ), "fetch_tracks_3d is only available for the training split."
392
+ cached_track_3d_path = osp.join(self.cache_dir, f"tracks_3d_{num_samples}.pth")
393
+ if osp.exists(cached_track_3d_path) and step == 1 and self.load_from_cache:
394
+ print("loading cached 3d tracks data...")
395
+ start, end = self.start, self.end
396
+ cached_track_3d_data = torch.load(cached_track_3d_path)
397
+ tracks_3d, visibles, invisibles, confidences, track_colors = (
398
+ cached_track_3d_data["tracks_3d"][:, start:end],
399
+ cached_track_3d_data["visibles"][:, start:end],
400
+ cached_track_3d_data["invisibles"][:, start:end],
401
+ cached_track_3d_data["confidences"][:, start:end],
402
+ cached_track_3d_data["track_colors"],
403
+ )
404
+ return tracks_3d, visibles, invisibles, confidences, track_colors
405
+
406
+ # Load 2D tracks.
407
+ raw_tracks_2d = []
408
+ candidate_frames = list(range(0, self.num_frames, step))
409
+ num_sampled_frames = len(candidate_frames)
410
+ for i in (
411
+ tqdm(candidate_frames, desc="Loading 2D tracks", leave=False)
412
+ if show_pbar
413
+ else candidate_frames
414
+ ):
415
+ curr_num_samples = self.query_tracks_2d[i].shape[0]
416
+ num_samples_per_frame = (
417
+ int(np.floor(num_samples / num_sampled_frames))
418
+ if i != candidate_frames[-1]
419
+ else num_samples
420
+ - (num_sampled_frames - 1)
421
+ * int(np.floor(num_samples / num_sampled_frames))
422
+ )
423
+ if num_samples_per_frame < curr_num_samples:
424
+ track_sels = np.random.choice(
425
+ curr_num_samples, (num_samples_per_frame,), replace=False
426
+ )
427
+ else:
428
+ track_sels = np.arange(0, curr_num_samples)
429
+ curr_tracks_2d = []
430
+ for j in range(0, self.num_frames, step):
431
+ if i == j:
432
+ target_tracks_2d = self.query_tracks_2d[i]
433
+ else:
434
+ target_tracks_2d = torch.from_numpy(
435
+ np.load(
436
+ osp.join(
437
+ self.data_dir,
438
+ "flow3d_preprocessed/2d_tracks/",
439
+ f"{self.factor}x/"
440
+ f"{self.frame_names[i]}_"
441
+ f"{self.frame_names[j]}.npy",
442
+ )
443
+ ).astype(np.float32)
444
+ )
445
+ curr_tracks_2d.append(target_tracks_2d[track_sels])
446
+ raw_tracks_2d.append(torch.stack(curr_tracks_2d, dim=1))
447
+ guru.info(f"{step=} {len(raw_tracks_2d)=} {raw_tracks_2d[0].shape=}")
448
+
449
+ # Process 3D tracks.
450
+ inv_Ks = torch.linalg.inv(self.Ks)[::step]
451
+ c2ws = torch.linalg.inv(self.w2cs)[::step]
452
+ H, W = self.imgs.shape[1:3]
453
+ filtered_tracks_3d, filtered_visibles, filtered_track_colors = [], [], []
454
+ filtered_invisibles, filtered_confidences = [], []
455
+ masks = self.masks * self.valid_masks * (self.depths > 0)
456
+ masks = (masks > 0.5).float()
457
+ for i, tracks_2d in enumerate(raw_tracks_2d):
458
+ tracks_2d = tracks_2d.swapdims(0, 1)
459
+ tracks_2d, occs, dists = (
460
+ tracks_2d[..., :2],
461
+ tracks_2d[..., 2],
462
+ tracks_2d[..., 3],
463
+ )
464
+ # visibles = postprocess_occlusions(occs, dists)
465
+ visibles, invisibles, confidences = parse_tapir_track_info(occs, dists)
466
+ # Unproject 2D tracks to 3D.
467
+ track_depths = F.grid_sample(
468
+ self.depths[::step, None],
469
+ normalize_coords(tracks_2d[..., None, :], H, W),
470
+ align_corners=True,
471
+ padding_mode="border",
472
+ )[:, 0]
473
+ tracks_3d = (
474
+ torch.einsum(
475
+ "nij,npj->npi",
476
+ inv_Ks,
477
+ F.pad(tracks_2d, (0, 1), value=1.0),
478
+ )
479
+ * track_depths
480
+ )
481
+ tracks_3d = torch.einsum(
482
+ "nij,npj->npi", c2ws, F.pad(tracks_3d, (0, 1), value=1.0)
483
+ )[..., :3]
484
+ # Filter out out-of-mask tracks.
485
+ is_in_masks = (
486
+ F.grid_sample(
487
+ masks[::step, None],
488
+ normalize_coords(tracks_2d[..., None, :], H, W),
489
+ align_corners=True,
490
+ ).squeeze()
491
+ == 1
492
+ )
493
+ visibles *= is_in_masks
494
+ invisibles *= is_in_masks
495
+ confidences *= is_in_masks.float()
496
+ # Get track's color from the query frame.
497
+ track_colors = (
498
+ F.grid_sample(
499
+ self.imgs[i * step : i * step + 1].permute(0, 3, 1, 2),
500
+ normalize_coords(tracks_2d[i : i + 1, None, :], H, W),
501
+ align_corners=True,
502
+ padding_mode="border",
503
+ )
504
+ .squeeze()
505
+ .T
506
+ )
507
+ # at least visible 5% of the time, otherwise discard
508
+ visible_counts = visibles.sum(0)
509
+ valid = visible_counts >= min(
510
+ int(0.05 * self.num_frames),
511
+ visible_counts.float().quantile(0.1).item(),
512
+ )
513
+
514
+ filtered_tracks_3d.append(tracks_3d[:, valid])
515
+ filtered_visibles.append(visibles[:, valid])
516
+ filtered_invisibles.append(invisibles[:, valid])
517
+ filtered_confidences.append(confidences[:, valid])
518
+ filtered_track_colors.append(track_colors[valid])
519
+
520
+ filtered_tracks_3d = torch.cat(filtered_tracks_3d, dim=1).swapdims(0, 1)
521
+ filtered_visibles = torch.cat(filtered_visibles, dim=1).swapdims(0, 1)
522
+ filtered_invisibles = torch.cat(filtered_invisibles, dim=1).swapdims(0, 1)
523
+ filtered_confidences = torch.cat(filtered_confidences, dim=1).swapdims(0, 1)
524
+ filtered_track_colors = torch.cat(filtered_track_colors, dim=0)
525
+ if step == 1:
526
+ torch.save(
527
+ {
528
+ "tracks_3d": filtered_tracks_3d,
529
+ "visibles": filtered_visibles,
530
+ "invisibles": filtered_invisibles,
531
+ "confidences": filtered_confidences,
532
+ "track_colors": filtered_track_colors,
533
+ },
534
+ cached_track_3d_path,
535
+ )
536
+ return (
537
+ filtered_tracks_3d,
538
+ filtered_visibles,
539
+ filtered_invisibles,
540
+ filtered_confidences,
541
+ filtered_track_colors,
542
+ )
543
+
544
+ def get_bkgd_points(
545
+ self, num_samples: int, **kwargs
546
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
547
+ H, W = self.imgs.shape[1:3]
548
+ grid = torch.stack(
549
+ torch.meshgrid(
550
+ torch.arange(W, dtype=torch.float32),
551
+ torch.arange(H, dtype=torch.float32),
552
+ indexing="xy",
553
+ ),
554
+ dim=-1,
555
+ )
556
+ candidate_frames = list(range(self.num_frames))
557
+ num_sampled_frames = len(candidate_frames)
558
+ bkgd_points, bkgd_point_normals, bkgd_point_colors = [], [], []
559
+ for i in tqdm(candidate_frames, desc="Loading bkgd points", leave=False):
560
+ img = self.imgs[i]
561
+ depth = self.depths[i]
562
+ bool_mask = ((1.0 - self.masks[i]) * self.valid_masks[i] * (depth > 0)).to(
563
+ torch.bool
564
+ )
565
+ w2c = self.w2cs[i]
566
+ K = self.Ks[i]
567
+ points = (
568
+ torch.einsum(
569
+ "ij,pj->pi",
570
+ torch.linalg.inv(K),
571
+ F.pad(grid[bool_mask], (0, 1), value=1.0),
572
+ )
573
+ * depth[bool_mask][:, None]
574
+ )
575
+ points = torch.einsum(
576
+ "ij,pj->pi", torch.linalg.inv(w2c)[:3], F.pad(points, (0, 1), value=1.0)
577
+ )
578
+ point_normals = normal_from_depth_image(depth, K, w2c)[bool_mask]
579
+ point_colors = img[bool_mask]
580
+ curr_num_samples = points.shape[0]
581
+ num_samples_per_frame = (
582
+ int(np.floor(num_samples / num_sampled_frames))
583
+ if i != candidate_frames[-1]
584
+ else num_samples
585
+ - (num_sampled_frames - 1)
586
+ * int(np.floor(num_samples / num_sampled_frames))
587
+ )
588
+ if num_samples_per_frame < curr_num_samples:
589
+ point_sels = np.random.choice(
590
+ curr_num_samples, (num_samples_per_frame,), replace=False
591
+ )
592
+ else:
593
+ point_sels = np.arange(0, curr_num_samples)
594
+ bkgd_points.append(points[point_sels])
595
+ bkgd_point_normals.append(point_normals[point_sels])
596
+ bkgd_point_colors.append(point_colors[point_sels])
597
+ bkgd_points = torch.cat(bkgd_points, dim=0)
598
+ bkgd_point_normals = torch.cat(bkgd_point_normals, dim=0)
599
+ bkgd_point_colors = torch.cat(bkgd_point_colors, dim=0)
600
+ return bkgd_points, bkgd_point_normals, bkgd_point_colors
601
+
602
+ def get_video_dataset(self) -> Dataset:
603
+ return iPhoneDatasetVideoView(self)
604
+
605
+ def __getitem__(self, index: int):
606
+ if self.training:
607
+ index = np.random.randint(0, self.num_frames)
608
+ data = {
609
+ # ().
610
+ "frame_names": self.frame_names[index],
611
+ # ().
612
+ "ts": self.time_ids[index],
613
+ # (4, 4).
614
+ "w2cs": self.w2cs[index],
615
+ # (3, 3).
616
+ "Ks": self.Ks[index],
617
+ # (H, W, 3).
618
+ "imgs": self.imgs[index],
619
+ # (H, W).
620
+ "valid_masks": self.valid_masks[index],
621
+ # (H, W).
622
+ "masks": self.masks[index],
623
+ }
624
+ if self.training:
625
+ # (H, W).
626
+ data["depths"] = self.depths[index]
627
+ # (P, 2).
628
+ data["query_tracks_2d"] = self.query_tracks_2d[index][:, :2]
629
+ target_inds = torch.from_numpy(
630
+ np.random.choice(
631
+ self.num_frames, (self.num_targets_per_frame,), replace=False
632
+ )
633
+ )
634
+ # (N, P, 4).
635
+ target_tracks_2d = torch.stack(
636
+ [
637
+ torch.from_numpy(
638
+ np.load(
639
+ osp.join(
640
+ self.data_dir,
641
+ "flow3d_preprocessed/2d_tracks/",
642
+ f"{self.factor}x/"
643
+ f"{self.frame_names[index]}_"
644
+ f"{self.frame_names[target_index.item()]}.npy",
645
+ )
646
+ ).astype(np.float32)
647
+ )
648
+ for target_index in target_inds
649
+ ],
650
+ dim=0,
651
+ )
652
+ # (N,).
653
+ target_ts = self.time_ids[target_inds]
654
+ data["target_ts"] = target_ts
655
+ # (N, 4, 4).
656
+ data["target_w2cs"] = self.w2cs[target_ts]
657
+ # (N, 3, 3).
658
+ data["target_Ks"] = self.Ks[target_ts]
659
+ # (N, P, 2).
660
+ data["target_tracks_2d"] = target_tracks_2d[..., :2]
661
+ # (N, P).
662
+ (
663
+ data["target_visibles"],
664
+ data["target_invisibles"],
665
+ data["target_confidences"],
666
+ ) = parse_tapir_track_info(
667
+ target_tracks_2d[..., 2], target_tracks_2d[..., 3]
668
+ )
669
+ # (N, P).
670
+ data["target_track_depths"] = F.grid_sample(
671
+ self.depths[target_inds, None],
672
+ normalize_coords(
673
+ target_tracks_2d[..., None, :2],
674
+ self.imgs.shape[1],
675
+ self.imgs.shape[2],
676
+ ),
677
+ align_corners=True,
678
+ padding_mode="border",
679
+ )[:, 0, :, 0]
680
+ else:
681
+ # (H, W).
682
+ data["covisible_masks"] = self.covisible_masks[index]
683
+ return data
684
+
685
+ def preprocess(self, data):
686
+ return data
687
+
688
+
689
+ class iPhoneDatasetKeypointView(Dataset):
690
+ """Return a dataset view of the annotated keypoints."""
691
+
692
+ def __init__(self, dataset: iPhoneDataset):
693
+ super().__init__()
694
+ self.dataset = dataset
695
+ assert self.dataset.split == "train"
696
+ # Load 2D keypoints.
697
+ keypoint_paths = sorted(
698
+ glob(osp.join(self.dataset.data_dir, "keypoint/2x/train/0_*.json"))
699
+ )
700
+ keypoints = []
701
+ for keypoint_path in keypoint_paths:
702
+ with open(keypoint_path) as f:
703
+ keypoints.append(json.load(f))
704
+ time_ids = [
705
+ int(osp.basename(p).split("_")[1].split(".")[0]) for p in keypoint_paths
706
+ ]
707
+ # only use time ids that are in the dataset.
708
+ start = self.dataset.start
709
+ time_ids = [t - start for t in time_ids if t - start in self.dataset.time_ids]
710
+ self.time_ids = torch.tensor(time_ids)
711
+ self.time_pairs = torch.tensor(list(product(self.time_ids, repeat=2)))
712
+ self.index_pairs = torch.tensor(
713
+ list(product(range(len(self.time_ids)), repeat=2))
714
+ )
715
+ self.keypoints = torch.tensor(keypoints, dtype=torch.float32)
716
+ self.keypoints[..., :2] *= 2.0 / self.dataset.factor
717
+
718
+ def __len__(self):
719
+ return len(self.time_pairs)
720
+
721
+ def __getitem__(self, index: int):
722
+ ts = self.time_pairs[index]
723
+ return {
724
+ "ts": ts,
725
+ "w2cs": self.dataset.w2cs[ts],
726
+ "Ks": self.dataset.Ks[ts],
727
+ "imgs": self.dataset.imgs[ts],
728
+ "keypoints": self.keypoints[self.index_pairs[index]],
729
+ }
730
+
731
+
732
+ class iPhoneDatasetVideoView(Dataset):
733
+ """Return a dataset view of the video trajectory."""
734
+
735
+ def __init__(self, dataset: iPhoneDataset):
736
+ super().__init__()
737
+ self.dataset = dataset
738
+ self.fps = self.dataset.fps
739
+ assert self.dataset.split == "train"
740
+
741
+ def __len__(self):
742
+ return self.dataset.num_frames
743
+
744
+ def __getitem__(self, index):
745
+ return {
746
+ "frame_names": self.dataset.frame_names[index],
747
+ "ts": index,
748
+ "w2cs": self.dataset.w2cs[index],
749
+ "Ks": self.dataset.Ks[index],
750
+ "imgs": self.dataset.imgs[index],
751
+ "depths": self.dataset.depths[index],
752
+ "masks": self.dataset.masks[index],
753
+ }
754
+
755
+
756
+ """
757
+ class iPhoneDataModule(BaseDataModule[iPhoneDataset]):
758
+ def __init__(
759
+ self,
760
+ data_dir: str,
761
+ factor: int = 1,
762
+ start: int = 0,
763
+ end: int = -1,
764
+ depth_type: Literal[
765
+ "midas",
766
+ "depth_anything",
767
+ "lidar",
768
+ "depth_anything_colmap",
769
+ ] = "depth_anything_colmap",
770
+ camera_type: Literal["original", "refined"] = "refined",
771
+ use_median_filter: bool = False,
772
+ num_targets_per_frame: int = 1,
773
+ load_from_cache: bool = False,
774
+ **kwargs,
775
+ ):
776
+ super().__init__(dataset_cls=iPhoneDataset, **kwargs)
777
+ self.data_dir = data_dir
778
+ self.start = start
779
+ self.end = end
780
+ self.factor = factor
781
+ self.depth_type = depth_type
782
+ self.camera_type = camera_type
783
+ self.use_median_filter = use_median_filter
784
+ self.num_targets_per_frame = num_targets_per_frame
785
+ self.load_from_cache = load_from_cache
786
+
787
+ self.val_loader_tasks = ["img", "keypoint"]
788
+
789
+ def setup(self, *_, **__) -> None:
790
+ guru.info("Loading train dataset...")
791
+ self.train_dataset = self.dataset_cls(
792
+ data_dir=self.data_dir,
793
+ training=True,
794
+ split="train",
795
+ start=self.start,
796
+ end=self.end,
797
+ factor=self.factor,
798
+ depth_type=self.depth_type, # type: ignore
799
+ camera_type=self.camera_type, # type: ignore
800
+ use_median_filter=self.use_median_filter,
801
+ num_targets_per_frame=self.num_targets_per_frame,
802
+ max_steps=self.max_steps * self.batch_size,
803
+ load_from_cache=self.load_from_cache,
804
+ )
805
+ if self.train_dataset.has_validation:
806
+ guru.info("Loading val dataset...")
807
+ self.val_dataset = self.dataset_cls(
808
+ data_dir=self.data_dir,
809
+ training=False,
810
+ split="val",
811
+ start=self.start,
812
+ end=self.end,
813
+ factor=self.factor,
814
+ depth_type=self.depth_type, # type: ignore
815
+ camera_type=self.camera_type, # type: ignore
816
+ use_median_filter=self.use_median_filter,
817
+ scene_norm_dict=self.train_dataset.scene_norm_dict,
818
+ load_from_cache=self.load_from_cache,
819
+ )
820
+ else:
821
+ # Dummy validation set.
822
+ self.val_dataset = TensorDataset(torch.zeros(0)) # type: ignore
823
+ self.keypoint_dataset = iPhoneDatasetKeypointView(self.train_dataset)
824
+ self.video_dataset = self.train_dataset.get_video_dataset()
825
+ guru.success("Loading finished!")
826
+
827
+ def train_dataloader(self) -> DataLoader:
828
+ return DataLoader(
829
+ self.train_dataset,
830
+ batch_size=self.batch_size,
831
+ num_workers=self.num_workers,
832
+ collate_fn=iPhoneDataset.train_collate_fn,
833
+ )
834
+
835
+ def val_dataloader(self) -> list[DataLoader]:
836
+ return [DataLoader(self.val_dataset), DataLoader(self.keypoint_dataset)]
837
+ """
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/data/utils.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Optional, Tuple, TypedDict
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+ from torch.nn.modules.utils import _pair, _quadruple
8
+
9
+ UINT16_MAX = 65535
10
+
11
+
12
+ class SceneNormDict(TypedDict):
13
+ scale: float
14
+ transfm: torch.Tensor
15
+
16
+
17
+ def to_device(batch, device):
18
+ if isinstance(batch, dict):
19
+ return {k: to_device(v, device) for k, v in batch.items()}
20
+ if isinstance(batch, (list, tuple)):
21
+ return [to_device(v, device) for v in batch]
22
+ if isinstance(batch, torch.Tensor):
23
+ return batch.to(device)
24
+ return batch
25
+
26
+
27
+ def normalize_coords(coords, h, w):
28
+ assert coords.shape[-1] == 2
29
+ return coords / torch.tensor([w - 1.0, h - 1.0], device=coords.device) * 2 - 1.0
30
+
31
+
32
+ def postprocess_occlusions(occlusions, expected_dist):
33
+ """Postprocess occlusions to boolean visible flag.
34
+
35
+ Args:
36
+ occlusions: [-inf, inf], np.float32
37
+ expected_dist:, [-inf, inf], np.float32
38
+
39
+ Returns:
40
+ visibles: bool
41
+ """
42
+
43
+ def sigmoid(x):
44
+ if x.dtype == np.ndarray:
45
+ return 1 / (1 + np.exp(-x))
46
+ else:
47
+ return torch.sigmoid(x)
48
+
49
+ visibles = (1 - sigmoid(occlusions)) * (1 - sigmoid(expected_dist)) > 0.5
50
+ return visibles
51
+
52
+
53
+ def parse_tapir_track_info(occlusions, expected_dist):
54
+ """
55
+ return:
56
+ valid_visible: mask of visible & confident points
57
+ valid_invisible: mask of invisible & confident points
58
+ confidence: clamped confidence scores (all < 0.5 -> 0)
59
+ """
60
+ visiblility = 1 - F.sigmoid(occlusions)
61
+ confidence = 1 - F.sigmoid(expected_dist)
62
+ valid_visible = visiblility * confidence > 0.5
63
+ valid_invisible = (1 - visiblility) * confidence > 0.5
64
+ # set all confidence < 0.5 to 0
65
+ confidence = confidence * (valid_visible | valid_invisible).float()
66
+ return valid_visible, valid_invisible, confidence
67
+
68
+
69
+ def get_tracks_3d_for_query_frame(
70
+ query_index: int,
71
+ query_img: torch.Tensor,
72
+ tracks_2d: torch.Tensor,
73
+ depths: torch.Tensor,
74
+ masks: torch.Tensor,
75
+ inv_Ks: torch.Tensor,
76
+ c2ws: torch.Tensor,
77
+ ):
78
+ """
79
+ :param query_index (int)
80
+ :param query_img [H, W, 3]
81
+ :param tracks_2d [N, T, 4]
82
+ :param depths [T, H, W]
83
+ :param masks [T, H, W]
84
+ :param inv_Ks [T, 3, 3]
85
+ :param c2ws [T, 4, 4]
86
+ returns (
87
+ tracks_3d [N, T, 3]
88
+ track_colors [N, 3]
89
+ visibles [N, T]
90
+ invisibles [N, T]
91
+ confidences [N, T]
92
+ )
93
+ """
94
+ T, H, W = depths.shape
95
+ query_img = query_img[None].permute(0, 3, 1, 2) # (1, 3, H, W)
96
+ tracks_2d = tracks_2d.swapaxes(0, 1) # (T, N, 4)
97
+ tracks_2d, occs, dists = (
98
+ tracks_2d[..., :2],
99
+ tracks_2d[..., 2],
100
+ tracks_2d[..., 3],
101
+ )
102
+ # visibles = postprocess_occlusions(occs, dists)
103
+ # (T, N), (T, N), (T, N)
104
+ visibles, invisibles, confidences = parse_tapir_track_info(occs, dists)
105
+ # Unproject 2D tracks to 3D.
106
+ # (T, 1, H, W), (T, 1, N, 2) -> (T, 1, 1, N)
107
+ track_depths = F.grid_sample(
108
+ depths[:, None],
109
+ normalize_coords(tracks_2d[:, None], H, W),
110
+ align_corners=True,
111
+ padding_mode="border",
112
+ )[:, 0, 0]
113
+ tracks_3d = (
114
+ torch.einsum(
115
+ "nij,npj->npi",
116
+ inv_Ks,
117
+ F.pad(tracks_2d, (0, 1), value=1.0),
118
+ )
119
+ * track_depths[..., None]
120
+ )
121
+ tracks_3d = torch.einsum("nij,npj->npi", c2ws, F.pad(tracks_3d, (0, 1), value=1.0))[
122
+ ..., :3
123
+ ]
124
+ # Filter out out-of-mask tracks.
125
+ # (T, 1, H, W), (T, 1, N, 2) -> (T, 1, 1, N)
126
+ is_in_masks = (
127
+ F.grid_sample(
128
+ masks[:, None],
129
+ normalize_coords(tracks_2d[:, None], H, W),
130
+ align_corners=True,
131
+ )[:, 0, 0]
132
+ == 1
133
+ )
134
+ visibles *= is_in_masks
135
+ invisibles *= is_in_masks
136
+ confidences *= is_in_masks.float()
137
+
138
+ # valid if in the fg mask at least 40% of the time
139
+ # in_mask_counts = is_in_masks.sum(0)
140
+ # t = 0.25
141
+ # thresh = min(t * T, in_mask_counts.float().quantile(t).item())
142
+ # valid = in_mask_counts > thresh
143
+ valid = is_in_masks[query_index]
144
+ # valid if visible 5% of the time
145
+ visible_counts = visibles.sum(0)
146
+ valid = valid & (
147
+ visible_counts
148
+ >= min(
149
+ int(0.05 * T),
150
+ visible_counts.float().quantile(0.1).item(),
151
+ )
152
+ )
153
+
154
+ # Get track's color from the query frame.
155
+ # (1, 3, H, W), (1, 1, N, 2) -> (1, 3, 1, N) -> (N, 3)
156
+ track_colors = F.grid_sample(
157
+ query_img,
158
+ normalize_coords(tracks_2d[query_index : query_index + 1, None], H, W),
159
+ align_corners=True,
160
+ padding_mode="border",
161
+ )[0, :, 0].T
162
+ return (
163
+ tracks_3d[:, valid].swapdims(0, 1),
164
+ track_colors[valid],
165
+ visibles[:, valid].swapdims(0, 1),
166
+ invisibles[:, valid].swapdims(0, 1),
167
+ confidences[:, valid].swapdims(0, 1),
168
+ )
169
+
170
+
171
+ def _get_padding(x, k, stride, padding, same: bool):
172
+ if same:
173
+ ih, iw = x.size()[2:]
174
+ if ih % stride[0] == 0:
175
+ ph = max(k[0] - stride[0], 0)
176
+ else:
177
+ ph = max(k[0] - (ih % stride[0]), 0)
178
+ if iw % stride[1] == 0:
179
+ pw = max(k[1] - stride[1], 0)
180
+ else:
181
+ pw = max(k[1] - (iw % stride[1]), 0)
182
+ pl = pw // 2
183
+ pr = pw - pl
184
+ pt = ph // 2
185
+ pb = ph - pt
186
+ padding = (pl, pr, pt, pb)
187
+ else:
188
+ padding = padding
189
+ return padding
190
+
191
+
192
+ def median_filter_2d(x, kernel_size=3, stride=1, padding=1, same: bool = True):
193
+ """
194
+ :param x [B, C, H, W]
195
+ """
196
+ k = _pair(kernel_size)
197
+ stride = _pair(stride) # convert to tuple
198
+ padding = _quadruple(padding) # convert to l, r, t, b
199
+ # using existing pytorch functions and tensor ops so that we get autograd,
200
+ # would likely be more efficient to implement from scratch at C/Cuda level
201
+ x = F.pad(x, _get_padding(x, k, stride, padding, same), mode="reflect")
202
+ x = x.unfold(2, k[0], stride[0]).unfold(3, k[1], stride[1])
203
+ x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0]
204
+ return x
205
+
206
+
207
+ def masked_median_blur(image, mask, kernel_size=11):
208
+ """
209
+ Args:
210
+ image: [B, C, H, W]
211
+ mask: [B, C, H, W]
212
+ kernel_size: int
213
+ """
214
+ assert image.shape == mask.shape
215
+ if not isinstance(image, torch.Tensor):
216
+ raise TypeError(f"Input type is not a torch.Tensor. Got {type(image)}")
217
+
218
+ if not len(image.shape) == 4:
219
+ raise ValueError(f"Invalid input shape, we expect BxCxHxW. Got: {image.shape}")
220
+
221
+ padding: Tuple[int, int] = _compute_zero_padding((kernel_size, kernel_size))
222
+
223
+ # prepare kernel
224
+ kernel: torch.Tensor = get_binary_kernel2d((kernel_size, kernel_size)).to(image)
225
+ b, c, h, w = image.shape
226
+
227
+ # map the local window to single vector
228
+ features: torch.Tensor = F.conv2d(
229
+ image.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1
230
+ )
231
+ masks: torch.Tensor = F.conv2d(
232
+ mask.reshape(b * c, 1, h, w), kernel, padding=padding, stride=1
233
+ )
234
+ features = features.view(b, c, -1, h, w).permute(
235
+ 0, 1, 3, 4, 2
236
+ ) # BxCxxHxWx(K_h * K_w)
237
+ min_value, max_value = features.min(), features.max()
238
+ masks = masks.view(b, c, -1, h, w).permute(0, 1, 3, 4, 2) # BxCxHxWx(K_h * K_w)
239
+ index_invalid = (1 - masks).nonzero(as_tuple=True)
240
+ index_b, index_c, index_h, index_w, index_k = index_invalid
241
+ features[(index_b[::2], index_c[::2], index_h[::2], index_w[::2], index_k[::2])] = (
242
+ min_value
243
+ )
244
+ features[
245
+ (index_b[1::2], index_c[1::2], index_h[1::2], index_w[1::2], index_k[1::2])
246
+ ] = max_value
247
+ # compute the median along the feature axis
248
+ median: torch.Tensor = torch.median(features, dim=-1)[0]
249
+
250
+ return median
251
+
252
+
253
+ def _compute_zero_padding(kernel_size: Tuple[int, int]) -> Tuple[int, int]:
254
+ r"""Utility function that computes zero padding tuple."""
255
+ computed: List[int] = [(k - 1) // 2 for k in kernel_size]
256
+ return computed[0], computed[1]
257
+
258
+
259
+ def get_binary_kernel2d(
260
+ window_size: tuple[int, int] | int,
261
+ *,
262
+ device: Optional[torch.device] = None,
263
+ dtype: torch.dtype = torch.float32,
264
+ ) -> torch.Tensor:
265
+ """
266
+ from kornia
267
+ Create a binary kernel to extract the patches.
268
+ If the window size is HxW will create a (H*W)x1xHxW kernel.
269
+ """
270
+ ky, kx = _unpack_2d_ks(window_size)
271
+
272
+ window_range = kx * ky
273
+
274
+ kernel = torch.zeros((window_range, window_range), device=device, dtype=dtype)
275
+ idx = torch.arange(window_range, device=device)
276
+ kernel[idx, idx] += 1.0
277
+ return kernel.view(window_range, 1, ky, kx)
278
+
279
+
280
+ def _unpack_2d_ks(kernel_size: tuple[int, int] | int) -> tuple[int, int]:
281
+ if isinstance(kernel_size, int):
282
+ ky = kx = kernel_size
283
+ else:
284
+ assert len(kernel_size) == 2, "2D Kernel size should have a length of 2."
285
+ ky, kx = kernel_size
286
+
287
+ ky = int(ky)
288
+ kx = int(kx)
289
+
290
+ return (ky, kx)
291
+
292
+
293
+ ## Functions from GaussianShader.
294
+ def ndc_2_cam(ndc_xyz, intrinsic, W, H):
295
+ inv_scale = torch.tensor([[W - 1, H - 1]], device=ndc_xyz.device)
296
+ cam_z = ndc_xyz[..., 2:3]
297
+ cam_xy = ndc_xyz[..., :2] * inv_scale * cam_z
298
+ cam_xyz = torch.cat([cam_xy, cam_z], dim=-1)
299
+ cam_xyz = cam_xyz @ torch.inverse(intrinsic[0, ...].t())
300
+ return cam_xyz
301
+
302
+
303
+ def depth2point_cam(sampled_depth, ref_intrinsic):
304
+ B, N, C, H, W = sampled_depth.shape
305
+ valid_z = sampled_depth
306
+ valid_x = torch.arange(W, dtype=torch.float32, device=sampled_depth.device) / (
307
+ W - 1
308
+ )
309
+ valid_y = torch.arange(H, dtype=torch.float32, device=sampled_depth.device) / (
310
+ H - 1
311
+ )
312
+ valid_y, valid_x = torch.meshgrid(valid_y, valid_x, indexing="ij")
313
+ # B,N,H,W
314
+ valid_x = valid_x[None, None, None, ...].expand(B, N, C, -1, -1)
315
+ valid_y = valid_y[None, None, None, ...].expand(B, N, C, -1, -1)
316
+ ndc_xyz = torch.stack([valid_x, valid_y, valid_z], dim=-1).view(
317
+ B, N, C, H, W, 3
318
+ ) # 1, 1, 5, 512, 640, 3
319
+ cam_xyz = ndc_2_cam(ndc_xyz, ref_intrinsic, W, H) # 1, 1, 5, 512, 640, 3
320
+ return ndc_xyz, cam_xyz
321
+
322
+
323
+ def depth2point_world(depth_image, intrinsic_matrix, extrinsic_matrix):
324
+ # depth_image: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4)
325
+ _, xyz_cam = depth2point_cam(
326
+ depth_image[None, None, None, ...], intrinsic_matrix[None, ...]
327
+ )
328
+ xyz_cam = xyz_cam.reshape(-1, 3)
329
+ xyz_world = torch.cat(
330
+ [xyz_cam, torch.ones_like(xyz_cam[..., 0:1])], dim=-1
331
+ ) @ torch.inverse(extrinsic_matrix).transpose(0, 1)
332
+ xyz_world = xyz_world[..., :3]
333
+
334
+ return xyz_world
335
+
336
+
337
+ def depth_pcd2normal(xyz):
338
+ hd, wd, _ = xyz.shape
339
+ bottom_point = xyz[..., 2:hd, 1 : wd - 1, :]
340
+ top_point = xyz[..., 0 : hd - 2, 1 : wd - 1, :]
341
+ right_point = xyz[..., 1 : hd - 1, 2:wd, :]
342
+ left_point = xyz[..., 1 : hd - 1, 0 : wd - 2, :]
343
+ left_to_right = right_point - left_point
344
+ bottom_to_top = top_point - bottom_point
345
+ xyz_normal = torch.cross(left_to_right, bottom_to_top, dim=-1)
346
+ xyz_normal = torch.nn.functional.normalize(xyz_normal, p=2, dim=-1)
347
+ xyz_normal = torch.nn.functional.pad(
348
+ xyz_normal.permute(2, 0, 1), (1, 1, 1, 1), mode="constant"
349
+ ).permute(1, 2, 0)
350
+ return xyz_normal
351
+
352
+
353
+ def normal_from_depth_image(depth, intrinsic_matrix, extrinsic_matrix):
354
+ # depth: (H, W), intrinsic_matrix: (3, 3), extrinsic_matrix: (4, 4)
355
+ # xyz_normal: (H, W, 3)
356
+ xyz_world = depth2point_world(depth, intrinsic_matrix, extrinsic_matrix) # (HxW, 3)
357
+ xyz_world = xyz_world.reshape(*depth.shape, 3)
358
+ xyz_normal = depth_pcd2normal(xyz_world)
359
+
360
+ return xyz_normal
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/init_utils.py ADDED
@@ -0,0 +1,650 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ from typing import Literal
3
+
4
+ import cupy as cp
5
+ import imageio.v3 as iio
6
+ import numpy as np
7
+
8
+ # from pytorch3d.ops import sample_farthest_points
9
+ import roma
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from cuml import HDBSCAN, KMeans
13
+ from loguru import logger as guru
14
+ from matplotlib.pyplot import get_cmap
15
+ from tqdm import tqdm
16
+ from viser import ViserServer
17
+
18
+ from flow3d.loss_utils import (
19
+ compute_accel_loss,
20
+ compute_se3_smoothness_loss,
21
+ compute_z_acc_loss,
22
+ get_weights_for_procrustes,
23
+ knn,
24
+ masked_l1_loss,
25
+ )
26
+ from flow3d.params import GaussianParams, MotionBases
27
+ from flow3d.tensor_dataclass import StaticObservations, TrackObservations
28
+ from flow3d.transforms import cont_6d_to_rmat, rt_to_mat4, solve_procrustes
29
+ from flow3d.vis.utils import draw_keypoints_video, get_server, project_2d_tracks
30
+
31
+
32
+ def init_fg_from_tracks_3d(
33
+ cano_t: int, tracks_3d: TrackObservations, motion_coefs: torch.Tensor
34
+ ) -> GaussianParams:
35
+ """
36
+ using dataclasses individual tensors so we know they're consistent
37
+ and are always masked/filtered together
38
+ """
39
+ num_fg = tracks_3d.xyz.shape[0]
40
+
41
+ # Initialize gaussian colors.
42
+ colors = torch.logit(tracks_3d.colors)
43
+ # Initialize gaussian features.
44
+ features = torch.rand(num_fg, 384) #### MUST BE CHANGE
45
+ # Initialize gaussian scales: find the average of the three nearest
46
+ # neighbors in the first frame for each point and use that as the
47
+ # scale.
48
+ dists, _ = knn(tracks_3d.xyz[:, cano_t], 3)
49
+ dists = torch.from_numpy(dists)
50
+ scales = dists.mean(dim=-1, keepdim=True)
51
+ scales = scales.clamp(torch.quantile(scales, 0.05), torch.quantile(scales, 0.95))
52
+ scales = torch.log(scales.repeat(1, 3))
53
+ # Initialize gaussian means.
54
+ means = tracks_3d.xyz[:, cano_t]
55
+ # Initialize gaussian orientations as random.
56
+ quats = torch.rand(num_fg, 4)
57
+ # Initialize gaussian opacities.
58
+ opacities = torch.logit(torch.full((num_fg,), 0.7))
59
+ gaussians = GaussianParams(means, quats, scales, colors, features, opacities, motion_coefs)
60
+ return gaussians
61
+
62
+
63
+ def init_bg(
64
+ points: StaticObservations,
65
+ ) -> GaussianParams:
66
+ """
67
+ using dataclasses instead of individual tensors so we know they're consistent
68
+ and are always masked/filtered together
69
+ """
70
+ num_init_bg_gaussians = points.xyz.shape[0]
71
+ bg_scene_center = points.xyz.mean(0)
72
+ bg_points_centered = points.xyz - bg_scene_center
73
+ bg_min_scale = bg_points_centered.quantile(0.05, dim=0)
74
+ bg_max_scale = bg_points_centered.quantile(0.95, dim=0)
75
+ bg_scene_scale = torch.max(bg_max_scale - bg_min_scale).item() / 2.0
76
+ bkdg_colors = torch.logit(points.colors)
77
+
78
+ # Initialize gaussian features.
79
+ bg_features = torch.rand(num_init_bg_gaussians, 384) #### MUST BE CHANGE
80
+
81
+ # Initialize gaussian scales: find the average of the three nearest
82
+ # neighbors in the first frame for each point and use that as the
83
+ # scale.
84
+ dists, _ = knn(points.xyz, 3)
85
+ dists = torch.from_numpy(dists)
86
+ bg_scales = dists.mean(dim=-1, keepdim=True)
87
+ bkdg_scales = torch.log(bg_scales.repeat(1, 3))
88
+
89
+ bg_means = points.xyz
90
+
91
+ # Initialize gaussian orientations by normals.
92
+ local_normals = points.normals.new_tensor([[0.0, 0.0, 1.0]]).expand_as(
93
+ points.normals
94
+ )
95
+ bg_quats = roma.rotvec_to_unitquat(
96
+ F.normalize(local_normals.cross(points.normals), dim=-1)
97
+ * (local_normals * points.normals).sum(-1, keepdim=True).acos_()
98
+ ).roll(1, dims=-1)
99
+ bg_opacities = torch.logit(torch.full((num_init_bg_gaussians,), 0.7))
100
+ gaussians = GaussianParams(
101
+ bg_means,
102
+ bg_quats,
103
+ bkdg_scales,
104
+ bkdg_colors,
105
+ bg_features,
106
+ bg_opacities,
107
+ scene_center=bg_scene_center,
108
+ scene_scale=bg_scene_scale,
109
+ )
110
+ return gaussians
111
+
112
+
113
+ def init_motion_params_with_procrustes(
114
+ tracks_3d: TrackObservations,
115
+ num_bases: int,
116
+ rot_type: Literal["quat", "6d"],
117
+ cano_t: int,
118
+ cluster_init_method: str = "kmeans",
119
+ min_mean_weight: float = 0.1,
120
+ vis: bool = False,
121
+ port: int | None = None,
122
+ ) -> tuple[MotionBases, torch.Tensor, TrackObservations]:
123
+ device = tracks_3d.xyz.device
124
+ num_frames = tracks_3d.xyz.shape[1]
125
+ # sample centers and get initial se3 motion bases by solving procrustes
126
+ means_cano = tracks_3d.xyz[:, cano_t].clone() # [num_gaussians, 3]
127
+
128
+ # remove outliers
129
+ scene_center = means_cano.median(dim=0).values
130
+ print(f"{scene_center=}")
131
+ dists = torch.norm(means_cano - scene_center, dim=-1)
132
+ dists_th = torch.quantile(dists, 0.95)
133
+ valid_mask = dists < dists_th
134
+
135
+ # remove tracks that are not visible in any frame
136
+ valid_mask = valid_mask & tracks_3d.visibles.any(dim=1)
137
+ print(f"{valid_mask.sum()=}")
138
+
139
+ tracks_3d = tracks_3d.filter_valid(valid_mask)
140
+
141
+ if vis and port is not None:
142
+ server = get_server(port)
143
+ try:
144
+ pts = tracks_3d.xyz.cpu().numpy()
145
+ clrs = tracks_3d.colors.cpu().numpy()
146
+ while True:
147
+ for t in range(num_frames):
148
+ server.scene.add_point_cloud("points", pts[:, t], clrs)
149
+ time.sleep(0.3)
150
+ except KeyboardInterrupt:
151
+ pass
152
+
153
+ means_cano = means_cano[valid_mask]
154
+
155
+ sampled_centers, num_bases, labels = sample_initial_bases_centers(
156
+ cluster_init_method, cano_t, tracks_3d, num_bases
157
+ )
158
+
159
+ # assign each point to the label to compute the cluster weight
160
+ ids, counts = labels.unique(return_counts=True)
161
+ ids = ids[counts > 100]
162
+ num_bases = len(ids)
163
+ sampled_centers = sampled_centers[:, ids]
164
+ print(f"{num_bases=} {sampled_centers.shape=}")
165
+
166
+ # compute basis weights from the distance to the cluster centers
167
+ dists2centers = torch.norm(means_cano[:, None] - sampled_centers, dim=-1)
168
+ motion_coefs = 10 * torch.exp(-dists2centers)
169
+
170
+ init_rots, init_ts = [], []
171
+
172
+ if rot_type == "quat":
173
+ id_rot = torch.tensor([1.0, 0.0, 0.0, 0.0], device=device)
174
+ rot_dim = 4
175
+ else:
176
+ id_rot = torch.tensor([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], device=device)
177
+ rot_dim = 6
178
+
179
+ init_rots = id_rot.reshape(1, 1, rot_dim).repeat(num_bases, num_frames, 1)
180
+ init_ts = torch.zeros(num_bases, num_frames, 3, device=device)
181
+ errs_before = np.full((num_bases, num_frames), -1.0)
182
+ errs_after = np.full((num_bases, num_frames), -1.0)
183
+
184
+ tgt_ts = list(range(cano_t - 1, -1, -1)) + list(range(cano_t, num_frames))
185
+ print(f"{tgt_ts=}")
186
+ skipped_ts = {}
187
+ for n, cluster_id in enumerate(ids):
188
+ mask_in_cluster = labels == cluster_id
189
+ cluster = tracks_3d.xyz[mask_in_cluster].transpose(
190
+ 0, 1
191
+ ) # [num_frames, n_pts, 3]
192
+ visibilities = tracks_3d.visibles[mask_in_cluster].swapaxes(
193
+ 0, 1
194
+ ) # [num_frames, n_pts]
195
+ confidences = tracks_3d.confidences[mask_in_cluster].swapaxes(
196
+ 0, 1
197
+ ) # [num_frames, n_pts]
198
+ weights = get_weights_for_procrustes(cluster, visibilities)
199
+ prev_t = cano_t
200
+ cluster_skip_ts = []
201
+ for cur_t in tgt_ts:
202
+ # compute pairwise transform from cano_t
203
+ procrustes_weights = (
204
+ weights[cano_t]
205
+ * weights[cur_t]
206
+ * (confidences[cano_t] + confidences[cur_t])
207
+ / 2
208
+ )
209
+ if procrustes_weights.sum() < min_mean_weight * num_frames:
210
+ init_rots[n, cur_t] = init_rots[n, prev_t]
211
+ init_ts[n, cur_t] = init_ts[n, prev_t]
212
+ cluster_skip_ts.append(cur_t)
213
+ else:
214
+ se3, (err, err_before) = solve_procrustes(
215
+ cluster[cano_t],
216
+ cluster[cur_t],
217
+ weights=procrustes_weights,
218
+ enforce_se3=True,
219
+ rot_type=rot_type,
220
+ )
221
+ init_rot, init_t, _ = se3
222
+ assert init_rot.shape[-1] == rot_dim
223
+ # double cover
224
+ if rot_type == "quat" and torch.linalg.norm(
225
+ init_rot - init_rots[n][prev_t]
226
+ ) > torch.linalg.norm(-init_rot - init_rots[n][prev_t]):
227
+ init_rot = -init_rot
228
+ init_rots[n, cur_t] = init_rot
229
+ init_ts[n, cur_t] = init_t
230
+ if err == np.nan:
231
+ print(f"{cur_t=} {err=}")
232
+ print(f"{procrustes_weights.isnan().sum()=}")
233
+ if err_before == np.nan:
234
+ print(f"{cur_t=} {err_before=}")
235
+ print(f"{procrustes_weights.isnan().sum()=}")
236
+ errs_after[n, cur_t] = err
237
+ errs_before[n, cur_t] = err_before
238
+ prev_t = cur_t
239
+ skipped_ts[cluster_id.item()] = cluster_skip_ts
240
+
241
+ guru.info(f"{skipped_ts=}")
242
+ guru.info(
243
+ "procrustes init median error: {:.5f} => {:.5f}".format(
244
+ np.median(errs_before[errs_before > 0]),
245
+ np.median(errs_after[errs_after > 0]),
246
+ )
247
+ )
248
+ guru.info(
249
+ "procrustes init mean error: {:.5f} => {:.5f}".format(
250
+ np.mean(errs_before[errs_before > 0]), np.mean(errs_after[errs_after > 0])
251
+ )
252
+ )
253
+ guru.info(f"{init_rots.shape=}, {init_ts.shape=}, {motion_coefs.shape=}")
254
+
255
+ if vis:
256
+ server = get_server(port)
257
+ center_idcs = torch.argmin(dists2centers, dim=0)
258
+ print(f"{dists2centers.shape=} {center_idcs.shape=}")
259
+ vis_se3_init_3d(server, init_rots, init_ts, means_cano[center_idcs])
260
+ vis_tracks_3d(server, tracks_3d.xyz[center_idcs].numpy(), name="center_tracks")
261
+ import ipdb
262
+
263
+ ipdb.set_trace()
264
+
265
+ bases = MotionBases(init_rots, init_ts)
266
+ return bases, motion_coefs, tracks_3d
267
+
268
+
269
+ def run_initial_optim(
270
+ fg: GaussianParams,
271
+ bases: MotionBases,
272
+ tracks_3d: TrackObservations,
273
+ Ks: torch.Tensor,
274
+ w2cs: torch.Tensor,
275
+ num_iters: int = 1000,
276
+ use_depth_range_loss: bool = False,
277
+ ):
278
+ """
279
+ :param motion_rots: [num_bases, num_frames, 4|6]
280
+ :param motion_transls: [num_bases, num_frames, 3]
281
+ :param motion_coefs: [num_bases, num_frames]
282
+ :param means: [num_gaussians, 3]
283
+ """
284
+ optimizer = torch.optim.Adam(
285
+ [
286
+ {"params": bases.params["rots"], "lr": 1e-2},
287
+ {"params": bases.params["transls"], "lr": 3e-2},
288
+ {"params": fg.params["motion_coefs"], "lr": 1e-2},
289
+ {"params": fg.params["means"], "lr": 1e-3},
290
+ ],
291
+ )
292
+ scheduler = torch.optim.lr_scheduler.ExponentialLR(
293
+ optimizer, gamma=0.1 ** (1 / num_iters)
294
+ )
295
+ G = fg.params.means.shape[0]
296
+ num_frames = bases.num_frames
297
+ device = bases.params["rots"].device
298
+
299
+ w_smooth_func = lambda i, min_v, max_v, th: (
300
+ min_v if i <= th else (max_v - min_v) * (i - th) / (num_iters - th) + min_v
301
+ )
302
+
303
+ gt_2d, gt_depth = project_2d_tracks(
304
+ tracks_3d.xyz.swapaxes(0, 1), Ks, w2cs, return_depth=True
305
+ )
306
+ # (G, T, 2)
307
+ gt_2d = gt_2d.swapaxes(0, 1)
308
+ # (G, T)
309
+ gt_depth = gt_depth.swapaxes(0, 1)
310
+
311
+ ts = torch.arange(0, num_frames, device=device)
312
+ ts_clamped = torch.clamp(ts, min=1, max=num_frames - 2)
313
+ ts_neighbors = torch.cat((ts_clamped - 1, ts_clamped, ts_clamped + 1)) # i (3B,)
314
+
315
+ pbar = tqdm(range(0, num_iters))
316
+ for i in pbar:
317
+ coefs = fg.get_coefs()
318
+ transfms = bases.compute_transforms(ts, coefs)
319
+ positions = torch.einsum(
320
+ "pnij,pj->pni",
321
+ transfms,
322
+ F.pad(fg.params["means"], (0, 1), value=1.0),
323
+ )
324
+
325
+ loss = 0.0
326
+ track_3d_loss = masked_l1_loss(
327
+ positions,
328
+ tracks_3d.xyz,
329
+ (tracks_3d.visibles.float() * tracks_3d.confidences)[..., None],
330
+ )
331
+ loss += track_3d_loss * 1.0
332
+
333
+ pred_2d, pred_depth = project_2d_tracks(
334
+ positions.swapaxes(0, 1), Ks, w2cs, return_depth=True
335
+ )
336
+ pred_2d = pred_2d.swapaxes(0, 1)
337
+ pred_depth = pred_depth.swapaxes(0, 1)
338
+
339
+ loss_2d = (
340
+ masked_l1_loss(
341
+ pred_2d,
342
+ gt_2d,
343
+ (tracks_3d.invisibles.float() * tracks_3d.confidences)[..., None],
344
+ quantile=0.95,
345
+ )
346
+ / Ks[0, 0, 0]
347
+ )
348
+ loss += 0.5 * loss_2d
349
+
350
+ if use_depth_range_loss:
351
+ near_depths = torch.quantile(gt_depth, 0.0, dim=0, keepdim=True)
352
+ far_depths = torch.quantile(gt_depth, 0.98, dim=0, keepdim=True)
353
+ loss_depth_in_range = 0
354
+ if (pred_depth < near_depths).any():
355
+ loss_depth_in_range += (near_depths - pred_depth)[
356
+ pred_depth < near_depths
357
+ ].mean()
358
+ if (pred_depth > far_depths).any():
359
+ loss_depth_in_range += (pred_depth - far_depths)[
360
+ pred_depth > far_depths
361
+ ].mean()
362
+
363
+ loss += loss_depth_in_range * w_smooth_func(i, 0.05, 0.5, 400)
364
+
365
+ motion_coef_sparse_loss = 1 - (coefs**2).sum(dim=-1).mean()
366
+ loss += motion_coef_sparse_loss * 0.01
367
+
368
+ # motion basis should be smooth.
369
+ w_smooth = w_smooth_func(i, 0.01, 0.1, 400)
370
+ small_acc_loss = compute_se3_smoothness_loss(
371
+ bases.params["rots"], bases.params["transls"]
372
+ )
373
+ loss += small_acc_loss * w_smooth
374
+
375
+ small_acc_loss_tracks = compute_accel_loss(positions)
376
+ loss += small_acc_loss_tracks * w_smooth * 0.5
377
+
378
+ transfms_nbs = bases.compute_transforms(ts_neighbors, coefs)
379
+ means_nbs = torch.einsum(
380
+ "pnij,pj->pni", transfms_nbs, F.pad(fg.params["means"], (0, 1), value=1.0)
381
+ ) # (G, 3n, 3)
382
+ means_nbs = means_nbs.reshape(means_nbs.shape[0], 3, -1, 3) # [G, 3, n, 3]
383
+ z_accel_loss = compute_z_acc_loss(means_nbs, w2cs)
384
+ loss += z_accel_loss * 0.1
385
+
386
+ optimizer.zero_grad()
387
+ loss.backward()
388
+ optimizer.step()
389
+ scheduler.step()
390
+
391
+ pbar.set_description(
392
+ f"{loss.item():.3f} "
393
+ f"{track_3d_loss.item():.3f} "
394
+ f"{motion_coef_sparse_loss.item():.3f} "
395
+ f"{small_acc_loss.item():.3f} "
396
+ f"{small_acc_loss_tracks.item():.3f} "
397
+ f"{z_accel_loss.item():.3f} "
398
+ )
399
+
400
+
401
+ def random_quats(N: int) -> torch.Tensor:
402
+ u = torch.rand(N, 1)
403
+ v = torch.rand(N, 1)
404
+ w = torch.rand(N, 1)
405
+ quats = torch.cat(
406
+ [
407
+ torch.sqrt(1.0 - u) * torch.sin(2.0 * np.pi * v),
408
+ torch.sqrt(1.0 - u) * torch.cos(2.0 * np.pi * v),
409
+ torch.sqrt(u) * torch.sin(2.0 * np.pi * w),
410
+ torch.sqrt(u) * torch.cos(2.0 * np.pi * w),
411
+ ],
412
+ -1,
413
+ )
414
+ return quats
415
+
416
+
417
+ def compute_means(ts, fg: GaussianParams, bases: MotionBases):
418
+ transfms = bases.compute_transforms(ts, fg.get_coefs())
419
+ means = torch.einsum(
420
+ "pnij,pj->pni",
421
+ transfms,
422
+ F.pad(fg.params["means"], (0, 1), value=1.0),
423
+ )
424
+ return means
425
+
426
+
427
+ def vis_init_params(
428
+ server,
429
+ fg: GaussianParams,
430
+ bases: MotionBases,
431
+ name="init_params",
432
+ num_vis: int = 100,
433
+ ):
434
+ idcs = np.random.choice(fg.num_gaussians, num_vis)
435
+ labels = np.linspace(0, 1, num_vis)
436
+ ts = torch.arange(bases.num_frames, device=bases.params["rots"].device)
437
+ with torch.no_grad():
438
+ pred_means = compute_means(ts, fg, bases)
439
+ vis_means = pred_means[idcs].detach().cpu().numpy()
440
+ vis_tracks_3d(server, vis_means, labels, name=name)
441
+
442
+
443
+ @torch.no_grad()
444
+ def vis_se3_init_3d(server, init_rots, init_ts, basis_centers):
445
+ """
446
+ :param init_rots: [num_bases, num_frames, 4|6]
447
+ :param init_ts: [num_bases, num_frames, 3]
448
+ :param basis_centers: [num_bases, 3]
449
+ """
450
+ # visualize the initial centers across time
451
+ rot_dim = init_rots.shape[-1]
452
+ assert rot_dim in [4, 6]
453
+ num_bases = init_rots.shape[0]
454
+ assert init_ts.shape[0] == num_bases
455
+ assert basis_centers.shape[0] == num_bases
456
+ labels = np.linspace(0, 1, num_bases)
457
+ if rot_dim == 4:
458
+ quats = F.normalize(init_rots, dim=-1, p=2)
459
+ rmats = roma.unitquat_to_rotmat(quats.roll(-1, dims=-1))
460
+ else:
461
+ rmats = cont_6d_to_rmat(init_rots)
462
+ transls = init_ts
463
+ transfms = rt_to_mat4(rmats, transls)
464
+ center_tracks3d = torch.einsum(
465
+ "bnij,bj->bni", transfms, F.pad(basis_centers, (0, 1), value=1.0)
466
+ )[..., :3]
467
+ vis_tracks_3d(server, center_tracks3d.cpu().numpy(), labels, name="se3_centers")
468
+
469
+
470
+ @torch.no_grad()
471
+ def vis_tracks_2d_video(
472
+ path,
473
+ imgs: np.ndarray,
474
+ tracks_3d: np.ndarray,
475
+ Ks: np.ndarray,
476
+ w2cs: np.ndarray,
477
+ occs=None,
478
+ radius: int = 3,
479
+ ):
480
+ num_tracks = tracks_3d.shape[0]
481
+ labels = np.linspace(0, 1, num_tracks)
482
+ cmap = get_cmap("gist_rainbow")
483
+ colors = cmap(labels)[:, :3]
484
+ tracks_2d = (
485
+ project_2d_tracks(tracks_3d.swapaxes(0, 1), Ks, w2cs).cpu().numpy() # type: ignore
486
+ )
487
+ frames = np.asarray(
488
+ draw_keypoints_video(imgs, tracks_2d, colors, occs, radius=radius)
489
+ )
490
+ iio.imwrite(path, frames, fps=15)
491
+
492
+
493
+ def vis_tracks_3d(
494
+ server: ViserServer,
495
+ vis_tracks: np.ndarray,
496
+ vis_label: np.ndarray | None = None,
497
+ name: str = "tracks",
498
+ ):
499
+ """
500
+ :param vis_tracks (np.ndarray): (N, T, 3)
501
+ :param vis_label (np.ndarray): (N)
502
+ """
503
+ cmap = get_cmap("gist_rainbow")
504
+ if vis_label is None:
505
+ vis_label = np.linspace(0, 1, len(vis_tracks))
506
+ colors = cmap(np.asarray(vis_label))[:, :3]
507
+ guru.info(f"{colors.shape=}, {vis_tracks.shape=}")
508
+ N, T = vis_tracks.shape[:2]
509
+ vis_tracks = np.asarray(vis_tracks)
510
+ for i in range(N):
511
+ server.scene.add_spline_catmull_rom(
512
+ f"/{name}/{i}/spline", vis_tracks[i], color=colors[i], segments=T - 1
513
+ )
514
+ server.scene.add_point_cloud(
515
+ f"/{name}/{i}/start",
516
+ vis_tracks[i, [0]],
517
+ colors=colors[i : i + 1],
518
+ point_size=0.05,
519
+ point_shape="circle",
520
+ )
521
+ server.scene.add_point_cloud(
522
+ f"/{name}/{i}/end",
523
+ vis_tracks[i, [-1]],
524
+ colors=colors[i : i + 1],
525
+ point_size=0.05,
526
+ point_shape="diamond",
527
+ )
528
+
529
+
530
+ def sample_initial_bases_centers(
531
+ mode: str, cano_t: int, tracks_3d: TrackObservations, num_bases: int
532
+ ):
533
+ """
534
+ :param mode: "farthest" | "hdbscan" | "kmeans"
535
+ :param tracks_3d: [G, T, 3]
536
+ :param cano_t: canonical index
537
+ :param num_bases: number of SE3 bases
538
+ """
539
+ assert mode in ["farthest", "hdbscan", "kmeans"]
540
+ means_canonical = tracks_3d.xyz[:, cano_t].clone()
541
+ # if mode == "farthest":
542
+ # vis_mask = tracks_3d.visibles[:, cano_t]
543
+ # sampled_centers, _ = sample_farthest_points(
544
+ # means_canonical[vis_mask][None],
545
+ # K=num_bases,
546
+ # random_start_point=True,
547
+ # ) # [1, num_bases, 3]
548
+ # dists2centers = torch.norm(means_canonical[:, None] - sampled_centers, dim=-1).T
549
+ # return sampled_centers, num_bases, dists2centers
550
+
551
+ # linearly interpolate missing 3d points
552
+ xyz = cp.asarray(tracks_3d.xyz)
553
+ print(f"{xyz.shape=}")
554
+ visibles = cp.asarray(tracks_3d.visibles)
555
+
556
+ num_tracks = xyz.shape[0]
557
+ xyz_interp = batched_interp_masked(xyz, visibles)
558
+
559
+ # num_vis = 50
560
+ # server = get_server(port=8890)
561
+ # idcs = np.random.choice(num_tracks, num_vis)
562
+ # labels = np.linspace(0, 1, num_vis)
563
+ # vis_tracks_3d(server, tracks_3d.xyz[idcs].get(), labels, name="raw_tracks")
564
+ # vis_tracks_3d(server, xyz_interp[idcs].get(), labels, name="interp_tracks")
565
+ # import ipdb; ipdb.set_trace()
566
+
567
+ velocities = xyz_interp[:, 1:] - xyz_interp[:, :-1]
568
+ vel_dirs = (
569
+ velocities / (cp.linalg.norm(velocities, axis=-1, keepdims=True) + 1e-5)
570
+ ).reshape((num_tracks, -1))
571
+
572
+ # [num_bases, num_gaussians]
573
+ if mode == "kmeans":
574
+ model = KMeans(n_clusters=num_bases)
575
+ else:
576
+ model = HDBSCAN(min_cluster_size=20, max_cluster_size=num_tracks // 4)
577
+ model.fit(vel_dirs)
578
+ labels = model.labels_
579
+ num_bases = labels.max().item() + 1
580
+ sampled_centers = torch.stack(
581
+ [
582
+ means_canonical[torch.tensor(labels == i)].median(dim=0).values
583
+ for i in range(num_bases)
584
+ ]
585
+ )[None]
586
+ print("number of {} clusters: ".format(mode), num_bases)
587
+ return sampled_centers, num_bases, torch.tensor(labels)
588
+
589
+
590
+ def interp_masked(vals: cp.ndarray, mask: cp.ndarray, pad: int = 1) -> cp.ndarray:
591
+ """
592
+ hacky way to interpolate batched with cupy
593
+ by concatenating the batches and pad with dummy values
594
+ :param vals: [B, M, *]
595
+ :param mask: [B, M]
596
+ """
597
+ assert mask.ndim == 2
598
+ assert vals.shape[:2] == mask.shape
599
+
600
+ B, M = mask.shape
601
+
602
+ # get the first and last valid values for each track
603
+ sh = vals.shape[2:]
604
+ vals = vals.reshape((B, M, -1))
605
+ D = vals.shape[-1]
606
+ first_val_idcs = cp.argmax(mask, axis=-1)
607
+ last_val_idcs = M - 1 - cp.argmax(cp.flip(mask, axis=-1), axis=-1)
608
+ bidcs = cp.arange(B)
609
+
610
+ v0 = vals[bidcs, first_val_idcs][:, None]
611
+ v1 = vals[bidcs, last_val_idcs][:, None]
612
+ m0 = mask[bidcs, first_val_idcs][:, None]
613
+ m1 = mask[bidcs, last_val_idcs][:, None]
614
+ if pad > 1:
615
+ v0 = cp.tile(v0, [1, pad, 1])
616
+ v1 = cp.tile(v1, [1, pad, 1])
617
+ m0 = cp.tile(m0, [1, pad])
618
+ m1 = cp.tile(m1, [1, pad])
619
+
620
+ vals_pad = cp.concatenate([v0, vals, v1], axis=1)
621
+ mask_pad = cp.concatenate([m0, mask, m1], axis=1)
622
+
623
+ M_pad = vals_pad.shape[1]
624
+ vals_flat = vals_pad.reshape((B * M_pad, -1))
625
+ mask_flat = mask_pad.reshape((B * M_pad,))
626
+ idcs = cp.where(mask_flat)[0]
627
+
628
+ cx = cp.arange(B * M_pad)
629
+ out = cp.zeros((B * M_pad, D), dtype=vals_flat.dtype)
630
+ for d in range(D):
631
+ out[:, d] = cp.interp(cx, idcs, vals_flat[idcs, d])
632
+
633
+ out = out.reshape((B, M_pad, *sh))[:, pad:-pad]
634
+ return out
635
+
636
+
637
+ def batched_interp_masked(
638
+ vals: cp.ndarray, mask: cp.ndarray, batch_num: int = 4096, batch_time: int = 64
639
+ ):
640
+ assert mask.ndim == 2
641
+ B, M = mask.shape
642
+ out = cp.zeros_like(vals)
643
+ for b in tqdm(range(0, B, batch_num), leave=False):
644
+ for m in tqdm(range(0, M, batch_time), leave=False):
645
+ x = interp_masked(
646
+ vals[b : b + batch_num, m : m + batch_time],
647
+ mask[b : b + batch_num, m : m + batch_time],
648
+ ) # (batch_num, batch_time, *)
649
+ out[b : b + batch_num, m : m + batch_time] = x
650
+ return out
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/loss_utils.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from sklearn.neighbors import NearestNeighbors
5
+
6
+
7
+ def masked_mse_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):
8
+ if mask is None:
9
+ return trimmed_mse_loss(pred, gt, quantile)
10
+ else:
11
+ sum_loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True)
12
+ quantile_mask = (
13
+ (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)
14
+ if quantile < 1
15
+ else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
16
+ )
17
+ ndim = sum_loss.shape[-1]
18
+ if normalize:
19
+ return torch.sum((sum_loss * mask)[quantile_mask]) / (
20
+ ndim * torch.sum(mask[quantile_mask]) + 1e-8
21
+ )
22
+ else:
23
+ return torch.mean((sum_loss * mask)[quantile_mask])
24
+
25
+
26
+ def masked_l1_loss(pred, gt, mask=None, normalize=True, quantile: float = 1.0):
27
+ if mask is None:
28
+ return trimmed_l1_loss(pred, gt, quantile)
29
+ else:
30
+ sum_loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1, keepdim=True)
31
+ quantile_mask = (
32
+ (sum_loss < torch.quantile(sum_loss, quantile)).squeeze(-1)
33
+ if quantile < 1
34
+ else torch.ones_like(sum_loss, dtype=torch.bool).squeeze(-1)
35
+ )
36
+ ndim = sum_loss.shape[-1]
37
+ if normalize:
38
+ return torch.sum((sum_loss * mask)[quantile_mask]) / (
39
+ ndim * torch.sum(mask[quantile_mask]) + 1e-8
40
+ )
41
+ else:
42
+ return torch.mean((sum_loss * mask)[quantile_mask])
43
+
44
+
45
+ def masked_huber_loss(pred, gt, delta, mask=None, normalize=True):
46
+ if mask is None:
47
+ return F.huber_loss(pred, gt, delta=delta)
48
+ else:
49
+ sum_loss = F.huber_loss(pred, gt, delta=delta, reduction="none")
50
+ ndim = sum_loss.shape[-1]
51
+ if normalize:
52
+ return torch.sum(sum_loss * mask) / (ndim * torch.sum(mask) + 1e-8)
53
+ else:
54
+ return torch.mean(sum_loss * mask)
55
+
56
+
57
+ def trimmed_mse_loss(pred, gt, quantile=0.9):
58
+ loss = F.mse_loss(pred, gt, reduction="none").mean(dim=-1)
59
+ loss_at_quantile = torch.quantile(loss, quantile)
60
+ trimmed_loss = loss[loss < loss_at_quantile].mean()
61
+ return trimmed_loss
62
+
63
+
64
+ def trimmed_l1_loss(pred, gt, quantile=0.9):
65
+ loss = F.l1_loss(pred, gt, reduction="none").mean(dim=-1)
66
+ loss_at_quantile = torch.quantile(loss, quantile)
67
+ trimmed_loss = loss[loss < loss_at_quantile].mean()
68
+ return trimmed_loss
69
+
70
+
71
+ def compute_gradient_loss(pred, gt, mask, quantile=0.98):
72
+ """
73
+ Compute gradient loss
74
+ pred: (batch_size, H, W, D) or (batch_size, H, W)
75
+ gt: (batch_size, H, W, D) or (batch_size, H, W)
76
+ mask: (batch_size, H, W), bool or float
77
+ """
78
+ # NOTE: messy need to be cleaned up
79
+ mask_x = mask[:, :, 1:] * mask[:, :, :-1]
80
+ mask_y = mask[:, 1:, :] * mask[:, :-1, :]
81
+ pred_grad_x = pred[:, :, 1:] - pred[:, :, :-1]
82
+ pred_grad_y = pred[:, 1:, :] - pred[:, :-1, :]
83
+ gt_grad_x = gt[:, :, 1:] - gt[:, :, :-1]
84
+ gt_grad_y = gt[:, 1:, :] - gt[:, :-1, :]
85
+ loss = masked_l1_loss(
86
+ pred_grad_x[mask_x][..., None], gt_grad_x[mask_x][..., None], quantile=quantile
87
+ ) + masked_l1_loss(
88
+ pred_grad_y[mask_y][..., None], gt_grad_y[mask_y][..., None], quantile=quantile
89
+ )
90
+ return loss
91
+
92
+
93
+ def knn(x: torch.Tensor, k: int) -> tuple[np.ndarray, np.ndarray]:
94
+ x = x.cpu().numpy()
95
+ knn_model = NearestNeighbors(
96
+ n_neighbors=k + 1, algorithm="auto", metric="euclidean"
97
+ ).fit(x)
98
+ distances, indices = knn_model.kneighbors(x)
99
+ return distances[:, 1:].astype(np.float32), indices[:, 1:].astype(np.float32)
100
+
101
+
102
+ def get_weights_for_procrustes(clusters, visibilities=None):
103
+ clusters_median = clusters.median(dim=-2, keepdim=True)[0]
104
+ dists2clusters_center = torch.norm(clusters - clusters_median, dim=-1)
105
+ dists2clusters_center /= dists2clusters_center.median(dim=-1, keepdim=True)[0]
106
+ weights = torch.exp(-dists2clusters_center)
107
+ weights /= weights.mean(dim=-1, keepdim=True) + 1e-6
108
+ if visibilities is not None:
109
+ weights *= visibilities.float() + 1e-6
110
+ invalid = dists2clusters_center > np.quantile(
111
+ dists2clusters_center.cpu().numpy(), 0.9
112
+ )
113
+ invalid |= torch.isnan(weights)
114
+ weights[invalid] = 0
115
+ return weights
116
+
117
+
118
+ def compute_z_acc_loss(means_ts_nb: torch.Tensor, w2cs: torch.Tensor):
119
+ """
120
+ :param means_ts (G, 3, B, 3)
121
+ :param w2cs (B, 4, 4)
122
+ return (float)
123
+ """
124
+ camera_center_t = torch.linalg.inv(w2cs)[:, :3, 3] # (B, 3)
125
+ ray_dir = F.normalize(
126
+ means_ts_nb[:, 1] - camera_center_t, p=2.0, dim=-1
127
+ ) # [G, B, 3]
128
+ # acc = 2 * means[:, 1] - means[:, 0] - means[:, 2] # [G, B, 3]
129
+ # acc_loss = (acc * ray_dir).sum(dim=-1).abs().mean()
130
+ acc_loss = (
131
+ ((means_ts_nb[:, 1] - means_ts_nb[:, 0]) * ray_dir).sum(dim=-1) ** 2
132
+ ).mean() + (
133
+ ((means_ts_nb[:, 2] - means_ts_nb[:, 1]) * ray_dir).sum(dim=-1) ** 2
134
+ ).mean()
135
+ return acc_loss
136
+
137
+
138
+ def compute_se3_smoothness_loss(
139
+ rots: torch.Tensor,
140
+ transls: torch.Tensor,
141
+ weight_rot: float = 1.0,
142
+ weight_transl: float = 2.0,
143
+ ):
144
+ """
145
+ central differences
146
+ :param motion_transls (K, T, 3)
147
+ :param motion_rots (K, T, 6)
148
+ """
149
+ r_accel_loss = compute_accel_loss(rots)
150
+ t_accel_loss = compute_accel_loss(transls)
151
+ return r_accel_loss * weight_rot + t_accel_loss * weight_transl
152
+
153
+
154
+ def compute_accel_loss(transls):
155
+ accel = 2 * transls[:, 1:-1] - transls[:, :-2] - transls[:, 2:]
156
+ loss = accel.norm(dim=-1).mean()
157
+ return loss
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/metrics.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torchmetrics.functional.image.lpips import _NoTrainLpips
7
+ from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
8
+ from torchmetrics.metric import Metric
9
+ from torchmetrics.utilities import dim_zero_cat
10
+ from torchmetrics.utilities.imports import _TORCHVISION_AVAILABLE
11
+
12
+
13
+ def compute_psnr(
14
+ preds: torch.Tensor,
15
+ targets: torch.Tensor,
16
+ masks: torch.Tensor | None = None,
17
+ ) -> float:
18
+ """
19
+ Args:
20
+ preds (torch.Tensor): (..., 3) predicted images in [0, 1].
21
+ targets (torch.Tensor): (..., 3) target images in [0, 1].
22
+ masks (torch.Tensor | None): (...,) optional binary masks where the
23
+ 1-regions will be taken into account.
24
+
25
+ Returns:
26
+ psnr (float): Peak signal-to-noise ratio.
27
+ """
28
+ if masks is None:
29
+ masks = torch.ones_like(preds[..., 0])
30
+ return (
31
+ -10.0
32
+ * torch.log(
33
+ F.mse_loss(
34
+ preds * masks[..., None],
35
+ targets * masks[..., None],
36
+ reduction="sum",
37
+ )
38
+ / masks.sum().clamp(min=1.0)
39
+ / 3.0
40
+ )
41
+ / np.log(10.0)
42
+ ).item()
43
+
44
+
45
+ def compute_pose_errors(
46
+ preds: torch.Tensor, targets: torch.Tensor
47
+ ) -> tuple[float, float, float]:
48
+ """
49
+ Args:
50
+ preds: (N, 4, 4) predicted camera poses.
51
+ targets: (N, 4, 4) target camera poses.
52
+
53
+ Returns:
54
+ ate (float): Absolute trajectory error.
55
+ rpe_t (float): Relative pose error in translation.
56
+ rpe_r (float): Relative pose error in rotation (degree).
57
+ """
58
+ # Compute ATE.
59
+ ate = torch.linalg.norm(preds[:, :3, -1] - targets[:, :3, -1], dim=-1).mean().item()
60
+ # Compute RPE_t and RPE_r.
61
+ # NOTE(hangg): It's important to use numpy here for the accuracy of RPE_r.
62
+ # torch has numerical issues for acos when the value is close to 1.0, i.e.
63
+ # RPE_r is supposed to be very small, and will result in artificially large
64
+ # error.
65
+ preds = preds.detach().cpu().numpy()
66
+ targets = targets.detach().cpu().numpy()
67
+ pred_rels = np.linalg.inv(preds[:-1]) @ preds[1:]
68
+ pred_rels = np.linalg.inv(preds[:-1]) @ preds[1:]
69
+ target_rels = np.linalg.inv(targets[:-1]) @ targets[1:]
70
+ error_rels = np.linalg.inv(target_rels) @ pred_rels
71
+ traces = error_rels[:, :3, :3].trace(axis1=-2, axis2=-1)
72
+ rpe_t = np.linalg.norm(error_rels[:, :3, -1], axis=-1).mean().item()
73
+ rpe_r = (
74
+ np.arccos(np.clip((traces - 1.0) / 2.0, -1.0, 1.0)).mean().item()
75
+ / np.pi
76
+ * 180.0
77
+ )
78
+ return ate, rpe_t, rpe_r
79
+
80
+
81
+ class mPSNR(PeakSignalNoiseRatio):
82
+ sum_squared_error: list[torch.Tensor]
83
+ total: list[torch.Tensor]
84
+
85
+ def __init__(self, **kwargs) -> None:
86
+ super().__init__(
87
+ data_range=1.0,
88
+ base=10.0,
89
+ dim=None,
90
+ reduction="elementwise_mean",
91
+ **kwargs,
92
+ )
93
+ self.add_state("sum_squared_error", default=[], dist_reduce_fx="cat")
94
+ self.add_state("total", default=[], dist_reduce_fx="cat")
95
+
96
+ def __len__(self) -> int:
97
+ return len(self.total)
98
+
99
+ def update(
100
+ self,
101
+ preds: torch.Tensor,
102
+ targets: torch.Tensor,
103
+ masks: torch.Tensor | None = None,
104
+ ):
105
+ """Update state with predictions and targets.
106
+
107
+ Args:
108
+ preds (torch.Tensor): (..., 3) float32 predicted images.
109
+ targets (torch.Tensor): (..., 3) float32 target images.
110
+ masks (torch.Tensor | None): (...,) optional binary masks where the
111
+ 1-regions will be taken into account.
112
+ """
113
+ if masks is None:
114
+ masks = torch.ones_like(preds[..., 0])
115
+ self.sum_squared_error.append(
116
+ torch.sum(torch.pow((preds - targets) * masks[..., None], 2))
117
+ )
118
+ self.total.append(masks.sum().to(torch.int64) * 3)
119
+
120
+ def compute(self) -> torch.Tensor:
121
+ """Compute peak signal-to-noise ratio over state."""
122
+ sum_squared_error = dim_zero_cat(self.sum_squared_error)
123
+ total = dim_zero_cat(self.total)
124
+ return -10.0 * torch.log(sum_squared_error / total).mean() / np.log(10.0)
125
+
126
+
127
+ class mSSIM(StructuralSimilarityIndexMeasure):
128
+ similarity: list
129
+
130
+ def __init__(self, **kwargs) -> None:
131
+ super().__init__(
132
+ reduction=None,
133
+ data_range=1.0,
134
+ return_full_image=False,
135
+ **kwargs,
136
+ )
137
+ assert isinstance(self.sigma, float)
138
+
139
+ def __len__(self) -> int:
140
+ return sum([s.shape[0] for s in self.similarity])
141
+
142
+ def update(
143
+ self,
144
+ preds: torch.Tensor,
145
+ targets: torch.Tensor,
146
+ masks: torch.Tensor | None = None,
147
+ ):
148
+ """Update state with predictions and targets.
149
+
150
+ Args:
151
+ preds (torch.Tensor): (B, H, W, 3) float32 predicted images.
152
+ targets (torch.Tensor): (B, H, W, 3) float32 target images.
153
+ masks (torch.Tensor | None): (B, H, W) optional binary masks where
154
+ the 1-regions will be taken into account.
155
+ """
156
+ if masks is None:
157
+ masks = torch.ones_like(preds[..., 0])
158
+
159
+ # Construct a 1D Gaussian blur filter.
160
+ assert isinstance(self.kernel_size, int)
161
+ hw = self.kernel_size // 2
162
+ shift = (2 * hw - self.kernel_size + 1) / 2
163
+ assert isinstance(self.sigma, float)
164
+ f_i = (
165
+ (torch.arange(self.kernel_size, device=preds.device) - hw + shift)
166
+ / self.sigma
167
+ ) ** 2
168
+ filt = torch.exp(-0.5 * f_i)
169
+ filt /= torch.sum(filt)
170
+
171
+ # Blur in x and y (faster than the 2D convolution).
172
+ def convolve2d(z, m, f):
173
+ # z: (B, H, W, C), m: (B, H, W), f: (Hf, Wf).
174
+ z = z.permute(0, 3, 1, 2)
175
+ m = m[:, None]
176
+ f = f[None, None].expand(z.shape[1], -1, -1, -1)
177
+ z_ = torch.nn.functional.conv2d(
178
+ z * m, f, padding="valid", groups=z.shape[1]
179
+ )
180
+ m_ = torch.nn.functional.conv2d(m, torch.ones_like(f[:1]), padding="valid")
181
+ return torch.where(
182
+ m_ != 0, z_ * torch.ones_like(f).sum() / (m_ * z.shape[1]), 0
183
+ ).permute(0, 2, 3, 1), (m_ != 0)[:, 0].to(z.dtype)
184
+
185
+ filt_fn1 = lambda z, m: convolve2d(z, m, filt[:, None])
186
+ filt_fn2 = lambda z, m: convolve2d(z, m, filt[None, :])
187
+ filt_fn = lambda z, m: filt_fn1(*filt_fn2(z, m))
188
+
189
+ mu0 = filt_fn(preds, masks)[0]
190
+ mu1 = filt_fn(targets, masks)[0]
191
+ mu00 = mu0 * mu0
192
+ mu11 = mu1 * mu1
193
+ mu01 = mu0 * mu1
194
+ sigma00 = filt_fn(preds**2, masks)[0] - mu00
195
+ sigma11 = filt_fn(targets**2, masks)[0] - mu11
196
+ sigma01 = filt_fn(preds * targets, masks)[0] - mu01
197
+
198
+ # Clip the variances and covariances to valid values.
199
+ # Variance must be non-negative:
200
+ sigma00 = sigma00.clamp(min=0.0)
201
+ sigma11 = sigma11.clamp(min=0.0)
202
+ sigma01 = torch.sign(sigma01) * torch.minimum(
203
+ torch.sqrt(sigma00 * sigma11), torch.abs(sigma01)
204
+ )
205
+
206
+ assert isinstance(self.data_range, float)
207
+ c1 = (self.k1 * self.data_range) ** 2
208
+ c2 = (self.k2 * self.data_range) ** 2
209
+ numer = (2 * mu01 + c1) * (2 * sigma01 + c2)
210
+ denom = (mu00 + mu11 + c1) * (sigma00 + sigma11 + c2)
211
+ ssim_map = numer / denom
212
+
213
+ self.similarity.append(ssim_map.mean(dim=(1, 2, 3)))
214
+
215
+ def compute(self) -> torch.Tensor:
216
+ """Compute final SSIM metric."""
217
+ return torch.cat(self.similarity).mean()
218
+
219
+
220
+ class mLPIPS(Metric):
221
+ sum_scores: list[torch.Tensor]
222
+ total: list[torch.Tensor]
223
+
224
+ def __init__(
225
+ self,
226
+ net_type: Literal["vgg", "alex", "squeeze"] = "alex",
227
+ **kwargs,
228
+ ):
229
+ super().__init__(**kwargs)
230
+
231
+ if not _TORCHVISION_AVAILABLE:
232
+ raise ModuleNotFoundError(
233
+ "LPIPS metric requires that torchvision is installed."
234
+ " Either install as `pip install torchmetrics[image]` or `pip install torchvision`."
235
+ )
236
+
237
+ valid_net_type = ("vgg", "alex", "squeeze")
238
+ if net_type not in valid_net_type:
239
+ raise ValueError(
240
+ f"Argument `net_type` must be one of {valid_net_type}, but got {net_type}."
241
+ )
242
+ self.net = _NoTrainLpips(net=net_type, spatial=True)
243
+
244
+ self.add_state("sum_scores", [], dist_reduce_fx="cat")
245
+ self.add_state("total", [], dist_reduce_fx="cat")
246
+
247
+ def __len__(self) -> int:
248
+ return len(self.total)
249
+
250
+ def update(
251
+ self,
252
+ preds: torch.Tensor,
253
+ targets: torch.Tensor,
254
+ masks: torch.Tensor | None = None,
255
+ ):
256
+ """Update internal states with lpips scores.
257
+
258
+ Args:
259
+ preds (torch.Tensor): (B, H, W, 3) float32 predicted images.
260
+ targets (torch.Tensor): (B, H, W, 3) float32 target images.
261
+ masks (torch.Tensor | None): (B, H, W) optional float32 binary
262
+ masks where the 1-regions will be taken into account.
263
+ """
264
+ if masks is None:
265
+ masks = torch.ones_like(preds[..., 0])
266
+ scores = self.net(
267
+ (preds * masks[..., None]).permute(0, 3, 1, 2),
268
+ (targets * masks[..., None]).permute(0, 3, 1, 2),
269
+ normalize=True,
270
+ )
271
+ self.sum_scores.append((scores * masks[:, None]).sum())
272
+ self.total.append(masks.sum().to(torch.int64))
273
+
274
+ def compute(self) -> torch.Tensor:
275
+ """Compute final perceptual similarity metric."""
276
+ return (
277
+ torch.tensor(self.sum_scores, device=self.device)
278
+ / torch.tensor(self.total, device=self.device)
279
+ ).mean()
280
+
281
+
282
+ class PCK(Metric):
283
+ correct: list[torch.Tensor]
284
+ total: list[int]
285
+
286
+ def __init__(self, **kwargs):
287
+ super().__init__(**kwargs)
288
+ self.add_state("correct", default=[], dist_reduce_fx="cat")
289
+ self.add_state("total", default=[], dist_reduce_fx="cat")
290
+
291
+ def __len__(self) -> int:
292
+ return len(self.total)
293
+
294
+ def update(self, preds: torch.Tensor, targets: torch.Tensor, threshold: float):
295
+ """Update internal states with PCK scores.
296
+
297
+ Args:
298
+ preds (torch.Tensor): (N, 2) predicted 2D keypoints.
299
+ targets (torch.Tensor): (N, 2) targets 2D keypoints.
300
+ threshold (float): PCK threshold.
301
+ """
302
+
303
+ self.correct.append(
304
+ (torch.linalg.norm(preds - targets, dim=-1) < threshold).sum()
305
+ )
306
+ self.total.append(preds.shape[0])
307
+
308
+ def compute(self) -> torch.Tensor:
309
+ """Compute PCK over state."""
310
+ return (
311
+ torch.tensor(self.correct, device=self.device)
312
+ / torch.clamp(torch.tensor(self.total, device=self.device), min=1e-8)
313
+ ).mean()
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/params.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ from flow3d.transforms import cont_6d_to_rmat
8
+
9
+
10
+ class GaussianParams(nn.Module):
11
+ def __init__(
12
+ self,
13
+ means: torch.Tensor,
14
+ quats: torch.Tensor,
15
+ scales: torch.Tensor,
16
+ colors: torch.Tensor,
17
+ features: torch.Tensor,
18
+ opacities: torch.Tensor,
19
+ motion_coefs: torch.Tensor | None = None,
20
+ scene_center: torch.Tensor | None = None,
21
+ scene_scale: torch.Tensor | float = 1.0,
22
+ ):
23
+ super().__init__()
24
+ if not check_gaussian_sizes(
25
+ means, quats, scales, colors, features, opacities, motion_coefs
26
+ ):
27
+ import ipdb
28
+
29
+ ipdb.set_trace()
30
+ params_dict = {
31
+ "means": nn.Parameter(means),
32
+ "quats": nn.Parameter(quats),
33
+ "scales": nn.Parameter(scales),
34
+ "colors": nn.Parameter(colors),
35
+ "features": nn.Parameter(features),
36
+ "opacities": nn.Parameter(opacities),
37
+ }
38
+ if motion_coefs is not None:
39
+ params_dict["motion_coefs"] = nn.Parameter(motion_coefs)
40
+ self.params = nn.ParameterDict(params_dict)
41
+ self.quat_activation = lambda x: F.normalize(x, dim=-1, p=2)
42
+ self.color_activation = torch.sigmoid
43
+ self.feature_activation = torch.sigmoid
44
+ self.scale_activation = torch.exp
45
+ self.opacity_activation = torch.sigmoid
46
+ self.motion_coef_activation = lambda x: F.softmax(x, dim=-1)
47
+
48
+ if scene_center is None:
49
+ scene_center = torch.zeros(3, device=means.device)
50
+ self.register_buffer("scene_center", scene_center)
51
+ self.register_buffer("scene_scale", torch.as_tensor(scene_scale))
52
+
53
+ @staticmethod
54
+ def init_from_state_dict(state_dict, prefix="params."):
55
+ req_keys = ["means", "quats", "scales", "colors", "features", "opacities"]
56
+ assert all(f"{prefix}{k}" in state_dict for k in req_keys)
57
+ args = {
58
+ "motion_coefs": None,
59
+ "scene_center": torch.zeros(3),
60
+ "scene_scale": torch.tensor(1.0),
61
+ }
62
+ for k in req_keys + list(args.keys()):
63
+ if f"{prefix}{k}" in state_dict:
64
+ args[k] = state_dict[f"{prefix}{k}"]
65
+ return GaussianParams(**args)
66
+
67
+ @property
68
+ def num_gaussians(self) -> int:
69
+ return self.params["means"].shape[0]
70
+
71
+ def get_colors(self) -> torch.Tensor:
72
+ return self.color_activation(self.params["colors"])
73
+
74
+ def get_features(self) -> torch.Tensor:
75
+ return self.feature_activation(self.params["features"])
76
+
77
+ def get_scales(self) -> torch.Tensor:
78
+ return self.scale_activation(self.params["scales"])
79
+
80
+ def get_opacities(self) -> torch.Tensor:
81
+ return self.opacity_activation(self.params["opacities"])
82
+
83
+ def get_quats(self) -> torch.Tensor:
84
+ return self.quat_activation(self.params["quats"])
85
+
86
+ def get_coefs(self) -> torch.Tensor:
87
+ assert "motion_coefs" in self.params
88
+ return self.motion_coef_activation(self.params["motion_coefs"])
89
+
90
+ def densify_params(self, should_split, should_dup):
91
+ """
92
+ densify gaussians
93
+ """
94
+ updated_params = {}
95
+ for name, x in self.params.items():
96
+ x_dup = x[should_dup]
97
+ x_split = x[should_split].repeat([2] + [1] * (x.ndim - 1))
98
+ if name == "scales":
99
+ x_split -= math.log(1.6)
100
+ x_new = nn.Parameter(torch.cat([x[~should_split], x_dup, x_split], dim=0))
101
+ updated_params[name] = x_new
102
+ self.params[name] = x_new
103
+ return updated_params
104
+
105
+ def cull_params(self, should_cull):
106
+ """
107
+ cull gaussians
108
+ """
109
+ updated_params = {}
110
+ for name, x in self.params.items():
111
+ x_new = nn.Parameter(x[~should_cull])
112
+ updated_params[name] = x_new
113
+ self.params[name] = x_new
114
+ return updated_params
115
+
116
+ def reset_opacities(self, new_val):
117
+ """
118
+ reset all opacities to new_val
119
+ """
120
+ self.params["opacities"].data.fill_(new_val)
121
+ updated_params = {"opacities": self.params["opacities"]}
122
+ return updated_params
123
+
124
+
125
+ class MotionBases(nn.Module):
126
+ def __init__(self, rots, transls):
127
+ super().__init__()
128
+ self.num_frames = rots.shape[1]
129
+ self.num_bases = rots.shape[0]
130
+ assert check_bases_sizes(rots, transls)
131
+ self.params = nn.ParameterDict(
132
+ {
133
+ "rots": nn.Parameter(rots),
134
+ "transls": nn.Parameter(transls),
135
+ }
136
+ )
137
+
138
+ @staticmethod
139
+ def init_from_state_dict(state_dict, prefix="params."):
140
+ param_keys = ["rots", "transls"]
141
+ assert all(f"{prefix}{k}" in state_dict for k in param_keys)
142
+ args = {k: state_dict[f"{prefix}{k}"] for k in param_keys}
143
+ return MotionBases(**args)
144
+
145
+ def compute_transforms(self, ts: torch.Tensor, coefs: torch.Tensor) -> torch.Tensor:
146
+ """
147
+ :param ts (B)
148
+ :param coefs (G, K)
149
+ returns transforms (G, B, 3, 4)
150
+ """
151
+ transls = self.params["transls"][:, ts] # (K, B, 3)
152
+ rots = self.params["rots"][:, ts] # (K, B, 6)
153
+ transls = torch.einsum("pk,kni->pni", coefs, transls)
154
+ rots = torch.einsum("pk,kni->pni", coefs, rots) # (G, B, 6)
155
+ rotmats = cont_6d_to_rmat(rots) # (K, B, 3, 3)
156
+ return torch.cat([rotmats, transls[..., None]], dim=-1)
157
+
158
+
159
+ def check_gaussian_sizes(
160
+ means: torch.Tensor,
161
+ quats: torch.Tensor,
162
+ scales: torch.Tensor,
163
+ colors: torch.Tensor,
164
+ features: torch.Tensor,
165
+ opacities: torch.Tensor,
166
+ motion_coefs: torch.Tensor | None = None,
167
+ ) -> bool:
168
+ dims = means.shape[:-1]
169
+ leading_dims_match = (
170
+ quats.shape[:-1] == dims
171
+ and scales.shape[:-1] == dims
172
+ and colors.shape[:-1] == dims
173
+ and features.shape[:-1] == dims
174
+ and opacities.shape == dims
175
+ )
176
+ if motion_coefs is not None and motion_coefs.numel() > 0:
177
+ leading_dims_match &= motion_coefs.shape[:-1] == dims
178
+ dims_correct = (
179
+ means.shape[-1] == 3
180
+ and (quats.shape[-1] == 4)
181
+ and (scales.shape[-1] == 3)
182
+ and (colors.shape[-1] == 3)
183
+ and (features.shape[-1] == 384) #### MUST BE CHANGE
184
+ )
185
+ return leading_dims_match and dims_correct
186
+
187
+
188
+ def check_bases_sizes(motion_rots: torch.Tensor, motion_transls: torch.Tensor) -> bool:
189
+ return (
190
+ motion_rots.shape[-1] == 6
191
+ and motion_transls.shape[-1] == 3
192
+ and motion_rots.shape[:-2] == motion_transls.shape[:-2]
193
+ )
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/renderer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from loguru import logger as guru
5
+ from nerfview import CameraState
6
+
7
+ from flow3d.scene_model import SceneModel
8
+ from flow3d.vis.utils import draw_tracks_2d_th, get_server
9
+ from flow3d.vis.viewer import DynamicViewer
10
+
11
+
12
+ class Renderer:
13
+ def __init__(
14
+ self,
15
+ model: SceneModel,
16
+ device: torch.device,
17
+ # Logging.
18
+ work_dir: str,
19
+ port: int | None = None,
20
+ ):
21
+ self.device = device
22
+
23
+ self.model = model
24
+ self.num_frames = model.num_frames
25
+
26
+ self.work_dir = work_dir
27
+ self.global_step = 0
28
+ self.epoch = 0
29
+
30
+ self.viewer = None
31
+ if port is not None:
32
+ server = get_server(port=port)
33
+ self.viewer = DynamicViewer(
34
+ server, self.render_fn, model.num_frames, work_dir, mode="rendering"
35
+ )
36
+
37
+ self.tracks_3d = self.model.compute_poses_fg(
38
+ # torch.arange(max(0, t - 20), max(1, t), device=self.device),
39
+ torch.arange(self.num_frames, device=self.device),
40
+ inds=torch.arange(10, device=self.device),
41
+ )[0]
42
+
43
+ @staticmethod
44
+ def init_from_checkpoint(
45
+ path: str, device: torch.device, *args, **kwargs
46
+ ) -> "Renderer":
47
+ guru.info(f"Loading checkpoint from {path}")
48
+ ckpt = torch.load(path)
49
+ state_dict = ckpt["model"]
50
+ model = SceneModel.init_from_state_dict(state_dict)
51
+ model = model.to(device)
52
+ renderer = Renderer(model, device, *args, **kwargs)
53
+ renderer.global_step = ckpt.get("global_step", 0)
54
+ renderer.epoch = ckpt.get("epoch", 0)
55
+ return renderer
56
+
57
+ @torch.inference_mode()
58
+ def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]):
59
+ if self.viewer is None:
60
+ return np.full((img_wh[1], img_wh[0], 3), 255, dtype=np.uint8)
61
+
62
+ W, H = img_wh
63
+
64
+ focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item()
65
+ K = torch.tensor(
66
+ [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]],
67
+ device=self.device,
68
+ )
69
+ w2c = torch.linalg.inv(
70
+ torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device)
71
+ )
72
+ t = (
73
+ int(self.viewer._playback_guis[0].value)
74
+ if not self.viewer._canonical_checkbox.value
75
+ else None
76
+ )
77
+ self.model.training = False
78
+ img = self.model.render(t, w2c[None], K[None], img_wh)["img"][0]
79
+ if not self.viewer._render_track_checkbox.value:
80
+ img = (img.cpu().numpy() * 255.0).astype(np.uint8)
81
+ else:
82
+ assert t is not None
83
+ tracks_3d = self.tracks_3d[:, max(0, t - 20) : max(1, t)]
84
+ tracks_2d = torch.einsum(
85
+ "ij,jk,nbk->nbi", K, w2c[:3], F.pad(tracks_3d, (0, 1), value=1.0)
86
+ )
87
+ tracks_2d = tracks_2d[..., :2] / tracks_2d[..., 2:]
88
+ img = draw_tracks_2d_th(img, tracks_2d)
89
+ return img
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/scene_model.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import roma
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from gsplat.rendering import rasterization
7
+ from torch import Tensor
8
+
9
+ from flow3d.params import GaussianParams, MotionBases
10
+
11
+
12
+ class SceneModel(nn.Module):
13
+ def __init__(
14
+ self,
15
+ Ks: Tensor,
16
+ w2cs: Tensor,
17
+ fg_params: GaussianParams,
18
+ motion_bases: MotionBases,
19
+ bg_params: GaussianParams | None = None,
20
+ num_dim: int | None = 384 ### MUST BE CHANGE
21
+ ):
22
+ super().__init__()
23
+ self.num_frames = motion_bases.num_frames
24
+ self.fg = fg_params
25
+ self.motion_bases = motion_bases
26
+ self.bg = bg_params
27
+ self.num_dim = num_dim
28
+ scene_scale = 1.0 if bg_params is None else bg_params.scene_scale
29
+ self.register_buffer("bg_scene_scale", torch.as_tensor(scene_scale))
30
+ self.register_buffer("Ks", Ks)
31
+ self.register_buffer("w2cs", w2cs)
32
+
33
+ self._current_xys = None
34
+ self._current_radii = None
35
+ self._current_img_wh = None
36
+
37
+ @property
38
+ def num_gaussians(self) -> int:
39
+ return self.num_bg_gaussians + self.num_fg_gaussians
40
+
41
+ @property
42
+ def num_bg_gaussians(self) -> int:
43
+ return self.bg.num_gaussians if self.bg is not None else 0
44
+
45
+ @property
46
+ def num_fg_gaussians(self) -> int:
47
+ return self.fg.num_gaussians
48
+
49
+ @property
50
+ def num_motion_bases(self) -> int:
51
+ return self.motion_bases.num_bases
52
+
53
+ @property
54
+ def has_bg(self) -> bool:
55
+ return self.bg is not None
56
+
57
+ def compute_poses_bg(self) -> tuple[torch.Tensor, torch.Tensor]:
58
+ """
59
+ Returns:
60
+ means: (G, B, 3)
61
+ quats: (G, B, 4)
62
+ """
63
+ assert self.bg is not None
64
+ return self.bg.params["means"], self.bg.get_quats()
65
+
66
+ def compute_transforms(
67
+ self, ts: torch.Tensor, inds: torch.Tensor | None = None
68
+ ) -> torch.Tensor:
69
+ coefs = self.fg.get_coefs() # (G, K)
70
+ if inds is not None:
71
+ coefs = coefs[inds]
72
+ transfms = self.motion_bases.compute_transforms(ts, coefs) # (G, B, 3, 4)
73
+ return transfms
74
+
75
+ def compute_poses_fg(
76
+ self, ts: torch.Tensor | None, inds: torch.Tensor | None = None
77
+ ) -> tuple[torch.Tensor, torch.Tensor]:
78
+ """
79
+ :returns means: (G, B, 3), quats: (G, B, 4)
80
+ """
81
+ means = self.fg.params["means"] # (G, 3)
82
+ quats = self.fg.get_quats() # (G, 4)
83
+ if inds is not None:
84
+ means = means[inds]
85
+ quats = quats[inds]
86
+ if ts is not None:
87
+ transfms = self.compute_transforms(ts, inds) # (G, B, 3, 4)
88
+ means = torch.einsum(
89
+ "pnij,pj->pni",
90
+ transfms,
91
+ F.pad(means, (0, 1), value=1.0),
92
+ )
93
+ quats = roma.quat_xyzw_to_wxyz(
94
+ (
95
+ roma.quat_product(
96
+ roma.rotmat_to_unitquat(transfms[..., :3, :3]),
97
+ roma.quat_wxyz_to_xyzw(quats[:, None]),
98
+ )
99
+ )
100
+ )
101
+ quats = F.normalize(quats, p=2, dim=-1)
102
+ else:
103
+ means = means[:, None]
104
+ quats = quats[:, None]
105
+ return means, quats
106
+
107
+ def compute_poses_all(
108
+ self, ts: torch.Tensor | None
109
+ ) -> tuple[torch.Tensor, torch.Tensor]:
110
+ means, quats = self.compute_poses_fg(ts)
111
+ if self.has_bg:
112
+ bg_means, bg_quats = self.compute_poses_bg()
113
+ means = torch.cat(
114
+ [means, bg_means[:, None].expand(-1, means.shape[1], -1)], dim=0
115
+ ).contiguous()
116
+ quats = torch.cat(
117
+ [quats, bg_quats[:, None].expand(-1, means.shape[1], -1)], dim=0
118
+ ).contiguous()
119
+ return means, quats
120
+
121
+ def get_colors_all(self) -> torch.Tensor:
122
+ colors = self.fg.get_colors()
123
+ if self.bg is not None:
124
+ colors = torch.cat([colors, self.bg.get_colors()], dim=0).contiguous()
125
+ return colors
126
+
127
+ def get_features_all(self) -> torch.Tensor:
128
+ features = self.fg.get_features()
129
+ if self.bg is not None:
130
+ features = torch.cat([features, self.bg.get_features()], dim=0).contiguous()
131
+ return features
132
+
133
+ def get_scales_all(self) -> torch.Tensor:
134
+ scales = self.fg.get_scales()
135
+ if self.bg is not None:
136
+ scales = torch.cat([scales, self.bg.get_scales()], dim=0).contiguous()
137
+ return scales
138
+
139
+ def get_opacities_all(self) -> torch.Tensor:
140
+ """
141
+ :returns colors: (G, 3), scales: (G, 3), opacities: (G, 1)
142
+ """
143
+ opacities = self.fg.get_opacities()
144
+ if self.bg is not None:
145
+ opacities = torch.cat(
146
+ [opacities, self.bg.get_opacities()], dim=0
147
+ ).contiguous()
148
+ return opacities
149
+
150
+ @staticmethod
151
+ def init_from_state_dict(state_dict, prefix=""):
152
+ fg = GaussianParams.init_from_state_dict(
153
+ state_dict, prefix=f"{prefix}fg.params."
154
+ )
155
+ bg = None
156
+ if any("bg." in k for k in state_dict):
157
+ bg = GaussianParams.init_from_state_dict(
158
+ state_dict, prefix=f"{prefix}bg.params."
159
+ )
160
+ motion_bases = MotionBases.init_from_state_dict(
161
+ state_dict, prefix=f"{prefix}motion_bases.params."
162
+ )
163
+ Ks = state_dict[f"{prefix}Ks"]
164
+ w2cs = state_dict[f"{prefix}w2cs"]
165
+ num_dim = 384
166
+ return SceneModel(Ks, w2cs, fg, motion_bases, bg, num_dim)
167
+
168
+ def render(
169
+ self,
170
+ # A single time instance for view rendering.
171
+ t: int | None,
172
+ w2cs: torch.Tensor, # (C, 4, 4)
173
+ Ks: torch.Tensor, # (C, 3, 3)
174
+ img_wh: tuple[int, int],
175
+ # Multiple time instances for track rendering: (B,).
176
+ target_ts: torch.Tensor | None = None, # (B)
177
+ target_w2cs: torch.Tensor | None = None, # (B, 4, 4)
178
+ bg_color: torch.Tensor | float = 1.0,
179
+ colors_override: torch.Tensor | None = None,
180
+ features_override: torch.Tensor | None = None,
181
+ means: torch.Tensor | None = None,
182
+ quats: torch.Tensor | None = None,
183
+ target_means: torch.Tensor | None = None,
184
+ return_color: bool = True,
185
+ return_feature: bool = True,
186
+ return_depth: bool = False,
187
+ return_mask: bool = False,
188
+ fg_only: bool = False,
189
+ filter_mask: torch.Tensor | None = None,
190
+ freeze_appearance: bool = False,
191
+ freeze_semantic: bool = False
192
+ ) -> dict:
193
+ device = w2cs.device
194
+ C = w2cs.shape[0]
195
+
196
+ W, H = img_wh
197
+ pose_fnc = self.compute_poses_fg if fg_only else self.compute_poses_all
198
+ N = self.num_fg_gaussians if fg_only else self.num_gaussians
199
+
200
+ if means is None or quats is None:
201
+ means, quats = pose_fnc(
202
+ torch.tensor([t], device=device) if t is not None else None
203
+ )
204
+ means = means[:, 0]
205
+ quats = quats[:, 0]
206
+
207
+ if colors_override is None:
208
+ if return_color:
209
+ colors_override = (
210
+ self.fg.get_colors() if fg_only else self.get_colors_all()
211
+ )
212
+ else:
213
+ colors_override = torch.zeros(N, 0, device=device)
214
+
215
+ D = colors_override.shape[-1]
216
+
217
+ if features_override is None:
218
+ if return_feature:
219
+ features_override = (
220
+ self.fg.get_features() if fg_only else self.get_features_all()
221
+ )
222
+ else:
223
+ features_override = torch.zeros(N, 0, device=device)
224
+
225
+ #features_override = features_override[:,0:10]
226
+ D_f = features_override.shape[-1]
227
+ bg_feature = torch.zeros(D_f).unsqueeze(0).to(device)
228
+
229
+ #colors_override = torch.cat([colors_override, features_override], dim=-1)
230
+
231
+ scales = self.fg.get_scales() if fg_only else self.get_scales_all()
232
+ opacities = self.fg.get_opacities() if fg_only else self.get_opacities_all()
233
+
234
+ if isinstance(bg_color, float):
235
+ bg_color = torch.full((C, D), bg_color, device=device)
236
+ assert isinstance(bg_color, torch.Tensor)
237
+
238
+ #bg_color = torch.cat([bg_color, bg_feature], dim=-1)
239
+
240
+ mode = "RGB"
241
+ ds_expected = {"img_color": D}
242
+
243
+ if return_mask:
244
+ if self.has_bg and not fg_only:
245
+ mask_values = torch.zeros((self.num_gaussians, 1), device=device)
246
+ mask_values[: self.num_fg_gaussians] = 1.0
247
+ else:
248
+ mask_values = torch.ones((self.num_fg_gaussians, 1), device=device)
249
+ colors_override = torch.cat([colors_override, mask_values], dim=-1)
250
+ bg_color = torch.cat([bg_color, torch.zeros(C, 1, device=device)], dim=-1)
251
+ ds_expected["mask"] = 1
252
+
253
+ B = 0
254
+ if target_ts is not None:
255
+ B = target_ts.shape[0]
256
+ if target_means is None:
257
+ target_means, _ = pose_fnc(target_ts) # [G, B, 3]
258
+ if target_w2cs is not None:
259
+ target_means = torch.einsum(
260
+ "bij,pbj->pbi",
261
+ target_w2cs[:, :3],
262
+ F.pad(target_means, (0, 1), value=1.0),
263
+ )
264
+ track_3d_vals = target_means.flatten(-2) # (G, B * 3)
265
+ d_track = track_3d_vals.shape[-1]
266
+ colors_override = torch.cat([colors_override, track_3d_vals], dim=-1)
267
+ bg_color = torch.cat(
268
+ [bg_color, torch.zeros(C, track_3d_vals.shape[-1], device=device)],
269
+ dim=-1,
270
+ )
271
+ ds_expected["tracks_3d"] = d_track
272
+
273
+ assert colors_override.shape[-1] == sum(ds_expected.values())
274
+ assert bg_color.shape[-1] == sum(ds_expected.values()), f"{bg_color.shape[-1]=} != {sum(ds_expected.values())=}"
275
+
276
+ if return_depth:
277
+ mode = "RGB+ED"
278
+ ds_expected["depth"] = 1
279
+
280
+ if filter_mask is not None:
281
+ assert filter_mask.shape == (N,)
282
+ means = means[filter_mask]
283
+ quats = quats[filter_mask]
284
+ scales = scales[filter_mask]
285
+ opacities = opacities[filter_mask]
286
+ colors_override = colors_override[filter_mask]
287
+ features_override = features_override[filter_mask]
288
+
289
+ render_colors, alphas, info = rasterization(
290
+ means=means,
291
+ quats=quats,
292
+ scales=scales,
293
+ opacities=opacities,
294
+ colors=colors_override,
295
+ backgrounds=bg_color,
296
+ viewmats=w2cs, # [C, 4, 4]
297
+ Ks=Ks, # [C, 3, 3]
298
+ width=W,
299
+ height=H,
300
+ packed=False,
301
+ render_mode=mode,
302
+ )
303
+
304
+ render_features, _, _ = rasterization(
305
+ means=means,
306
+ quats=quats,
307
+ scales=scales,
308
+ opacities=opacities,
309
+ colors=features_override,
310
+ backgrounds=bg_feature,
311
+ viewmats=w2cs, # [C, 4, 4]
312
+ Ks=Ks, # [C, 3, 3]
313
+ width=W,
314
+ height=H,
315
+ packed=False,
316
+ render_mode="RGB",
317
+ )
318
+
319
+ # Populate the current data for adaptive gaussian control.
320
+ if self.training and info["means2d"].requires_grad:
321
+ self._current_xys = info["means2d"]
322
+ self._current_radii = info["radii"]
323
+ self._current_img_wh = img_wh
324
+ # We want to be able to access to xys' gradients later in a
325
+ # torch.no_grad context.
326
+ self._current_xys.retain_grad()
327
+
328
+ assert render_colors.shape[-1] == sum(ds_expected.values())
329
+ outputs = torch.split(render_colors, list(ds_expected.values()), dim=-1)
330
+ out_dict = {}
331
+ for i, (name, dim) in enumerate(ds_expected.items()):
332
+ x = outputs[i]
333
+ if name == "img_color":
334
+ #x = x[:,:,:,0:3]
335
+ pass
336
+ #assert x.shape[-1] == dim, f"{x.shape[-1]=} != {dim=}"
337
+ if name == "tracks_3d":
338
+ x = x.reshape(C, H, W, B, 3)
339
+ out_dict[name] = x
340
+ out_dict["acc"] = alphas
341
+ out_dict["img_feature"] = render_features
342
+
343
+ return out_dict
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/tensor_dataclass.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Callable, TypeVar
3
+
4
+ import torch
5
+ from typing_extensions import Self
6
+
7
+ TensorDataclassT = TypeVar("T", bound="TensorDataclass")
8
+
9
+
10
+ class TensorDataclass:
11
+ """A lighter version of nerfstudio's TensorDataclass:
12
+ https://github.com/nerfstudio-project/nerfstudio/blob/main/nerfstudio/utils/tensor_dataclass.py
13
+ """
14
+
15
+ def __getitem__(self, key) -> Self:
16
+ return self.map(lambda x: x[key])
17
+
18
+ def to(self, device: torch.device | str) -> Self:
19
+ """Move the tensors in the dataclass to the given device.
20
+
21
+ Args:
22
+ device: The device to move to.
23
+
24
+ Returns:
25
+ A new dataclass.
26
+ """
27
+ return self.map(lambda x: x.to(device))
28
+
29
+ def map(self, fn: Callable[[torch.Tensor], torch.Tensor]) -> Self:
30
+ """Apply a function to all tensors in the dataclass.
31
+
32
+ Also recurses into lists, tuples, and dictionaries.
33
+
34
+ Args:
35
+ fn: The function to apply to each tensor.
36
+
37
+ Returns:
38
+ A new dataclass.
39
+ """
40
+
41
+ MapT = TypeVar("MapT")
42
+
43
+ def _map_impl(
44
+ fn: Callable[[torch.Tensor], torch.Tensor],
45
+ val: MapT,
46
+ ) -> MapT:
47
+ if isinstance(val, torch.Tensor):
48
+ return fn(val)
49
+ elif isinstance(val, TensorDataclass):
50
+ return type(val)(**_map_impl(fn, vars(val)))
51
+ elif isinstance(val, (list, tuple)):
52
+ return type(val)(_map_impl(fn, v) for v in val)
53
+ elif isinstance(val, dict):
54
+ assert type(val) is dict # No subclass support.
55
+ return {k: _map_impl(fn, v) for k, v in val.items()} # type: ignore
56
+ else:
57
+ return val
58
+
59
+ return _map_impl(fn, self)
60
+
61
+
62
+ @dataclass
63
+ class TrackObservations(TensorDataclass):
64
+ xyz: torch.Tensor
65
+ visibles: torch.Tensor
66
+ invisibles: torch.Tensor
67
+ confidences: torch.Tensor
68
+ colors: torch.Tensor
69
+
70
+ def check_sizes(self) -> bool:
71
+ dims = self.xyz.shape[:-1]
72
+ return (
73
+ self.visibles.shape == dims
74
+ and self.invisibles.shape == dims
75
+ and self.confidences.shape == dims
76
+ and self.colors.shape[:-1] == dims[:-1]
77
+ and self.xyz.shape[-1] == 3
78
+ and self.colors.shape[-1] == 3
79
+ )
80
+
81
+ def filter_valid(self, valid_mask: torch.Tensor) -> Self:
82
+ return self.map(lambda x: x[valid_mask])
83
+
84
+
85
+ @dataclass
86
+ class StaticObservations(TensorDataclass):
87
+ xyz: torch.Tensor
88
+ normals: torch.Tensor
89
+ colors: torch.Tensor
90
+
91
+ def check_sizes(self) -> bool:
92
+ dims = self.xyz.shape
93
+ return self.normals.shape == dims and self.colors.shape == dims
94
+
95
+ def filter_valid(self, valid_mask: torch.Tensor) -> Self:
96
+ return self.map(lambda x: x[valid_mask])
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/trainer.py ADDED
@@ -0,0 +1,827 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ from os import statvfs_result
3
+ import time
4
+ from dataclasses import asdict
5
+ from typing import cast
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from loguru import logger as guru
11
+ from nerfview import CameraState
12
+ from pytorch_msssim import SSIM
13
+ from torch.utils.tensorboard import SummaryWriter # type: ignore
14
+
15
+ from flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig
16
+ from flow3d.loss_utils import (
17
+ compute_gradient_loss,
18
+ compute_se3_smoothness_loss,
19
+ compute_z_acc_loss,
20
+ masked_l1_loss,
21
+ )
22
+ from flow3d.metrics import PCK, mLPIPS, mPSNR, mSSIM
23
+ from flow3d.scene_model import SceneModel
24
+ from flow3d.vis.utils import get_server
25
+ from flow3d.vis.viewer import DynamicViewer
26
+
27
+
28
+ class Trainer:
29
+ def __init__(
30
+ self,
31
+ model: SceneModel,
32
+ device: torch.device,
33
+ lr_cfg: SceneLRConfig,
34
+ losses_cfg: LossesConfig,
35
+ optim_cfg: OptimizerConfig,
36
+ # Logging.
37
+ work_dir: str,
38
+ port: int | None = None,
39
+ log_every: int = 10,
40
+ checkpoint_every: int = 200,
41
+ validate_every: int = 500,
42
+ validate_video_every: int = 1000,
43
+ validate_viewer_assets_every: int = 100,
44
+ ):
45
+ self.device = device
46
+ self.log_every = log_every
47
+ self.checkpoint_every = checkpoint_every
48
+ self.validate_every = validate_every
49
+ self.validate_video_every = validate_video_every
50
+ self.validate_viewer_assets_every = validate_viewer_assets_every
51
+
52
+ self.model = model
53
+ self.num_frames = model.num_frames
54
+
55
+ self.lr_cfg = lr_cfg
56
+ self.losses_cfg = losses_cfg
57
+ self.optim_cfg = optim_cfg
58
+
59
+ self.reset_opacity_every = (
60
+ self.optim_cfg.reset_opacity_every_n_controls * self.optim_cfg.control_every
61
+ )
62
+ self.optimizers, self.scheduler = self.configure_optimizers()
63
+
64
+ # running stats for adaptive density control
65
+ self.running_stats = {
66
+ "xys_grad_norm_acc": torch.zeros(self.model.num_gaussians, device=device),
67
+ "vis_count": torch.zeros(
68
+ self.model.num_gaussians, device=device, dtype=torch.int64
69
+ ),
70
+ "max_radii": torch.zeros(self.model.num_gaussians, device=device),
71
+ }
72
+
73
+ self.work_dir = work_dir
74
+ self.writer = SummaryWriter(log_dir=work_dir)
75
+ self.global_step = 0
76
+ self.epoch = 0
77
+
78
+ self.viewer = None
79
+ if port is not None:
80
+ server = get_server(port=port)
81
+ self.viewer = DynamicViewer(
82
+ server, self.render_fn, model.num_frames, work_dir, mode="training"
83
+ )
84
+
85
+ # metrics
86
+ self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)
87
+ self.psnr_metric = mPSNR()
88
+ self.ssim_metric = mSSIM()
89
+ self.lpips_metric = mLPIPS()
90
+ self.pck_metric = PCK()
91
+ self.bg_psnr_metric = mPSNR()
92
+ self.fg_psnr_metric = mPSNR()
93
+ self.bg_ssim_metric = mSSIM()
94
+ self.fg_ssim_metric = mSSIM()
95
+ self.bg_lpips_metric = mLPIPS()
96
+ self.fg_lpips_metric = mLPIPS()
97
+
98
+ def set_epoch(self, epoch: int):
99
+ self.epoch = epoch
100
+
101
+ def save_checkpoint(self, path: str):
102
+ model_dict = self.model.state_dict()
103
+ optimizer_dict = {k: v.state_dict() for k, v in self.optimizers.items()}
104
+ scheduler_dict = {k: v.state_dict() for k, v in self.scheduler.items()}
105
+ ckpt = {
106
+ "model": model_dict,
107
+ "optimizers": optimizer_dict,
108
+ "schedulers": scheduler_dict,
109
+ "global_step": self.global_step,
110
+ "epoch": self.epoch,
111
+ }
112
+ torch.save(ckpt, path)
113
+ guru.info(f"Saved checkpoint at {self.global_step=} to {path}")
114
+
115
+ @staticmethod
116
+ def init_from_checkpoint(
117
+ path: str, device: torch.device, *args, **kwargs
118
+ ) -> tuple["Trainer", int]:
119
+ guru.info(f"Loading checkpoint from {path}")
120
+ ckpt = torch.load(path)
121
+ state_dict = ckpt["model"]
122
+ model = SceneModel.init_from_state_dict(state_dict)
123
+ model = model.to(device)
124
+ trainer = Trainer(model, device, *args, **kwargs)
125
+ if "optimizers" in ckpt:
126
+ trainer.load_checkpoint_optimizers(ckpt["optimizers"])
127
+ if "schedulers" in ckpt:
128
+ trainer.load_checkpoint_schedulers(ckpt["schedulers"])
129
+ trainer.global_step = ckpt.get("global_step", 0)
130
+ start_epoch = ckpt.get("epoch", 0)
131
+ trainer.set_epoch(start_epoch)
132
+ return trainer, start_epoch
133
+
134
+ def load_checkpoint_optimizers(self, opt_ckpt):
135
+ for k, v in self.optimizers.items():
136
+ v.load_state_dict(opt_ckpt[k])
137
+
138
+ def load_checkpoint_schedulers(self, sched_ckpt):
139
+ for k, v in self.scheduler.items():
140
+ v.load_state_dict(sched_ckpt[k])
141
+
142
+ @torch.inference_mode()
143
+ def render_fn(self, camera_state: CameraState, img_wh: tuple[int, int]):
144
+ W, H = img_wh
145
+
146
+ focal = 0.5 * H / np.tan(0.5 * camera_state.fov).item()
147
+ K = torch.tensor(
148
+ [[focal, 0.0, W / 2.0], [0.0, focal, H / 2.0], [0.0, 0.0, 1.0]],
149
+ device=self.device,
150
+ )
151
+ w2c = torch.linalg.inv(
152
+ torch.from_numpy(camera_state.c2w.astype(np.float32)).to(self.device)
153
+ )
154
+ t = 0
155
+ if self.viewer is not None:
156
+ t = (
157
+ int(self.viewer._playback_guis[0].value)
158
+ if not self.viewer._canonical_checkbox.value
159
+ else None
160
+ )
161
+ self.model.training = False
162
+ img = self.model.render(t, w2c[None], K[None], img_wh)["img"][0]
163
+ return (img.cpu().numpy() * 255.0).astype(np.uint8)
164
+
165
+ def train_step(self, batch, freeze_appearance, freeze_semantic):
166
+ if self.viewer is not None:
167
+ while self.viewer.state.status == "paused":
168
+ time.sleep(0.1)
169
+ self.viewer.lock.acquire()
170
+
171
+ loss, stats, num_rays_per_step, num_rays_per_sec = self.compute_losses(batch, freeze_appearance, freeze_semantic)
172
+ if loss.isnan():
173
+ guru.info(f"Loss is NaN at step {self.global_step}!!")
174
+ import ipdb
175
+
176
+ ipdb.set_trace()
177
+ loss.backward()
178
+
179
+ for opt in self.optimizers.values():
180
+ opt.step()
181
+ opt.zero_grad(set_to_none=True)
182
+ for sched in self.scheduler.values():
183
+ sched.step()
184
+
185
+ self.log_dict(stats)
186
+ self.global_step += 1
187
+ self.run_control_steps()
188
+
189
+ if self.viewer is not None:
190
+ self.viewer.lock.release()
191
+ self.viewer.state.num_train_rays_per_sec = num_rays_per_sec
192
+ if self.viewer.mode == "training":
193
+ self.viewer.update(self.global_step, num_rays_per_step)
194
+
195
+ if self.global_step % self.checkpoint_every == 0:
196
+ self.save_checkpoint(f"{self.work_dir}/checkpoints/last.ckpt")
197
+
198
+ return loss.item(), stats
199
+
200
+ def compute_losses(self, batch, freeze_appearance, freeze_semantic):
201
+ self.model.training = True
202
+ B = batch["imgs"].shape[0]
203
+ W, H = img_wh = batch["imgs"].shape[2:0:-1]
204
+ N = batch["target_ts"][0].shape[0]
205
+
206
+ # (B,).
207
+ ts = batch["ts"]
208
+ # (B, 4, 4).
209
+ w2cs = batch["w2cs"]
210
+ # (B, 3, 3).
211
+ Ks = batch["Ks"]
212
+ # (B, H, W, 3).
213
+ imgs = batch["imgs"]
214
+ # (B, H, W).
215
+ valid_masks = batch.get("valid_masks", torch.ones_like(batch["imgs"][..., 0]))
216
+ # (B, H, W).
217
+ masks = batch["masks"]
218
+ masks *= valid_masks
219
+ # (B, H, W).
220
+ depths = batch["depths"]
221
+ # [(P, 2), ...].
222
+ query_tracks_2d = batch["query_tracks_2d"]
223
+ # [(N,), ...].
224
+ target_ts = batch["target_ts"]
225
+ # [(N, 4, 4), ...].
226
+ target_w2cs = batch["target_w2cs"]
227
+ # [(N, 3, 3), ...].
228
+ target_Ks = batch["target_Ks"]
229
+ # [(N, P, 2), ...].
230
+ target_tracks_2d = batch["target_tracks_2d"]
231
+ # [(N, P), ...].
232
+ target_visibles = batch["target_visibles"]
233
+ # [(N, P), ...].
234
+ target_invisibles = batch["target_invisibles"]
235
+ # [(N, P), ...].
236
+ target_confidences = batch["target_confidences"]
237
+ # [(N, P), ...].
238
+ target_track_depths = batch["target_track_depths"]
239
+ # (B, H, W, 384)
240
+ feature_maps = batch["feature_maps"]
241
+
242
+ _tic = time.time()
243
+ # (B, G, 3).
244
+ means, quats = self.model.compute_poses_all(ts) # (G, B, 3), (G, B, 4)
245
+ device = means.device
246
+ means = means.transpose(0, 1)
247
+ quats = quats.transpose(0, 1)
248
+ # [(N, G, 3), ...].
249
+ target_ts_vec = torch.cat(target_ts)
250
+ # (B * N, G, 3).
251
+ target_means, _ = self.model.compute_poses_all(target_ts_vec)
252
+ target_means = target_means.transpose(0, 1)
253
+ target_mean_list = target_means.split(N)
254
+ num_frames = self.model.num_frames
255
+
256
+ loss = 0.0
257
+
258
+
259
+ bg_colors = []
260
+ rendered_all = []
261
+ self._batched_xys = []
262
+ self._batched_radii = []
263
+ self._batched_img_wh = []
264
+ for i in range(B):
265
+ bg_color = torch.ones(1, 3, device=device)
266
+ rendered = self.model.render(
267
+ ts[i].item(),
268
+ w2cs[None, i],
269
+ Ks[None, i],
270
+ img_wh,
271
+ target_ts=target_ts[i],
272
+ target_w2cs=target_w2cs[i],
273
+ bg_color=bg_color,
274
+ means=means[i],
275
+ quats=quats[i],
276
+ target_means=target_mean_list[i].transpose(0, 1),
277
+ return_depth=True,
278
+ return_mask=self.model.has_bg,
279
+ freeze_appearance=freeze_appearance,
280
+ freeze_semantic=freeze_semantic
281
+ )
282
+ rendered_all.append(rendered)
283
+ bg_colors.append(bg_color)
284
+ if (
285
+ self.model._current_xys is not None
286
+ and self.model._current_radii is not None
287
+ and self.model._current_img_wh is not None
288
+ ):
289
+ self._batched_xys.append(self.model._current_xys)
290
+ self._batched_radii.append(self.model._current_radii)
291
+ self._batched_img_wh.append(self.model._current_img_wh)
292
+
293
+ # Necessary to make viewer work.
294
+ num_rays_per_step = H * W * B
295
+ num_rays_per_sec = num_rays_per_step / (time.time() - _tic)
296
+
297
+ # (B, H, W, N, *).
298
+ rendered_all = {
299
+ key: (
300
+ torch.cat([out_dict[key] for out_dict in rendered_all], dim=0)
301
+ if rendered_all[0][key] is not None
302
+ else None
303
+ )
304
+ for key in rendered_all[0]
305
+ }
306
+ bg_colors = torch.cat(bg_colors, dim=0)
307
+
308
+ # Compute losses.
309
+ # (B * N).
310
+ frame_intervals = (ts.repeat_interleave(N) - target_ts_vec).abs()
311
+ if not self.model.has_bg:
312
+ imgs = (
313
+ imgs * masks[..., None]
314
+ + (1.0 - masks[..., None]) * bg_colors[:, None, None]
315
+ )
316
+ else:
317
+ imgs = (
318
+ imgs * valid_masks[..., None]
319
+ + (1.0 - valid_masks[..., None]) * bg_colors[:, None, None]
320
+ )
321
+ # (P_all, 2).
322
+ tracks_2d = torch.cat([x.reshape(-1, 2) for x in target_tracks_2d], dim=0)
323
+ # (P_all,)
324
+ visibles = torch.cat([x.reshape(-1) for x in target_visibles], dim=0)
325
+ # (P_all,)
326
+ confidences = torch.cat([x.reshape(-1) for x in target_confidences], dim=0)
327
+
328
+
329
+ rendered_imgs = cast(torch.Tensor, rendered_all["img_color"])
330
+ rendered_features = cast(torch.Tensor, rendered_all["img_feature"])
331
+
332
+ # RGB loss.
333
+ if self.model.has_bg:
334
+ rendered_imgs = (
335
+ rendered_imgs * valid_masks[..., None]
336
+ + (1.0 - valid_masks[..., None]) * bg_colors[:, None, None]
337
+ )
338
+ rgb_loss = 0.8 * F.l1_loss(rendered_imgs, imgs) + 0.2 * (
339
+ 1 - self.ssim(rendered_imgs.permute(0, 3, 1, 2), imgs.permute(0, 3, 1, 2))
340
+ )
341
+ loss += rgb_loss * self.losses_cfg.w_rgb
342
+
343
+ # Mask loss.
344
+ if not self.model.has_bg:
345
+ mask_loss = F.mse_loss(rendered_all["acc"], masks[..., None]) # type: ignore
346
+ else:
347
+ mask_loss = F.mse_loss(
348
+ rendered_all["acc"], torch.ones_like(rendered_all["acc"]) # type: ignore
349
+ ) + masked_l1_loss(
350
+ rendered_all["mask"],
351
+ masks[..., None],
352
+ quantile=0.98, # type: ignore
353
+ )
354
+ loss += mask_loss * self.losses_cfg.w_mask
355
+
356
+ # (B * N, H * W, 3).
357
+ pred_tracks_3d = (
358
+ rendered_all["tracks_3d"].permute(0, 3, 1, 2, 4).reshape(-1, H * W, 3) # type: ignore
359
+ )
360
+ pred_tracks_2d = torch.einsum(
361
+ "bij,bpj->bpi", torch.cat(target_Ks), pred_tracks_3d
362
+ )
363
+ # (B * N, H * W, 1).
364
+ mapped_depth = torch.clamp(pred_tracks_2d[..., 2:], min=1e-6)
365
+ # (B * N, H * W, 2).
366
+ pred_tracks_2d = pred_tracks_2d[..., :2] / mapped_depth
367
+
368
+ # (B * N).
369
+ w_interval = torch.exp(-2 * frame_intervals / num_frames)
370
+ # w_track_loss = min(1, (self.max_steps - self.global_step) / 6000)
371
+ track_weights = confidences[..., None] * w_interval
372
+
373
+ # (B, H, W).
374
+ masks_flatten = torch.zeros_like(masks)
375
+ for i in range(B):
376
+ # This takes advantage of the fact that the query 2D tracks are
377
+ # always on the grid.
378
+ query_pixels = query_tracks_2d[i].to(torch.int64)
379
+ masks_flatten[i, query_pixels[:, 1], query_pixels[:, 0]] = 1.0
380
+ # (B * N, H * W).
381
+ masks_flatten = (
382
+ masks_flatten.reshape(-1, H * W).tile(1, N).reshape(-1, H * W) > 0.5
383
+ )
384
+ ####
385
+ # # Tracking loss.
386
+ # track_2d_loss = masked_l1_loss(
387
+ # pred_tracks_2d[masks_flatten][visibles],
388
+ # tracks_2d[visibles],
389
+ # mask=track_weights[visibles],
390
+ # quantile=0.98,
391
+ # ) / max(H, W)
392
+ # loss += track_2d_loss * self.losses_cfg.w_track
393
+ ####
394
+ depth_masks = (
395
+ masks[..., None] if not self.model.has_bg else valid_masks[..., None]
396
+ )
397
+
398
+ pred_depth = cast(torch.Tensor, rendered_all["depth"])
399
+ pred_disp = 1.0 / (pred_depth + 1e-5)
400
+ tgt_disp = 1.0 / (depths[..., None] + 1e-5)
401
+ depth_loss = masked_l1_loss(
402
+ pred_disp,
403
+ tgt_disp,
404
+ mask=depth_masks,
405
+ quantile=0.98,
406
+ )
407
+ # depth_loss = cauchy_loss_with_uncertainty(
408
+ # pred_disp.squeeze(-1),
409
+ # tgt_disp.squeeze(-1),
410
+ # depth_masks.squeeze(-1),
411
+ # self.depth_uncertainty_activation(self.depth_uncertainties)[ts],
412
+ # bias=1e-3,
413
+ # )
414
+ loss += depth_loss * self.losses_cfg.w_depth_reg
415
+
416
+ # mapped depth loss (using cached depth with EMA)
417
+ # mapped_depth_loss = 0.0
418
+ mapped_depth_gt = torch.cat([x.reshape(-1) for x in target_track_depths], dim=0)
419
+ mapped_depth_loss = masked_l1_loss(
420
+ 1 / (mapped_depth[masks_flatten][visibles] + 1e-5),
421
+ 1 / (mapped_depth_gt[visibles, None] + 1e-5),
422
+ track_weights[visibles],
423
+ )
424
+
425
+ loss += mapped_depth_loss * self.losses_cfg.w_depth_const
426
+
427
+ # depth_gradient_loss = 0.0
428
+ depth_gradient_loss = compute_gradient_loss(
429
+ pred_disp,
430
+ tgt_disp,
431
+ mask=depth_masks > 0.5,
432
+ quantile=0.95,
433
+ )
434
+ # depth_gradient_loss = compute_gradient_loss(
435
+ # pred_disps,
436
+ # ref_disps,
437
+ # mask=depth_masks.squeeze(-1) > 0.5,
438
+ # c=depth_uncertainty.detach(),
439
+ # mode="l1",
440
+ # bias=1e-3,
441
+ # )
442
+ loss += depth_gradient_loss * self.losses_cfg.w_depth_grad
443
+
444
+ # bases should be smooth.
445
+ small_accel_loss = compute_se3_smoothness_loss(
446
+ self.model.motion_bases.params["rots"],
447
+ self.model.motion_bases.params["transls"],
448
+ )
449
+ loss += small_accel_loss * self.losses_cfg.w_smooth_bases
450
+
451
+ # tracks should be smooth
452
+ ts = torch.clamp(ts, min=1, max=num_frames - 2)
453
+ ts_neighbors = torch.cat((ts - 1, ts, ts + 1))
454
+ transfms_nbs = self.model.compute_transforms(ts_neighbors) # (G, 3n, 3, 4)
455
+ means_fg_nbs = torch.einsum(
456
+ "pnij,pj->pni",
457
+ transfms_nbs,
458
+ F.pad(self.model.fg.params["means"], (0, 1), value=1.0),
459
+ )
460
+ means_fg_nbs = means_fg_nbs.reshape(
461
+ means_fg_nbs.shape[0], 3, -1, 3
462
+ ) # [G, 3, n, 3]
463
+ if self.losses_cfg.w_smooth_tracks > 0:
464
+ small_accel_loss_tracks = 0.5 * (
465
+ (2 * means_fg_nbs[:, 1:-1] - means_fg_nbs[:, :-2] - means_fg_nbs[:, 2:])
466
+ .norm(dim=-1)
467
+ .mean()
468
+ )
469
+ loss += small_accel_loss_tracks * self.losses_cfg.w_smooth_tracks
470
+
471
+ # Constrain the std of scales.
472
+ # TODO: do we want to penalize before or after exp?
473
+ loss += (
474
+ self.losses_cfg.w_scale_var
475
+ * torch.var(self.model.fg.params["scales"], dim=-1).mean()
476
+ )
477
+ if self.model.bg is not None:
478
+ loss += (
479
+ self.losses_cfg.w_scale_var
480
+ * torch.var(self.model.bg.params["scales"], dim=-1).mean()
481
+ )
482
+
483
+ # # sparsity loss
484
+ # loss += 0.01 * self.opacity_activation(self.opacities).abs().mean()
485
+
486
+ # Acceleration along ray direction should be small.
487
+ z_accel_loss = compute_z_acc_loss(means_fg_nbs, w2cs)
488
+ loss += self.losses_cfg.w_z_accel * z_accel_loss
489
+
490
+ #print("Compute appearance loss", loss.item())
491
+
492
+ # Feature loss.
493
+ target_size = (feature_maps.shape[1], feature_maps.shape[2])
494
+ rendered_features_permuted = rendered_features.permute(0, 3, 1, 2) # Shape: [N, D, W, H]
495
+ rendered_features_downsampled = F.interpolate(rendered_features_permuted, size=target_size, mode='bilinear', align_corners=True)
496
+ rendered_features_downsampled = rendered_features_downsampled.permute(0, 2, 3, 1) # Shape: [N, W', H', D]
497
+
498
+ feature_loss = F.l1_loss(rendered_features_downsampled, feature_maps)
499
+ loss += feature_loss * self.losses_cfg.w_feature
500
+
501
+ # Prepare stats for logging.
502
+ stats = {
503
+ "train/loss": loss.item(),
504
+ "train/rgb_loss": rgb_loss.item(),
505
+ "train/mask_loss": mask_loss.item(),
506
+ "train/feature_loss": feature_loss.item(),
507
+ "train/depth_loss": depth_loss.item(),
508
+ "train/depth_gradient_loss": depth_gradient_loss.item(),
509
+ "train/mapped_depth_loss": mapped_depth_loss.item(),
510
+ #"train/track_2d_loss": track_2d_loss.item(),
511
+ "train/small_accel_loss": small_accel_loss.item(),
512
+ "train/z_acc_loss": z_accel_loss.item(),
513
+ "train/num_gaussians": self.model.num_gaussians,
514
+ "train/num_fg_gaussians": self.model.num_fg_gaussians,
515
+ "train/num_bg_gaussians": self.model.num_bg_gaussians,
516
+ }
517
+
518
+ # Compute metrics.
519
+ with torch.no_grad():
520
+ psnr = self.psnr_metric(
521
+ rendered_imgs, imgs, masks if not self.model.has_bg else valid_masks
522
+ )
523
+ self.psnr_metric.reset()
524
+ stats["train/psnr"] = psnr
525
+ if self.model.has_bg:
526
+ bg_psnr = self.bg_psnr_metric(rendered_imgs, imgs, 1.0 - masks)
527
+ fg_psnr = self.fg_psnr_metric(rendered_imgs, imgs, masks)
528
+ self.bg_psnr_metric.reset()
529
+ self.fg_psnr_metric.reset()
530
+ stats["train/bg_psnr"] = bg_psnr
531
+ stats["train/fg_psnr"] = fg_psnr
532
+
533
+ stats.update(
534
+ **{
535
+ "train/num_rays_per_sec": num_rays_per_sec,
536
+ "train/num_rays_per_step": float(num_rays_per_step),
537
+ }
538
+ )
539
+
540
+ return loss, stats, num_rays_per_step, num_rays_per_sec
541
+
542
+ def log_dict(self, stats: dict):
543
+ for k, v in stats.items():
544
+ self.writer.add_scalar(k, v, self.global_step)
545
+
546
+ def run_control_steps(self):
547
+ global_step = self.global_step
548
+ # Adaptive gaussian control.
549
+ cfg = self.optim_cfg
550
+ num_frames = self.model.num_frames
551
+ ready = self._prepare_control_step()
552
+ if (
553
+ ready
554
+ and global_step > cfg.warmup_steps
555
+ and global_step % cfg.control_every == 0
556
+ and global_step < cfg.stop_control_steps
557
+ ):
558
+ if (
559
+ global_step < cfg.stop_densify_steps
560
+ and global_step % self.reset_opacity_every > num_frames
561
+ ):
562
+ self._densify_control_step(global_step)
563
+ if global_step % self.reset_opacity_every > min(3 * num_frames, 1000):
564
+ self._cull_control_step(global_step)
565
+ if global_step % self.reset_opacity_every == 0:
566
+ self._reset_opacity_control_step()
567
+
568
+ # Reset stats after every control.
569
+ for k in self.running_stats:
570
+ self.running_stats[k].zero_()
571
+
572
+ @torch.no_grad()
573
+ def _prepare_control_step(self) -> bool:
574
+ # Prepare for adaptive gaussian control based on the current stats.
575
+ if not (
576
+ self.model._current_radii is not None
577
+ and self.model._current_xys is not None
578
+ ):
579
+ guru.warning("Model not training, skipping control step preparation")
580
+ return False
581
+
582
+ batch_size = len(self._batched_xys)
583
+ # these quantities are for each rendered view and have shapes (C, G, *)
584
+ # must be aggregated over all views
585
+ for _current_xys, _current_radii, _current_img_wh in zip(
586
+ self._batched_xys, self._batched_radii, self._batched_img_wh
587
+ ):
588
+ sel = _current_radii > 0
589
+ gidcs = torch.where(sel)[1]
590
+ # normalize grads to [-1, 1] screen space
591
+ xys_grad = _current_xys.grad.clone()
592
+ xys_grad[..., 0] *= _current_img_wh[0] / 2.0 * batch_size
593
+ xys_grad[..., 1] *= _current_img_wh[1] / 2.0 * batch_size
594
+ self.running_stats["xys_grad_norm_acc"].index_add_(
595
+ 0, gidcs, xys_grad[sel].norm(dim=-1)
596
+ )
597
+ self.running_stats["vis_count"].index_add_(
598
+ 0, gidcs, torch.ones_like(gidcs, dtype=torch.int64)
599
+ )
600
+ max_radii = torch.maximum(
601
+ self.running_stats["max_radii"].index_select(0, gidcs),
602
+ _current_radii[sel] / max(_current_img_wh),
603
+ )
604
+ self.running_stats["max_radii"].index_put((gidcs,), max_radii)
605
+ return True
606
+
607
+ @torch.no_grad()
608
+ def _densify_control_step(self, global_step):
609
+ assert (self.running_stats["vis_count"] > 0).any()
610
+
611
+ cfg = self.optim_cfg
612
+ xys_grad_avg = self.running_stats["xys_grad_norm_acc"] / self.running_stats[
613
+ "vis_count"
614
+ ].clamp_min(1)
615
+ is_grad_too_high = xys_grad_avg > cfg.densify_xys_grad_threshold
616
+ # Split gaussians.
617
+ scales = self.model.get_scales_all()
618
+ is_scale_too_big = scales.amax(dim=-1) > cfg.densify_scale_threshold
619
+ if global_step < cfg.stop_control_by_screen_steps:
620
+ is_radius_too_big = (
621
+ self.running_stats["max_radii"] > cfg.densify_screen_threshold
622
+ )
623
+ else:
624
+ is_radius_too_big = torch.zeros_like(is_grad_too_high, dtype=torch.bool)
625
+
626
+ should_split = is_grad_too_high & (is_scale_too_big | is_radius_too_big)
627
+ should_dup = is_grad_too_high & ~is_scale_too_big
628
+
629
+ num_fg = self.model.num_fg_gaussians
630
+ should_fg_split = should_split[:num_fg]
631
+ num_fg_splits = int(should_fg_split.sum().item())
632
+ should_fg_dup = should_dup[:num_fg]
633
+ num_fg_dups = int(should_fg_dup.sum().item())
634
+
635
+ should_bg_split = should_split[num_fg:]
636
+ num_bg_splits = int(should_bg_split.sum().item())
637
+ should_bg_dup = should_dup[num_fg:]
638
+ num_bg_dups = int(should_bg_dup.sum().item())
639
+
640
+ fg_param_map = self.model.fg.densify_params(should_fg_split, should_fg_dup)
641
+ for param_name, new_params in fg_param_map.items():
642
+ full_param_name = f"fg.params.{param_name}"
643
+ optimizer = self.optimizers[full_param_name]
644
+ dup_in_optim(
645
+ optimizer,
646
+ [new_params],
647
+ should_fg_split,
648
+ num_fg_splits * 2 + num_fg_dups,
649
+ )
650
+
651
+ if self.model.bg is not None:
652
+ bg_param_map = self.model.bg.densify_params(should_bg_split, should_bg_dup)
653
+ for param_name, new_params in bg_param_map.items():
654
+ full_param_name = f"bg.params.{param_name}"
655
+ optimizer = self.optimizers[full_param_name]
656
+ dup_in_optim(
657
+ optimizer,
658
+ [new_params],
659
+ should_bg_split,
660
+ num_bg_splits * 2 + num_bg_dups,
661
+ )
662
+
663
+ # update running stats
664
+ for k, v in self.running_stats.items():
665
+ v_fg, v_bg = v[:num_fg], v[num_fg:]
666
+ new_v = torch.cat(
667
+ [
668
+ v_fg[~should_fg_split],
669
+ v_fg[should_fg_dup],
670
+ v_fg[should_fg_split].repeat(2),
671
+ v_bg[~should_bg_split],
672
+ v_bg[should_bg_dup],
673
+ v_bg[should_bg_split].repeat(2),
674
+ ],
675
+ dim=0,
676
+ )
677
+ self.running_stats[k] = new_v
678
+ guru.info(
679
+ f"Split {should_split.sum().item()} gaussians, "
680
+ f"Duplicated {should_dup.sum().item()} gaussians, "
681
+ f"{self.model.num_gaussians} gaussians left"
682
+ )
683
+
684
+ @torch.no_grad()
685
+ def _cull_control_step(self, global_step):
686
+ # Cull gaussians.
687
+ cfg = self.optim_cfg
688
+ opacities = self.model.get_opacities_all()
689
+ device = opacities.device
690
+ is_opacity_too_small = opacities < cfg.cull_opacity_threshold
691
+ is_radius_too_big = torch.zeros_like(is_opacity_too_small, dtype=torch.bool)
692
+ is_scale_too_big = torch.zeros_like(is_opacity_too_small, dtype=torch.bool)
693
+ cull_scale_threshold = (
694
+ torch.ones(len(is_scale_too_big), device=device) * cfg.cull_scale_threshold
695
+ )
696
+ num_fg = self.model.num_fg_gaussians
697
+ cull_scale_threshold[num_fg:] *= self.model.bg_scene_scale
698
+ if global_step > self.reset_opacity_every:
699
+ scales = self.model.get_scales_all()
700
+ is_scale_too_big = scales.amax(dim=-1) > cull_scale_threshold
701
+ if global_step < cfg.stop_control_by_screen_steps:
702
+ is_radius_too_big = (
703
+ self.running_stats["max_radii"] > cfg.cull_screen_threshold
704
+ )
705
+ should_cull = is_opacity_too_small | is_radius_too_big | is_scale_too_big
706
+ should_fg_cull = should_cull[:num_fg]
707
+ should_bg_cull = should_cull[num_fg:]
708
+
709
+ fg_param_map = self.model.fg.cull_params(should_fg_cull)
710
+ for param_name, new_params in fg_param_map.items():
711
+ full_param_name = f"fg.params.{param_name}"
712
+ optimizer = self.optimizers[full_param_name]
713
+ remove_from_optim(optimizer, [new_params], should_fg_cull)
714
+
715
+ if self.model.bg is not None:
716
+ bg_param_map = self.model.bg.cull_params(should_bg_cull)
717
+ for param_name, new_params in bg_param_map.items():
718
+ full_param_name = f"bg.params.{param_name}"
719
+ optimizer = self.optimizers[full_param_name]
720
+ remove_from_optim(optimizer, [new_params], should_bg_cull)
721
+
722
+ # update running stats
723
+ for k, v in self.running_stats.items():
724
+ self.running_stats[k] = v[~should_cull]
725
+
726
+ guru.info(
727
+ f"Culled {should_cull.sum().item()} gaussians, "
728
+ f"{self.model.num_gaussians} gaussians left"
729
+ )
730
+
731
+ @torch.no_grad()
732
+ def _reset_opacity_control_step(self):
733
+ # Reset gaussian opacities.
734
+ new_val = torch.logit(torch.tensor(0.8 * self.optim_cfg.cull_opacity_threshold))
735
+ for part in ["fg", "bg"]:
736
+ part_params = getattr(self.model, part).reset_opacities(new_val)
737
+ # Modify optimizer states by new assignment.
738
+ for param_name, new_params in part_params.items():
739
+ full_param_name = f"{part}.params.{param_name}"
740
+ optimizer = self.optimizers[full_param_name]
741
+ reset_in_optim(optimizer, [new_params])
742
+ guru.info("Reset opacities")
743
+
744
+ def configure_optimizers(self):
745
+ def _exponential_decay(step, *, lr_init, lr_final):
746
+ t = np.clip(step / self.optim_cfg.max_steps, 0.0, 1.0)
747
+ lr = np.exp(np.log(lr_init) * (1 - t) + np.log(lr_final) * t)
748
+ return lr / lr_init
749
+
750
+ lr_dict = asdict(self.lr_cfg)
751
+ optimizers = {}
752
+ schedulers = {}
753
+ # named parameters will be [part].params.[field]
754
+ # e.g. fg.params.means
755
+ # lr config is a nested dict for each fg/bg part
756
+ for name, params in self.model.named_parameters():
757
+ part, _, field = name.split(".")
758
+ lr = lr_dict[part][field]
759
+ optim = torch.optim.Adam([{"params": params, "lr": lr, "name": name}])
760
+
761
+ if "scales" in name:
762
+ fnc = functools.partial(_exponential_decay, lr_final=0.1 * lr)
763
+ else:
764
+ fnc = lambda _, **__: 1.0
765
+
766
+ optimizers[name] = optim
767
+ schedulers[name] = torch.optim.lr_scheduler.LambdaLR(
768
+ optim, functools.partial(fnc, lr_init=lr)
769
+ )
770
+ return optimizers, schedulers
771
+
772
+
773
+ def dup_in_optim(optimizer, new_params: list, should_dup: torch.Tensor, num_dups: int):
774
+ assert len(optimizer.param_groups) == len(new_params)
775
+ for i, p_new in enumerate(new_params):
776
+ old_params = optimizer.param_groups[i]["params"][0]
777
+ param_state = optimizer.state[old_params]
778
+ if len(param_state) == 0:
779
+ return
780
+ for key in param_state:
781
+ if key == "step":
782
+ continue
783
+ p = param_state[key]
784
+ param_state[key] = torch.cat(
785
+ [p[~should_dup], p.new_zeros(num_dups, *p.shape[1:])],
786
+ dim=0,
787
+ )
788
+ del optimizer.state[old_params]
789
+ optimizer.state[p_new] = param_state
790
+ optimizer.param_groups[i]["params"] = [p_new]
791
+ del old_params
792
+ torch.cuda.empty_cache()
793
+
794
+
795
+ def remove_from_optim(optimizer, new_params: list, _should_cull: torch.Tensor):
796
+ assert len(optimizer.param_groups) == len(new_params)
797
+ for i, p_new in enumerate(new_params):
798
+ old_params = optimizer.param_groups[i]["params"][0]
799
+ param_state = optimizer.state[old_params]
800
+ if len(param_state) == 0:
801
+ return
802
+ for key in param_state:
803
+ if key == "step":
804
+ continue
805
+ param_state[key] = param_state[key][~_should_cull]
806
+ del optimizer.state[old_params]
807
+ optimizer.state[p_new] = param_state
808
+ optimizer.param_groups[i]["params"] = [p_new]
809
+ del old_params
810
+ torch.cuda.empty_cache()
811
+
812
+
813
+ def reset_in_optim(optimizer, new_params: list):
814
+ assert len(optimizer.param_groups) == len(new_params)
815
+ for i, p_new in enumerate(new_params):
816
+ old_params = optimizer.param_groups[i]["params"][0]
817
+ param_state = optimizer.state[old_params]
818
+ if len(param_state) == 0:
819
+ return
820
+ for key in param_state:
821
+ param_state[key] = torch.zeros_like(param_state[key])
822
+ del optimizer.state[old_params]
823
+ optimizer.state[p_new] = param_state
824
+ optimizer.param_groups[i]["params"] = [p_new]
825
+ del old_params
826
+ torch.cuda.empty_cache()
827
+
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/trajectories.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import roma
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .transforms import rt_to_mat4
7
+
8
+
9
+ def get_avg_w2c(w2cs: torch.Tensor):
10
+ c2ws = torch.linalg.inv(w2cs)
11
+ # 1. Compute the center
12
+ center = c2ws[:, :3, -1].mean(0)
13
+ # 2. Compute the z axis
14
+ z = F.normalize(c2ws[:, :3, 2].mean(0), dim=-1)
15
+ # 3. Compute axis y' (no need to normalize as it's not the final output)
16
+ y_ = c2ws[:, :3, 1].mean(0) # (3)
17
+ # 4. Compute the x axis
18
+ x = F.normalize(torch.cross(y_, z, dim=-1), dim=-1) # (3)
19
+ # 5. Compute the y axis (as z and x are normalized, y is already of norm 1)
20
+ y = torch.cross(z, x, dim=-1) # (3)
21
+ avg_c2w = rt_to_mat4(torch.stack([x, y, z], 1), center)
22
+ avg_w2c = torch.linalg.inv(avg_c2w)
23
+ return avg_w2c
24
+
25
+
26
+ def get_lookat(origins: torch.Tensor, viewdirs: torch.Tensor) -> torch.Tensor:
27
+ """Triangulate a set of rays to find a single lookat point.
28
+
29
+ Args:
30
+ origins (torch.Tensor): A (N, 3) array of ray origins.
31
+ viewdirs (torch.Tensor): A (N, 3) array of ray view directions.
32
+
33
+ Returns:
34
+ torch.Tensor: A (3,) lookat point.
35
+ """
36
+
37
+ viewdirs = torch.nn.functional.normalize(viewdirs, dim=-1)
38
+ eye = torch.eye(3, device=origins.device, dtype=origins.dtype)[None]
39
+ # Calculate projection matrix I - rr^T
40
+ I_min_cov = eye - (viewdirs[..., None] * viewdirs[..., None, :])
41
+ # Compute sum of projections
42
+ sum_proj = I_min_cov.matmul(origins[..., None]).sum(dim=-3)
43
+ # Solve for the intersection point using least squares
44
+ lookat = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
45
+ # Check NaNs.
46
+ assert not torch.any(torch.isnan(lookat))
47
+ return lookat
48
+
49
+
50
+ def get_lookat_w2cs(positions: torch.Tensor, lookat: torch.Tensor, up: torch.Tensor):
51
+ """
52
+ Args:
53
+ positions: (N, 3) tensor of camera positions
54
+ lookat: (3,) tensor of lookat point
55
+ up: (3,) tensor of up vector
56
+
57
+ Returns:
58
+ w2cs: (N, 3, 3) tensor of world to camera rotation matrices
59
+ """
60
+ forward_vectors = F.normalize(lookat - positions, dim=-1)
61
+ right_vectors = F.normalize(torch.cross(forward_vectors, up[None], dim=-1), dim=-1)
62
+ down_vectors = F.normalize(
63
+ torch.cross(forward_vectors, right_vectors, dim=-1), dim=-1
64
+ )
65
+ Rs = torch.stack([right_vectors, down_vectors, forward_vectors], dim=-1)
66
+ w2cs = torch.linalg.inv(rt_to_mat4(Rs, positions))
67
+ return w2cs
68
+
69
+
70
+ def get_arc_w2cs(
71
+ ref_w2c: torch.Tensor,
72
+ lookat: torch.Tensor,
73
+ up: torch.Tensor,
74
+ num_frames: int,
75
+ degree: float,
76
+ **_,
77
+ ) -> torch.Tensor:
78
+ ref_position = torch.linalg.inv(ref_w2c)[:3, 3]
79
+ thetas = (
80
+ torch.sin(
81
+ torch.linspace(0.0, torch.pi * 2.0, num_frames + 1, device=ref_w2c.device)[
82
+ :-1
83
+ ]
84
+ )
85
+ * (degree / 2.0)
86
+ / 180.0
87
+ * torch.pi
88
+ )
89
+ positions = torch.einsum(
90
+ "nij,j->ni",
91
+ roma.rotvec_to_rotmat(thetas[:, None] * up[None]),
92
+ ref_position - lookat,
93
+ )
94
+ return get_lookat_w2cs(positions, lookat, up)
95
+
96
+
97
+ def get_lemniscate_w2cs(
98
+ ref_w2c: torch.Tensor,
99
+ lookat: torch.Tensor,
100
+ up: torch.Tensor,
101
+ num_frames: int,
102
+ degree: float,
103
+ **_,
104
+ ) -> torch.Tensor:
105
+ ref_c2w = torch.linalg.inv(ref_w2c)
106
+ a = torch.linalg.norm(ref_c2w[:3, 3] - lookat) * np.tan(degree / 360 * np.pi)
107
+ # Lemniscate curve in camera space. Starting at the origin.
108
+ thetas = (
109
+ torch.linspace(0, 2 * torch.pi, num_frames + 1, device=ref_w2c.device)[:-1]
110
+ + torch.pi / 2
111
+ )
112
+ positions = torch.stack(
113
+ [
114
+ a * torch.cos(thetas) / (1 + torch.sin(thetas) ** 2),
115
+ a * torch.cos(thetas) * torch.sin(thetas) / (1 + torch.sin(thetas) ** 2),
116
+ torch.zeros(num_frames, device=ref_w2c.device),
117
+ ],
118
+ dim=-1,
119
+ )
120
+ # Transform to world space.
121
+ positions = torch.einsum(
122
+ "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
123
+ )
124
+ return get_lookat_w2cs(positions, lookat, up)
125
+
126
+
127
+ def get_spiral_w2cs(
128
+ ref_w2c: torch.Tensor,
129
+ lookat: torch.Tensor,
130
+ up: torch.Tensor,
131
+ num_frames: int,
132
+ rads: float | torch.Tensor,
133
+ zrate: float,
134
+ rots: int,
135
+ **_,
136
+ ) -> torch.Tensor:
137
+ ref_c2w = torch.linalg.inv(ref_w2c)
138
+ thetas = torch.linspace(
139
+ 0, 2 * torch.pi * rots, num_frames + 1, device=ref_w2c.device
140
+ )[:-1]
141
+ # Spiral curve in camera space. Starting at the origin.
142
+ if isinstance(rads, torch.Tensor):
143
+ rads = rads.reshape(-1, 3).to(ref_w2c.device)
144
+ positions = (
145
+ torch.stack(
146
+ [
147
+ torch.cos(thetas),
148
+ -torch.sin(thetas),
149
+ -torch.sin(thetas * zrate),
150
+ ],
151
+ dim=-1,
152
+ )
153
+ * rads
154
+ )
155
+ # Transform to world space.
156
+ positions = torch.einsum(
157
+ "ij,nj->ni", ref_c2w[:3], F.pad(positions, (0, 1), value=1.0)
158
+ )
159
+ return get_lookat_w2cs(positions, lookat, up)
160
+
161
+
162
+ def get_wander_w2cs(ref_w2c, focal_length, num_frames, **_):
163
+ device = ref_w2c.device
164
+ c2w = np.linalg.inv(ref_w2c.detach().cpu().numpy())
165
+ max_disp = 48.0
166
+
167
+ max_trans = max_disp / focal_length
168
+ output_poses = []
169
+
170
+ for i in range(num_frames):
171
+ x_trans = max_trans * np.sin(2.0 * np.pi * float(i) / float(num_frames))
172
+ y_trans = 0.0
173
+ z_trans = max_trans * np.cos(2.0 * np.pi * float(i) / float(num_frames)) / 2.0
174
+
175
+ i_pose = np.concatenate(
176
+ [
177
+ np.concatenate(
178
+ [
179
+ np.eye(3),
180
+ np.array([x_trans, y_trans, z_trans])[:, np.newaxis],
181
+ ],
182
+ axis=1,
183
+ ),
184
+ np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :],
185
+ ],
186
+ axis=0,
187
+ )
188
+
189
+ i_pose = np.linalg.inv(i_pose)
190
+
191
+ ref_pose = np.concatenate(
192
+ [c2w[:3, :4], np.array([0.0, 0.0, 0.0, 1.0])[np.newaxis, :]], axis=0
193
+ )
194
+
195
+ render_pose = np.dot(ref_pose, i_pose)
196
+ output_poses.append(render_pose)
197
+ output_poses = torch.from_numpy(np.array(output_poses, dtype=np.float32)).to(device)
198
+ w2cs = torch.linalg.inv(output_poses)
199
+
200
+ return w2cs
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/transforms.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+
3
+ import roma
4
+ import torch
5
+ import torch.nn.functional as F
6
+
7
+
8
+ def rt_to_mat4(
9
+ R: torch.Tensor, t: torch.Tensor, s: torch.Tensor | None = None
10
+ ) -> torch.Tensor:
11
+ """
12
+ Args:
13
+ R (torch.Tensor): (..., 3, 3).
14
+ t (torch.Tensor): (..., 3).
15
+ s (torch.Tensor): (...,).
16
+
17
+ Returns:
18
+ torch.Tensor: (..., 4, 4)
19
+ """
20
+ mat34 = torch.cat([R, t[..., None]], dim=-1)
21
+ if s is None:
22
+ bottom = (
23
+ mat34.new_tensor([[0.0, 0.0, 0.0, 1.0]])
24
+ .reshape((1,) * (mat34.dim() - 2) + (1, 4))
25
+ .expand(mat34.shape[:-2] + (1, 4))
26
+ )
27
+ else:
28
+ bottom = F.pad(1.0 / s[..., None, None], (3, 0), value=0.0)
29
+ mat4 = torch.cat([mat34, bottom], dim=-2)
30
+ return mat4
31
+
32
+
33
+ def rmat_to_cont_6d(matrix):
34
+ """
35
+ :param matrix (*, 3, 3)
36
+ :returns 6d vector (*, 6)
37
+ """
38
+ return torch.cat([matrix[..., 0], matrix[..., 1]], dim=-1)
39
+
40
+
41
+ def cont_6d_to_rmat(cont_6d):
42
+ """
43
+ :param 6d vector (*, 6)
44
+ :returns matrix (*, 3, 3)
45
+ """
46
+ x1 = cont_6d[..., 0:3]
47
+ y1 = cont_6d[..., 3:6]
48
+
49
+ x = F.normalize(x1, dim=-1)
50
+ y = F.normalize(y1 - (y1 * x).sum(dim=-1, keepdim=True) * x, dim=-1)
51
+ z = torch.linalg.cross(x, y, dim=-1)
52
+
53
+ return torch.stack([x, y, z], dim=-1)
54
+
55
+
56
+ def solve_procrustes(
57
+ src: torch.Tensor,
58
+ dst: torch.Tensor,
59
+ weights: torch.Tensor | None = None,
60
+ enforce_se3: bool = False,
61
+ rot_type: Literal["quat", "mat", "6d"] = "quat",
62
+ ):
63
+ """
64
+ Solve the Procrustes problem to align two point clouds, by solving the
65
+ following problem:
66
+
67
+ min_{s, R, t} || s * (src @ R.T + t) - dst ||_2, s.t. R.T @ R = I and det(R) = 1.
68
+
69
+ Args:
70
+ src (torch.Tensor): (N, 3).
71
+ dst (torch.Tensor): (N, 3).
72
+ weights (torch.Tensor | None): (N,), optional weights for alignment.
73
+ enforce_se3 (bool): Whether to enforce the transfm to be SE3.
74
+
75
+ Returns:
76
+ sim3 (tuple[torch.Tensor, torch.Tensor, torch.Tensor]):
77
+ q (torch.Tensor): (4,), rotation component in quaternion of WXYZ
78
+ format.
79
+ t (torch.Tensor): (3,), translation component.
80
+ s (torch.Tensor): (), scale component.
81
+ error (torch.Tensor): (), average L2 distance after alignment.
82
+ """
83
+ # Compute weights.
84
+ if weights is None:
85
+ weights = src.new_ones(src.shape[0])
86
+ weights = weights[:, None] / weights.sum()
87
+ # Normalize point positions.
88
+ src_mean = (src * weights).sum(dim=0)
89
+ dst_mean = (dst * weights).sum(dim=0)
90
+ src_cent = src - src_mean
91
+ dst_cent = dst - dst_mean
92
+ # Normalize point scales.
93
+ if not enforce_se3:
94
+ src_scale = (src_cent**2 * weights).sum(dim=-1).mean().sqrt()
95
+ dst_scale = (dst_cent**2 * weights).sum(dim=-1).mean().sqrt()
96
+ else:
97
+ src_scale = dst_scale = src.new_tensor(1.0)
98
+ src_scaled = src_cent / src_scale
99
+ dst_scaled = dst_cent / dst_scale
100
+ # Compute the matrix for the singular value decomposition (SVD).
101
+ matrix = (weights * dst_scaled).T @ src_scaled
102
+ U, _, Vh = torch.linalg.svd(matrix)
103
+ # Special reflection case.
104
+ S = torch.eye(3, device=src.device)
105
+ if torch.det(U) * torch.det(Vh) < 0:
106
+ S[2, 2] = -1
107
+ R = U @ S @ Vh
108
+ # Compute the transformation.
109
+ if rot_type == "quat":
110
+ rot = roma.rotmat_to_unitquat(R).roll(1, dims=-1)
111
+ elif rot_type == "6d":
112
+ rot = rmat_to_cont_6d(R)
113
+ else:
114
+ rot = R
115
+ s = dst_scale / src_scale
116
+ t = dst_mean / s - src_mean @ R.T
117
+ sim3 = rot, t, s
118
+ # Debug: error.
119
+ procrustes_dst = torch.einsum(
120
+ "ij,nj->ni", rt_to_mat4(R, t, s), F.pad(src, (0, 1), value=1.0)
121
+ )
122
+ procrustes_dst = procrustes_dst[:, :3] / procrustes_dst[:, 3:]
123
+ error_before = (torch.linalg.norm(dst - src, dim=-1) * weights[:, 0]).sum()
124
+ error = (torch.linalg.norm(dst - procrustes_dst, dim=-1) * weights[:, 0]).sum()
125
+ # print(f"Procrustes error: {error_before} -> {error}")
126
+ # if error_before < error:
127
+ # print("Something is wrong.")
128
+ # __import__("ipdb").set_trace()
129
+ return sim3, (error.item(), error_before.item())
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/validator.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ from dataclasses import asdict
6
+ from typing import cast
7
+
8
+ import imageio as iio
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from loguru import logger as guru
13
+ from nerfview import CameraState, Viewer
14
+ from pytorch_msssim import SSIM
15
+ from torch.utils.data import DataLoader, Dataset
16
+ from torch.utils.tensorboard import SummaryWriter
17
+ from tqdm import tqdm
18
+ from sklearn.decomposition import PCA
19
+
20
+ from flow3d.configs import LossesConfig, OptimizerConfig, SceneLRConfig
21
+ from flow3d.data.utils import normalize_coords, to_device
22
+ from flow3d.metrics import PCK, mLPIPS, mPSNR, mSSIM
23
+ from flow3d.scene_model import SceneModel
24
+ from flow3d.vis.utils import (
25
+ apply_depth_colormap,
26
+ make_video_divisble,
27
+ plot_correspondences,
28
+ )
29
+
30
+
31
+ class Validator:
32
+ def __init__(
33
+ self,
34
+ model: SceneModel,
35
+ device: torch.device,
36
+ train_loader: DataLoader | None,
37
+ val_img_loader: DataLoader | None,
38
+ val_kpt_loader: DataLoader | None,
39
+ save_dir: str,
40
+ ):
41
+ self.model = model
42
+ self.device = device
43
+ self.train_loader = train_loader
44
+ self.val_img_loader = val_img_loader
45
+ self.val_kpt_loader = val_kpt_loader
46
+ self.save_dir = save_dir
47
+ self.has_bg = self.model.has_bg
48
+
49
+ # metrics
50
+ self.ssim = SSIM(data_range=1.0, size_average=True, channel=3)
51
+ self.psnr_metric = mPSNR()
52
+ self.ssim_metric = mSSIM()
53
+ self.lpips_metric = mLPIPS().to(device)
54
+ self.fg_psnr_metric = mPSNR()
55
+ self.fg_ssim_metric = mSSIM()
56
+ self.fg_lpips_metric = mLPIPS().to(device)
57
+ self.bg_psnr_metric = mPSNR()
58
+ self.bg_ssim_metric = mSSIM()
59
+ self.bg_lpips_metric = mLPIPS().to(device)
60
+ self.pck_metric = PCK()
61
+
62
+ def reset_metrics(self):
63
+ self.psnr_metric.reset()
64
+ self.ssim_metric.reset()
65
+ self.lpips_metric.reset()
66
+ self.fg_psnr_metric.reset()
67
+ self.fg_ssim_metric.reset()
68
+ self.fg_lpips_metric.reset()
69
+ self.bg_psnr_metric.reset()
70
+ self.bg_ssim_metric.reset()
71
+ self.bg_lpips_metric.reset()
72
+ self.pck_metric.reset()
73
+
74
+ @torch.no_grad()
75
+ def validate(self):
76
+ self.reset_metrics()
77
+ metric_imgs = self.validate_imgs() or {}
78
+ metric_kpts = self.validate_keypoints() or {}
79
+ return {**metric_imgs, **metric_kpts}
80
+
81
+ @torch.no_grad()
82
+ def validate_imgs(self):
83
+ guru.info("rendering validation images...")
84
+ if self.val_img_loader is None:
85
+ return
86
+
87
+ for batch in tqdm(self.val_img_loader, desc="render val images"):
88
+ batch = to_device(batch, self.device)
89
+ frame_name = batch["frame_names"][0]
90
+ t = batch["ts"][0]
91
+ # (1, 4, 4).
92
+ w2c = batch["w2cs"]
93
+ # (1, 3, 3).
94
+ K = batch["Ks"]
95
+ # (1, H, W, 3).
96
+ img = batch["imgs"]
97
+ # (1, H, W).
98
+ valid_mask = batch.get(
99
+ "valid_masks", torch.ones_like(batch["imgs"][..., 0])
100
+ )
101
+ # (1, H, W).
102
+ fg_mask = batch["masks"]
103
+
104
+ # (H, W).
105
+ covisible_mask = batch.get(
106
+ "covisible_masks",
107
+ torch.ones_like(fg_mask)[None],
108
+ )
109
+ W, H = img_wh = img[0].shape[-2::-1]
110
+ rendered = self.model.render(t, w2c, K, img_wh, return_depth=True)
111
+
112
+ # Compute metrics.
113
+ valid_mask *= covisible_mask
114
+ fg_valid_mask = fg_mask * valid_mask
115
+ bg_valid_mask = (1 - fg_mask) * valid_mask
116
+ main_valid_mask = valid_mask if self.has_bg else fg_valid_mask
117
+
118
+ self.psnr_metric.update(rendered["img"], img, main_valid_mask)
119
+ self.ssim_metric.update(rendered["img"], img, main_valid_mask)
120
+ self.lpips_metric.update(rendered["img"], img, main_valid_mask)
121
+
122
+ if self.has_bg:
123
+ self.fg_psnr_metric.update(rendered["img"], img, fg_valid_mask)
124
+ self.fg_ssim_metric.update(rendered["img"], img, fg_valid_mask)
125
+ self.fg_lpips_metric.update(rendered["img"], img, fg_valid_mask)
126
+
127
+ self.bg_psnr_metric.update(rendered["img"], img, bg_valid_mask)
128
+ self.bg_ssim_metric.update(rendered["img"], img, bg_valid_mask)
129
+ self.bg_lpips_metric.update(rendered["img"], img, bg_valid_mask)
130
+
131
+ # Dump results.
132
+ results_dir = osp.join(self.save_dir, "results", "rgb")
133
+ os.makedirs(results_dir, exist_ok=True)
134
+ iio.imwrite(
135
+ osp.join(results_dir, f"{frame_name}.png"),
136
+ (rendered["img"][0].cpu().numpy() * 255).astype(np.uint8),
137
+ )
138
+
139
+ return {
140
+ "val/psnr": self.psnr_metric.compute(),
141
+ "val/ssim": self.ssim_metric.compute(),
142
+ "val/lpips": self.lpips_metric.compute(),
143
+ "val/fg_psnr": self.fg_psnr_metric.compute(),
144
+ "val/fg_ssim": self.fg_ssim_metric.compute(),
145
+ "val/fg_lpips": self.fg_lpips_metric.compute(),
146
+ "val/bg_psnr": self.bg_psnr_metric.compute(),
147
+ "val/bg_ssim": self.bg_ssim_metric.compute(),
148
+ "val/bg_lpips": self.bg_lpips_metric.compute(),
149
+ }
150
+
151
+ @torch.no_grad()
152
+ def validate_keypoints(self):
153
+ if self.val_kpt_loader is None:
154
+ return
155
+ pred_keypoints_3d_all = []
156
+ time_ids = self.val_kpt_loader.dataset.time_ids.tolist()
157
+ h, w = self.val_kpt_loader.dataset.dataset.imgs.shape[1:3]
158
+ pred_train_depths = np.zeros((len(time_ids), h, w))
159
+
160
+ for batch in tqdm(self.val_kpt_loader, desc="render val keypoints"):
161
+ batch = to_device(batch, self.device)
162
+ # (2,).
163
+ ts = batch["ts"][0]
164
+ # (2, 4, 4).
165
+ w2cs = batch["w2cs"][0]
166
+ # (2, 3, 3).
167
+ Ks = batch["Ks"][0]
168
+ # (2, H, W, 3).
169
+ imgs = batch["imgs"][0]
170
+ # (2, P, 3).
171
+ keypoints = batch["keypoints"][0]
172
+ # (P,)
173
+ keypoint_masks = (keypoints[..., -1] > 0.5).all(dim=0)
174
+ src_keypoints, target_keypoints = keypoints[:, keypoint_masks, :2]
175
+ W, H = img_wh = imgs.shape[-2:0:-1]
176
+ rendered = self.model.render(
177
+ ts[0].item(),
178
+ w2cs[:1],
179
+ Ks[:1],
180
+ img_wh,
181
+ target_ts=ts[1:],
182
+ target_w2cs=w2cs[1:],
183
+ return_depth=True,
184
+ )
185
+ pred_tracks_3d = rendered["tracks_3d"][0, ..., 0, :]
186
+ pred_tracks_2d = torch.einsum("ij,hwj->hwi", Ks[1], pred_tracks_3d)
187
+ pred_tracks_2d = pred_tracks_2d[..., :2] / torch.clamp(
188
+ pred_tracks_2d[..., -1:], min=1e-6
189
+ )
190
+ pred_keypoints = F.grid_sample(
191
+ pred_tracks_2d[None].permute(0, 3, 1, 2),
192
+ normalize_coords(src_keypoints, H, W)[None, None],
193
+ align_corners=True,
194
+ ).permute(0, 2, 3, 1)[0, 0]
195
+
196
+ # Compute metrics.
197
+ self.pck_metric.update(pred_keypoints, target_keypoints, max(img_wh) * 0.05)
198
+
199
+ padded_keypoints_3d = torch.zeros_like(keypoints[0])
200
+ pred_keypoints_3d = F.grid_sample(
201
+ pred_tracks_3d[None].permute(0, 3, 1, 2),
202
+ normalize_coords(src_keypoints, H, W)[None, None],
203
+ align_corners=True,
204
+ ).permute(0, 2, 3, 1)[0, 0]
205
+ # Transform 3D keypoints back to world space.
206
+ pred_keypoints_3d = torch.einsum(
207
+ "ij,pj->pi",
208
+ torch.linalg.inv(w2cs[1])[:3],
209
+ F.pad(pred_keypoints_3d, (0, 1), value=1.0),
210
+ )
211
+ padded_keypoints_3d[keypoint_masks] = pred_keypoints_3d
212
+ # Cache predicted keypoints.
213
+ pred_keypoints_3d_all.append(padded_keypoints_3d.cpu().numpy())
214
+ pred_train_depths[time_ids.index(ts[0].item())] = (
215
+ rendered["depth"][0, ..., 0].cpu().numpy()
216
+ )
217
+
218
+ # Dump unified results.
219
+ all_Ks = self.val_kpt_loader.dataset.dataset.Ks
220
+ all_w2cs = self.val_kpt_loader.dataset.dataset.w2cs
221
+
222
+ keypoint_result_dict = {
223
+ "Ks": all_Ks[time_ids].cpu().numpy(),
224
+ "w2cs": all_w2cs[time_ids].cpu().numpy(),
225
+ "pred_keypoints_3d": np.stack(pred_keypoints_3d_all, 0),
226
+ "pred_train_depths": pred_train_depths,
227
+ }
228
+
229
+ results_dir = osp.join(self.save_dir, "results")
230
+ os.makedirs(results_dir, exist_ok=True)
231
+ np.savez(
232
+ osp.join(results_dir, "keypoints.npz"),
233
+ **keypoint_result_dict,
234
+ )
235
+ guru.info(
236
+ f"Dumped keypoint results to {results_dir=} {keypoint_result_dict['pred_keypoints_3d'].shape=}"
237
+ )
238
+
239
+ return {"val/pck": self.pck_metric.compute()}
240
+
241
+ @torch.no_grad()
242
+ def save_train_videos(self, epoch: int):
243
+ pca = PCA(n_components = 3)
244
+ if self.train_loader is None:
245
+ return
246
+ video_dir = osp.join(self.save_dir, "videos", f"epoch_{epoch:04d}")
247
+ os.makedirs(video_dir, exist_ok=True)
248
+ fps = getattr(self.train_loader.dataset.dataset, "fps", 15.0)
249
+ # Render video.
250
+ video = []
251
+ video_PCA = []
252
+ ref_pred_depths = []
253
+ masks = []
254
+ depth_min, depth_max = 1e6, 0
255
+ for batch_idx, batch in enumerate(
256
+ tqdm(self.train_loader, desc="Rendering video", leave=False)
257
+ ):
258
+ batch = {
259
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
260
+ for k, v in batch.items()
261
+ }
262
+ # ().
263
+ t = batch["ts"][0]
264
+ # (4, 4).
265
+ w2c = batch["w2cs"][0]
266
+ # (3, 3).
267
+ K = batch["Ks"][0]
268
+ # (H, W, 3).
269
+ img = batch["imgs"][0]
270
+ # (H, W).
271
+ depth = batch["depths"][0]
272
+
273
+ img_wh = img.shape[-2::-1]
274
+ rendered = self.model.render(
275
+ t, w2c[None], K[None], img_wh, return_depth=True, return_mask=True
276
+ )
277
+
278
+ feature_map = rendered["img_feature"][0].cpu()
279
+ H, W, D = feature_map.shape
280
+ flattened_feature_map = feature_map.view(H * W, D)
281
+ pca.fit(flattened_feature_map)
282
+ pca_features = pca.transform(flattened_feature_map)
283
+ for i in range(3):
284
+ pca_features[:, i] = (pca_features[:, i] - pca_features[:, i].min()) / (pca_features[:, i].max() - pca_features[:, i].min())
285
+ pca_features = pca_features.reshape(H, W, 3)
286
+ pca_features_tensor = torch.from_numpy(pca_features)
287
+
288
+ # Putting results onto CPU since it will consume unnecessarily
289
+ # large GPU memory for long sequence OW.
290
+ video.append(torch.cat([img, rendered["img_color"][0]], dim=1).cpu())
291
+ video_PCA.append(pca_features_tensor)
292
+ ref_pred_depth = torch.cat(
293
+ (depth[..., None], rendered["depth"][0]), dim=1
294
+ ).cpu()
295
+ ref_pred_depths.append(ref_pred_depth)
296
+ depth_min = min(depth_min, ref_pred_depth.min().item())
297
+ depth_max = max(depth_max, ref_pred_depth.quantile(0.99).item())
298
+ if rendered["mask"] is not None:
299
+ masks.append(rendered["mask"][0].cpu().squeeze(-1))
300
+
301
+ # rgb video
302
+ video = torch.stack(video, dim=0)
303
+ iio.mimwrite(
304
+ osp.join(video_dir, "rgbs.mp4"),
305
+ make_video_divisble((video.numpy() * 255).astype(np.uint8)),
306
+ fps=fps,
307
+ )
308
+ # PCA video
309
+ video_PCA = torch.stack(video_PCA, dim=0)
310
+ iio.mimwrite(
311
+ osp.join(video_dir, "PCA.mp4"),
312
+ make_video_divisble((video_PCA.numpy() * 255).astype(np.uint8)),
313
+ fps=fps,
314
+ )
315
+ # depth video
316
+ depth_video = torch.stack(
317
+ [
318
+ apply_depth_colormap(
319
+ ref_pred_depth, near_plane=depth_min, far_plane=depth_max
320
+ )
321
+ for ref_pred_depth in ref_pred_depths
322
+ ],
323
+ dim=0,
324
+ )
325
+ iio.mimwrite(
326
+ osp.join(video_dir, "depths.mp4"),
327
+ make_video_divisble((depth_video.numpy() * 255).astype(np.uint8)),
328
+ fps=fps,
329
+ )
330
+ if len(masks) > 0:
331
+ # mask video
332
+ mask_video = torch.stack(masks, dim=0)
333
+ iio.mimwrite(
334
+ osp.join(video_dir, "masks.mp4"),
335
+ make_video_divisble((mask_video.numpy() * 255).astype(np.uint8)),
336
+ fps=fps,
337
+ )
338
+
339
+ # Render 2D track video.
340
+ tracks_2d, target_imgs = [], []
341
+ sample_interval = 10
342
+ batch0 = {
343
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
344
+ for k, v in self.train_loader.dataset[0].items()
345
+ }
346
+ # ().
347
+ t = batch0["ts"]
348
+ # (4, 4).
349
+ w2c = batch0["w2cs"]
350
+ # (3, 3).
351
+ K = batch0["Ks"]
352
+ # (H, W, 3).
353
+ img = batch0["imgs"]
354
+ # (H, W).
355
+ bool_mask = batch0["masks"] > 0.5
356
+ img_wh = img.shape[-2::-1]
357
+ for batch in tqdm(
358
+ self.train_loader, desc="Rendering 2D track video", leave=False
359
+ ):
360
+ batch = {
361
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
362
+ for k, v in batch.items()
363
+ }
364
+ # Putting results onto CPU since it will consume unnecessarily
365
+ # large GPU memory for long sequence OW.
366
+ # (1, H, W, 3).
367
+ target_imgs.append(batch["imgs"].cpu())
368
+ # (1,).
369
+ target_ts = batch["ts"]
370
+ # (1, 4, 4).
371
+ target_w2cs = batch["w2cs"]
372
+ # (1, 3, 3).
373
+ target_Ks = batch["Ks"]
374
+ rendered = self.model.render(
375
+ t,
376
+ w2c[None],
377
+ K[None],
378
+ img_wh,
379
+ target_ts=target_ts,
380
+ target_w2cs=target_w2cs,
381
+ )
382
+ pred_tracks_3d = rendered["tracks_3d"][0][
383
+ ::sample_interval, ::sample_interval
384
+ ][bool_mask[::sample_interval, ::sample_interval]].swapaxes(0, 1)
385
+ pred_tracks_2d = torch.einsum("bij,bpj->bpi", target_Ks, pred_tracks_3d)
386
+ pred_tracks_2d = pred_tracks_2d[..., :2] / torch.clamp(
387
+ pred_tracks_2d[..., 2:], min=1e-6
388
+ )
389
+ tracks_2d.append(pred_tracks_2d.cpu())
390
+ tracks_2d = torch.cat(tracks_2d, dim=0)
391
+ target_imgs = torch.cat(target_imgs, dim=0)
392
+ track_2d_video = plot_correspondences(
393
+ target_imgs.numpy(),
394
+ tracks_2d.numpy(),
395
+ query_id=cast(int, t),
396
+ )
397
+ iio.mimwrite(
398
+ osp.join(video_dir, "tracks_2d.mp4"),
399
+ make_video_divisble(np.stack(track_2d_video, 0)),
400
+ fps=fps,
401
+ )
402
+ # Render motion coefficient video.
403
+ with torch.random.fork_rng():
404
+ torch.random.manual_seed(0)
405
+ motion_coef_colors = torch.pca_lowrank(
406
+ self.model.fg.get_coefs()[None],
407
+ q=3,
408
+ )[0][0]
409
+ motion_coef_colors = (motion_coef_colors - motion_coef_colors.min(0)[0]) / (
410
+ motion_coef_colors.max(0)[0] - motion_coef_colors.min(0)[0]
411
+ )
412
+ motion_coef_colors = F.pad(
413
+ motion_coef_colors, (0, 0, 0, self.model.bg.num_gaussians), value=0.5
414
+ )
415
+ video = []
416
+ for batch in tqdm(
417
+ self.train_loader, desc="Rendering motion coefficient video", leave=False
418
+ ):
419
+ batch = {
420
+ k: v.to(self.device) if isinstance(v, torch.Tensor) else v
421
+ for k, v in batch.items()
422
+ }
423
+ # ().
424
+ t = batch["ts"][0]
425
+ # (4, 4).
426
+ w2c = batch["w2cs"][0]
427
+ # (3, 3).
428
+ K = batch["Ks"][0]
429
+ # (3, 3).
430
+ img = batch["imgs"][0]
431
+ img_wh = img.shape[-2::-1]
432
+ rendered = self.model.render(
433
+ t, w2c[None], K[None], img_wh, colors_override=motion_coef_colors
434
+ )
435
+ # Putting results onto CPU since it will consume unnecessarily
436
+ # large GPU memory for long sequence OW.
437
+ video.append(torch.cat([img, rendered["img_color"][0]], dim=1).cpu())
438
+ video = torch.stack(video, dim=0)
439
+ iio.mimwrite(
440
+ osp.join(video_dir, "motion_coefs.mp4"),
441
+ make_video_divisble((video.numpy() * 255).astype(np.uint8)),
442
+ fps=fps,
443
+ )
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__init__.py ADDED
File without changes
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (148 Bytes). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/playback_panel.cpython-311.pyc ADDED
Binary file (3.28 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/render_panel.cpython-311.pyc ADDED
Binary file (59.3 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/utils.cpython-311.pyc ADDED
Binary file (27.1 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/__pycache__/viewer.cpython-311.pyc ADDED
Binary file (4.83 kB). View file
 
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/playback_panel.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import threading
2
+ import time
3
+
4
+ import viser
5
+
6
+
7
+ def add_gui_playback_group(
8
+ server: viser.ViserServer,
9
+ num_frames: int,
10
+ min_fps: float = 1.0,
11
+ max_fps: float = 60.0,
12
+ fps_step: float = 0.1,
13
+ initial_fps: float = 10.0,
14
+ ):
15
+ gui_timestep = server.gui.add_slider(
16
+ "Timestep",
17
+ min=0,
18
+ max=num_frames - 1,
19
+ step=1,
20
+ initial_value=0,
21
+ disabled=True,
22
+ )
23
+ gui_next_frame = server.gui.add_button("Next Frame")
24
+ gui_prev_frame = server.gui.add_button("Prev Frame")
25
+ gui_playing_pause = server.gui.add_button("Pause")
26
+ gui_playing_pause.visible = False
27
+ gui_playing_resume = server.gui.add_button("Resume")
28
+ gui_framerate = server.gui.add_slider(
29
+ "FPS", min=min_fps, max=max_fps, step=fps_step, initial_value=initial_fps
30
+ )
31
+
32
+ # Frame step buttons.
33
+ @gui_next_frame.on_click
34
+ def _(_) -> None:
35
+ gui_timestep.value = (gui_timestep.value + 1) % num_frames
36
+
37
+ @gui_prev_frame.on_click
38
+ def _(_) -> None:
39
+ gui_timestep.value = (gui_timestep.value - 1) % num_frames
40
+
41
+ # Disable frame controls when we're playing.
42
+ def _toggle_gui_playing(_):
43
+ gui_playing_pause.visible = not gui_playing_pause.visible
44
+ gui_playing_resume.visible = not gui_playing_resume.visible
45
+ gui_timestep.disabled = gui_playing_pause.visible
46
+ gui_next_frame.disabled = gui_playing_pause.visible
47
+ gui_prev_frame.disabled = gui_playing_pause.visible
48
+
49
+ gui_playing_pause.on_click(_toggle_gui_playing)
50
+ gui_playing_resume.on_click(_toggle_gui_playing)
51
+
52
+ # Create a thread to update the timestep indefinitely.
53
+ def _update_timestep():
54
+ while True:
55
+ if gui_playing_pause.visible:
56
+ gui_timestep.value = (gui_timestep.value + 1) % num_frames
57
+ time.sleep(1 / gui_framerate.value)
58
+
59
+ threading.Thread(target=_update_timestep, daemon=True).start()
60
+
61
+ return (
62
+ gui_timestep,
63
+ gui_next_frame,
64
+ gui_prev_frame,
65
+ gui_playing_pause,
66
+ gui_playing_resume,
67
+ gui_framerate,
68
+ )
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/render_panel.py ADDED
@@ -0,0 +1,1165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2022 the Regents of the University of California, Nerfstudio Team and contributors. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import colorsys
18
+ import dataclasses
19
+ import datetime
20
+ import json
21
+ import threading
22
+ import time
23
+ from pathlib import Path
24
+ from typing import Dict, List, Literal, Optional, Tuple
25
+
26
+ import numpy as np
27
+ import scipy
28
+ import splines
29
+ import splines.quaternion
30
+ import viser
31
+ import viser.transforms as tf
32
+
33
+ VISER_SCALE_RATIO = 10.0
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class Keyframe:
38
+ time: float
39
+ position: np.ndarray
40
+ wxyz: np.ndarray
41
+ override_fov_enabled: bool
42
+ override_fov_rad: float
43
+ aspect: float
44
+ override_transition_enabled: bool
45
+ override_transition_sec: Optional[float]
46
+
47
+ @staticmethod
48
+ def from_camera(time: float, camera: viser.CameraHandle, aspect: float) -> Keyframe:
49
+ return Keyframe(
50
+ time,
51
+ camera.position,
52
+ camera.wxyz,
53
+ override_fov_enabled=False,
54
+ override_fov_rad=camera.fov,
55
+ aspect=aspect,
56
+ override_transition_enabled=False,
57
+ override_transition_sec=None,
58
+ )
59
+
60
+
61
+ class CameraPath:
62
+ def __init__(
63
+ self, server: viser.ViserServer, duration_element: viser.GuiInputHandle[float]
64
+ ):
65
+ self._server = server
66
+ self._keyframes: Dict[int, Tuple[Keyframe, viser.CameraFrustumHandle]] = {}
67
+ self._keyframe_counter: int = 0
68
+ self._spline_nodes: List[viser.SceneNodeHandle] = []
69
+ self._camera_edit_panel: Optional[viser.Gui3dContainerHandle] = None
70
+
71
+ self._orientation_spline: Optional[splines.quaternion.KochanekBartels] = None
72
+ self._position_spline: Optional[splines.KochanekBartels] = None
73
+ self._fov_spline: Optional[splines.KochanekBartels] = None
74
+ self._time_spline: Optional[splines.KochanekBartels] = None
75
+
76
+ self._keyframes_visible: bool = True
77
+
78
+ self._duration_element = duration_element
79
+
80
+ # These parameters should be overridden externally.
81
+ self.loop: bool = False
82
+ self.framerate: float = 30.0
83
+ self.tension: float = 0.5 # Tension / alpha term.
84
+ self.default_fov: float = 0.0
85
+ self.default_transition_sec: float = 0.0
86
+ self.show_spline: bool = True
87
+
88
+ def set_keyframes_visible(self, visible: bool) -> None:
89
+ self._keyframes_visible = visible
90
+ for keyframe in self._keyframes.values():
91
+ keyframe[1].visible = visible
92
+
93
+ def add_camera(
94
+ self, keyframe: Keyframe, keyframe_index: Optional[int] = None
95
+ ) -> None:
96
+ """Add a new camera, or replace an old one if `keyframe_index` is passed in."""
97
+ server = self._server
98
+
99
+ # Add a keyframe if we aren't replacing an existing one.
100
+ if keyframe_index is None:
101
+ keyframe_index = self._keyframe_counter
102
+ self._keyframe_counter += 1
103
+
104
+ print(
105
+ f"{keyframe.wxyz=} {keyframe.position=} {keyframe_index=} {keyframe.aspect=}"
106
+ )
107
+ frustum_handle = server.scene.add_camera_frustum(
108
+ f"/render_cameras/{keyframe_index}",
109
+ fov=(
110
+ keyframe.override_fov_rad
111
+ if keyframe.override_fov_enabled
112
+ else self.default_fov
113
+ ),
114
+ aspect=keyframe.aspect,
115
+ scale=0.1,
116
+ color=(200, 10, 30),
117
+ wxyz=keyframe.wxyz,
118
+ position=keyframe.position,
119
+ visible=self._keyframes_visible,
120
+ )
121
+ self._server.scene.add_icosphere(
122
+ f"/render_cameras/{keyframe_index}/sphere",
123
+ radius=0.03,
124
+ color=(200, 10, 30),
125
+ )
126
+
127
+ @frustum_handle.on_click
128
+ def _(_) -> None:
129
+ if self._camera_edit_panel is not None:
130
+ self._camera_edit_panel.remove()
131
+ self._camera_edit_panel = None
132
+
133
+ with server.scene.add_3d_gui_container(
134
+ "/camera_edit_panel",
135
+ position=keyframe.position,
136
+ ) as camera_edit_panel:
137
+ self._camera_edit_panel = camera_edit_panel
138
+ override_fov = server.gui.add_checkbox(
139
+ "Override FOV", initial_value=keyframe.override_fov_enabled
140
+ )
141
+ override_fov_degrees = server.gui.add_slider(
142
+ "Override FOV (degrees)",
143
+ 5.0,
144
+ 175.0,
145
+ step=0.1,
146
+ initial_value=keyframe.override_fov_rad * 180.0 / np.pi,
147
+ disabled=not keyframe.override_fov_enabled,
148
+ )
149
+ delete_button = server.gui.add_button(
150
+ "Delete", color="red", icon=viser.Icon.TRASH
151
+ )
152
+ go_to_button = server.gui.add_button("Go to")
153
+ close_button = server.gui.add_button("Close")
154
+
155
+ @override_fov.on_update
156
+ def _(_) -> None:
157
+ keyframe.override_fov_enabled = override_fov.value
158
+ override_fov_degrees.disabled = not override_fov.value
159
+ self.add_camera(keyframe, keyframe_index)
160
+
161
+ @override_fov_degrees.on_update
162
+ def _(_) -> None:
163
+ keyframe.override_fov_rad = override_fov_degrees.value / 180.0 * np.pi
164
+ self.add_camera(keyframe, keyframe_index)
165
+
166
+ @delete_button.on_click
167
+ def _(event: viser.GuiEvent) -> None:
168
+ assert event.client is not None
169
+ with event.client.gui.add_modal("Confirm") as modal:
170
+ event.client.gui.add_markdown("Delete keyframe?")
171
+ confirm_button = event.client.gui.add_button(
172
+ "Yes", color="red", icon=viser.Icon.TRASH
173
+ )
174
+ exit_button = event.client.gui.add_button("Cancel")
175
+
176
+ @confirm_button.on_click
177
+ def _(_) -> None:
178
+ assert camera_edit_panel is not None
179
+
180
+ keyframe_id = None
181
+ for i, keyframe_tuple in self._keyframes.items():
182
+ if keyframe_tuple[1] is frustum_handle:
183
+ keyframe_id = i
184
+ break
185
+ assert keyframe_id is not None
186
+
187
+ self._keyframes.pop(keyframe_id)
188
+ frustum_handle.remove()
189
+ camera_edit_panel.remove()
190
+ self._camera_edit_panel = None
191
+ modal.close()
192
+ self.update_spline()
193
+
194
+ @exit_button.on_click
195
+ def _(_) -> None:
196
+ modal.close()
197
+
198
+ @go_to_button.on_click
199
+ def _(event: viser.GuiEvent) -> None:
200
+ assert event.client is not None
201
+ client = event.client
202
+ T_world_current = tf.SE3.from_rotation_and_translation(
203
+ tf.SO3(client.camera.wxyz), client.camera.position
204
+ )
205
+ T_world_target = tf.SE3.from_rotation_and_translation(
206
+ tf.SO3(keyframe.wxyz), keyframe.position
207
+ ) @ tf.SE3.from_translation(np.array([0.0, 0.0, -0.5]))
208
+
209
+ T_current_target = T_world_current.inverse() @ T_world_target
210
+
211
+ for j in range(10):
212
+ T_world_set = T_world_current @ tf.SE3.exp(
213
+ T_current_target.log() * j / 9.0
214
+ )
215
+
216
+ # Important bit: we atomically set both the orientation and the position
217
+ # of the camera.
218
+ with client.atomic():
219
+ client.camera.wxyz = T_world_set.rotation().wxyz
220
+ client.camera.position = T_world_set.translation()
221
+ time.sleep(1.0 / 30.0)
222
+
223
+ @close_button.on_click
224
+ def _(_) -> None:
225
+ assert camera_edit_panel is not None
226
+ camera_edit_panel.remove()
227
+ self._camera_edit_panel = None
228
+
229
+ self._keyframes[keyframe_index] = (keyframe, frustum_handle)
230
+
231
+ def update_aspect(self, aspect: float) -> None:
232
+ for keyframe_index, frame in self._keyframes.items():
233
+ frame = dataclasses.replace(frame[0], aspect=aspect)
234
+ self.add_camera(frame, keyframe_index=keyframe_index)
235
+
236
+ def get_aspect(self) -> float:
237
+ """Get W/H aspect ratio, which is shared across all keyframes."""
238
+ assert len(self._keyframes) > 0
239
+ return next(iter(self._keyframes.values()))[0].aspect
240
+
241
+ def reset(self) -> None:
242
+ for frame in self._keyframes.values():
243
+ print(f"removing {frame[1]}")
244
+ frame[1].remove()
245
+ self._keyframes.clear()
246
+ self.update_spline()
247
+ print("camera path reset")
248
+
249
+ def spline_t_from_t_sec(self, time: np.ndarray) -> np.ndarray:
250
+ """From a time value in seconds, compute a t value for our geometric
251
+ spline interpolation. An increment of 1 for the latter will move the
252
+ camera forward by one keyframe.
253
+
254
+ We use a PCHIP spline here to guarantee monotonicity.
255
+ """
256
+ transition_times_cumsum = self.compute_transition_times_cumsum()
257
+ spline_indices = np.arange(transition_times_cumsum.shape[0])
258
+
259
+ if self.loop:
260
+ # In the case of a loop, we pad the spline to match the start/end
261
+ # slopes.
262
+ interpolator = scipy.interpolate.PchipInterpolator(
263
+ x=np.concatenate(
264
+ [
265
+ [-(transition_times_cumsum[-1] - transition_times_cumsum[-2])],
266
+ transition_times_cumsum,
267
+ transition_times_cumsum[-1:] + transition_times_cumsum[1:2],
268
+ ],
269
+ axis=0,
270
+ ),
271
+ y=np.concatenate(
272
+ [[-1], spline_indices, [spline_indices[-1] + 1]], axis=0
273
+ ),
274
+ )
275
+ else:
276
+ interpolator = scipy.interpolate.PchipInterpolator(
277
+ x=transition_times_cumsum, y=spline_indices
278
+ )
279
+
280
+ # Clip to account for floating point error.
281
+ return np.clip(interpolator(time), 0, spline_indices[-1])
282
+
283
+ def interpolate_pose_and_fov_rad(
284
+ self, normalized_t: float
285
+ ) -> Optional[Tuple[tf.SE3, float, float]]:
286
+ if len(self._keyframes) < 2:
287
+ return None
288
+
289
+ self._time_spline = splines.KochanekBartels(
290
+ [keyframe[0].time for keyframe in self._keyframes.values()],
291
+ tcb=(self.tension, 0.0, 0.0),
292
+ endconditions="closed" if self.loop else "natural",
293
+ )
294
+
295
+ self._fov_spline = splines.KochanekBartels(
296
+ [
297
+ (
298
+ keyframe[0].override_fov_rad
299
+ if keyframe[0].override_fov_enabled
300
+ else self.default_fov
301
+ )
302
+ for keyframe in self._keyframes.values()
303
+ ],
304
+ tcb=(self.tension, 0.0, 0.0),
305
+ endconditions="closed" if self.loop else "natural",
306
+ )
307
+
308
+ assert self._orientation_spline is not None
309
+ assert self._position_spline is not None
310
+ assert self._fov_spline is not None
311
+ assert self._time_spline is not None
312
+
313
+ max_t = self.compute_duration()
314
+ t = max_t * normalized_t
315
+ spline_t = float(self.spline_t_from_t_sec(np.array(t)))
316
+
317
+ quat = self._orientation_spline.evaluate(spline_t)
318
+ assert isinstance(quat, splines.quaternion.UnitQuaternion)
319
+ return (
320
+ tf.SE3.from_rotation_and_translation(
321
+ tf.SO3(np.array([quat.scalar, *quat.vector])),
322
+ self._position_spline.evaluate(spline_t),
323
+ ),
324
+ float(self._fov_spline.evaluate(spline_t)),
325
+ float(self._time_spline.evaluate(spline_t)),
326
+ )
327
+
328
+ def update_spline(self) -> None:
329
+ num_frames = int(self.compute_duration() * self.framerate)
330
+ keyframes = list(self._keyframes.values())
331
+
332
+ if num_frames <= 0 or not self.show_spline or len(keyframes) < 2:
333
+ for node in self._spline_nodes:
334
+ node.remove()
335
+ self._spline_nodes.clear()
336
+ return
337
+
338
+ transition_times_cumsum = self.compute_transition_times_cumsum()
339
+
340
+ self._orientation_spline = splines.quaternion.KochanekBartels(
341
+ [
342
+ splines.quaternion.UnitQuaternion.from_unit_xyzw(
343
+ np.roll(keyframe[0].wxyz, shift=-1)
344
+ )
345
+ for keyframe in keyframes
346
+ ],
347
+ tcb=(self.tension, 0.0, 0.0),
348
+ endconditions="closed" if self.loop else "natural",
349
+ )
350
+ self._position_spline = splines.KochanekBartels(
351
+ [keyframe[0].position for keyframe in keyframes],
352
+ tcb=(self.tension, 0.0, 0.0),
353
+ endconditions="closed" if self.loop else "natural",
354
+ )
355
+
356
+ # Update visualized spline.
357
+ points_array = self._position_spline.evaluate(
358
+ self.spline_t_from_t_sec(
359
+ np.linspace(0, transition_times_cumsum[-1], num_frames)
360
+ )
361
+ )
362
+ colors_array = np.array(
363
+ [
364
+ colorsys.hls_to_rgb(h, 0.5, 1.0)
365
+ for h in np.linspace(0.0, 1.0, len(points_array))
366
+ ]
367
+ )
368
+
369
+ # Clear prior spline nodes.
370
+ for node in self._spline_nodes:
371
+ node.remove()
372
+ self._spline_nodes.clear()
373
+
374
+ self._spline_nodes.append(
375
+ self._server.scene.add_spline_catmull_rom(
376
+ "/render_camera_spline",
377
+ positions=points_array,
378
+ color=(220, 220, 220),
379
+ closed=self.loop,
380
+ line_width=1.0,
381
+ segments=points_array.shape[0] + 1,
382
+ )
383
+ )
384
+ self._spline_nodes.append(
385
+ self._server.scene.add_point_cloud(
386
+ "/render_camera_spline/points",
387
+ points=points_array,
388
+ colors=colors_array,
389
+ point_size=0.04,
390
+ )
391
+ )
392
+
393
+ def make_transition_handle(i: int) -> None:
394
+ assert self._position_spline is not None
395
+ transition_pos = self._position_spline.evaluate(
396
+ float(
397
+ self.spline_t_from_t_sec(
398
+ (transition_times_cumsum[i] + transition_times_cumsum[i + 1])
399
+ / 2.0,
400
+ )
401
+ )
402
+ )
403
+ transition_sphere = self._server.scene.add_icosphere(
404
+ f"/render_camera_spline/transition_{i}",
405
+ radius=0.04,
406
+ color=(255, 0, 0),
407
+ position=transition_pos,
408
+ )
409
+ self._spline_nodes.append(transition_sphere)
410
+
411
+ @transition_sphere.on_click
412
+ def _(_) -> None:
413
+ server = self._server
414
+
415
+ if self._camera_edit_panel is not None:
416
+ self._camera_edit_panel.remove()
417
+ self._camera_edit_panel = None
418
+
419
+ keyframe_index = (i + 1) % len(self._keyframes)
420
+ keyframe = keyframes[keyframe_index][0]
421
+
422
+ with server.scene.add_3d_gui_container(
423
+ "/camera_edit_panel",
424
+ position=transition_pos,
425
+ ) as camera_edit_panel:
426
+ self._camera_edit_panel = camera_edit_panel
427
+ override_transition_enabled = server.gui.add_checkbox(
428
+ "Override transition",
429
+ initial_value=keyframe.override_transition_enabled,
430
+ )
431
+ override_transition_sec = server.gui.add_number(
432
+ "Override transition (sec)",
433
+ initial_value=(
434
+ keyframe.override_transition_sec
435
+ if keyframe.override_transition_sec is not None
436
+ else self.default_transition_sec
437
+ ),
438
+ min=0.001,
439
+ max=30.0,
440
+ step=0.001,
441
+ disabled=not override_transition_enabled.value,
442
+ )
443
+ close_button = server.gui.add_button("Close")
444
+
445
+ @override_transition_enabled.on_update
446
+ def _(_) -> None:
447
+ keyframe.override_transition_enabled = (
448
+ override_transition_enabled.value
449
+ )
450
+ override_transition_sec.disabled = (
451
+ not override_transition_enabled.value
452
+ )
453
+ self._duration_element.value = self.compute_duration()
454
+
455
+ @override_transition_sec.on_update
456
+ def _(_) -> None:
457
+ keyframe.override_transition_sec = override_transition_sec.value
458
+ self._duration_element.value = self.compute_duration()
459
+
460
+ @close_button.on_click
461
+ def _(_) -> None:
462
+ assert camera_edit_panel is not None
463
+ camera_edit_panel.remove()
464
+ self._camera_edit_panel = None
465
+
466
+ (num_transitions_plus_1,) = transition_times_cumsum.shape
467
+ for i in range(num_transitions_plus_1 - 1):
468
+ make_transition_handle(i)
469
+
470
+ # for i in range(transition_times.shape[0])
471
+
472
+ def compute_duration(self) -> float:
473
+ """Compute the total duration of the trajectory."""
474
+ total = 0.0
475
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
476
+ if i == 0 and not self.loop:
477
+ continue
478
+ del frustum
479
+ total += (
480
+ keyframe.override_transition_sec
481
+ if keyframe.override_transition_enabled
482
+ and keyframe.override_transition_sec is not None
483
+ else self.default_transition_sec
484
+ )
485
+ return total
486
+
487
+ def compute_transition_times_cumsum(self) -> np.ndarray:
488
+ """Compute the total duration of the trajectory."""
489
+ total = 0.0
490
+ out = [0.0]
491
+ for i, (keyframe, frustum) in enumerate(self._keyframes.values()):
492
+ if i == 0:
493
+ continue
494
+ del frustum
495
+ total += (
496
+ keyframe.override_transition_sec
497
+ if keyframe.override_transition_enabled
498
+ and keyframe.override_transition_sec is not None
499
+ else self.default_transition_sec
500
+ )
501
+ out.append(total)
502
+
503
+ if self.loop:
504
+ keyframe = next(iter(self._keyframes.values()))[0]
505
+ total += (
506
+ keyframe.override_transition_sec
507
+ if keyframe.override_transition_enabled
508
+ and keyframe.override_transition_sec is not None
509
+ else self.default_transition_sec
510
+ )
511
+ out.append(total)
512
+
513
+ return np.array(out)
514
+
515
+
516
+ @dataclasses.dataclass
517
+ class RenderTabState:
518
+ """Useful GUI handles exposed by the render tab."""
519
+
520
+ preview_render: bool
521
+ preview_fov: float
522
+ preview_aspect: float
523
+ preview_camera_type: Literal["Perspective", "Fisheye", "Equirectangular"]
524
+
525
+
526
+ def populate_render_tab(
527
+ server: viser.ViserServer,
528
+ datapath: Path,
529
+ gui_timestep_handle: viser.GuiInputHandle[int] | None,
530
+ ) -> RenderTabState:
531
+
532
+ render_tab_state = RenderTabState(
533
+ preview_render=False,
534
+ preview_fov=0.0,
535
+ preview_aspect=1.0,
536
+ preview_camera_type="Perspective",
537
+ )
538
+
539
+ fov_degrees = server.gui.add_slider(
540
+ "Default FOV",
541
+ initial_value=75.0,
542
+ min=0.1,
543
+ max=175.0,
544
+ step=0.01,
545
+ hint="Field-of-view for rendering, which can also be overridden on a per-keyframe basis.",
546
+ )
547
+
548
+ @fov_degrees.on_update
549
+ def _(_) -> None:
550
+ fov_radians = fov_degrees.value / 180.0 * np.pi
551
+ for client in server.get_clients().values():
552
+ client.camera.fov = fov_radians
553
+ camera_path.default_fov = fov_radians
554
+
555
+ # Updating the aspect ratio will also re-render the camera frustums.
556
+ # Could rethink this.
557
+ camera_path.update_aspect(resolution.value[0] / resolution.value[1])
558
+ compute_and_update_preview_camera_state()
559
+
560
+ resolution = server.gui.add_vector2(
561
+ "Resolution",
562
+ initial_value=(1920, 1080),
563
+ min=(50, 50),
564
+ max=(10_000, 10_000),
565
+ step=1,
566
+ hint="Render output resolution in pixels.",
567
+ )
568
+
569
+ @resolution.on_update
570
+ def _(_) -> None:
571
+ camera_path.update_aspect(resolution.value[0] / resolution.value[1])
572
+ compute_and_update_preview_camera_state()
573
+
574
+ camera_type = server.gui.add_dropdown(
575
+ "Camera type",
576
+ ("Perspective", "Fisheye", "Equirectangular"),
577
+ initial_value="Perspective",
578
+ hint="Camera model to render with. This is applied to all keyframes.",
579
+ )
580
+ add_button = server.gui.add_button(
581
+ "Add Keyframe",
582
+ icon=viser.Icon.PLUS,
583
+ hint="Add a new keyframe at the current pose.",
584
+ )
585
+
586
+ @add_button.on_click
587
+ def _(event: viser.GuiEvent) -> None:
588
+ assert event.client_id is not None
589
+ camera = server.get_clients()[event.client_id].camera
590
+ pose = tf.SE3.from_rotation_and_translation(
591
+ tf.SO3(camera.wxyz), camera.position
592
+ )
593
+ print(f"client {event.client_id} at {camera.position} {camera.wxyz}")
594
+ print(f"camera pose {pose.as_matrix()}")
595
+ if gui_timestep_handle is not None:
596
+ print(f"timestep {gui_timestep_handle.value}")
597
+
598
+ # Add this camera to the path.
599
+ time = 0
600
+ if gui_timestep_handle is not None:
601
+ time = gui_timestep_handle.value
602
+ camera_path.add_camera(
603
+ Keyframe.from_camera(
604
+ time,
605
+ camera,
606
+ aspect=resolution.value[0] / resolution.value[1],
607
+ ),
608
+ )
609
+ duration_number.value = camera_path.compute_duration()
610
+ camera_path.update_spline()
611
+
612
+ clear_keyframes_button = server.gui.add_button(
613
+ "Clear Keyframes",
614
+ icon=viser.Icon.TRASH,
615
+ hint="Remove all keyframes from the render path.",
616
+ )
617
+
618
+ @clear_keyframes_button.on_click
619
+ def _(event: viser.GuiEvent) -> None:
620
+ assert event.client_id is not None
621
+ client = server.get_clients()[event.client_id]
622
+ with client.atomic(), client.gui.add_modal("Confirm") as modal:
623
+ client.gui.add_markdown("Clear all keyframes?")
624
+ confirm_button = client.gui.add_button(
625
+ "Yes", color="red", icon=viser.Icon.TRASH
626
+ )
627
+ exit_button = client.gui.add_button("Cancel")
628
+
629
+ @confirm_button.on_click
630
+ def _(_) -> None:
631
+ camera_path.reset()
632
+ modal.close()
633
+
634
+ duration_number.value = camera_path.compute_duration()
635
+
636
+ # Clear move handles.
637
+ if len(transform_controls) > 0:
638
+ for t in transform_controls:
639
+ t.remove()
640
+ transform_controls.clear()
641
+ return
642
+
643
+ @exit_button.on_click
644
+ def _(_) -> None:
645
+ modal.close()
646
+
647
+ loop = server.gui.add_checkbox(
648
+ "Loop", False, hint="Add a segment between the first and last keyframes."
649
+ )
650
+
651
+ @loop.on_update
652
+ def _(_) -> None:
653
+ camera_path.loop = loop.value
654
+ duration_number.value = camera_path.compute_duration()
655
+
656
+ tension_slider = server.gui.add_slider(
657
+ "Spline tension",
658
+ min=0.0,
659
+ max=1.0,
660
+ initial_value=0.0,
661
+ step=0.01,
662
+ hint="Tension parameter for adjusting smoothness of spline interpolation.",
663
+ )
664
+
665
+ @tension_slider.on_update
666
+ def _(_) -> None:
667
+ camera_path.tension = tension_slider.value
668
+ camera_path.update_spline()
669
+
670
+ move_checkbox = server.gui.add_checkbox(
671
+ "Move keyframes",
672
+ initial_value=False,
673
+ hint="Toggle move handles for keyframes in the scene.",
674
+ )
675
+
676
+ transform_controls: List[viser.SceneNodeHandle] = []
677
+
678
+ @move_checkbox.on_update
679
+ def _(event: viser.GuiEvent) -> None:
680
+ # Clear move handles when toggled off.
681
+ if move_checkbox.value is False:
682
+ for t in transform_controls:
683
+ t.remove()
684
+ transform_controls.clear()
685
+ return
686
+
687
+ def _make_transform_controls_callback(
688
+ keyframe: Tuple[Keyframe, viser.SceneNodeHandle],
689
+ controls: viser.TransformControlsHandle,
690
+ ) -> None:
691
+ @controls.on_update
692
+ def _(_) -> None:
693
+ keyframe[0].wxyz = controls.wxyz
694
+ keyframe[0].position = controls.position
695
+
696
+ keyframe[1].wxyz = controls.wxyz
697
+ keyframe[1].position = controls.position
698
+
699
+ camera_path.update_spline()
700
+
701
+ # Show move handles.
702
+ assert event.client is not None
703
+ for keyframe_index, keyframe in camera_path._keyframes.items():
704
+ controls = event.client.scene.add_transform_controls(
705
+ f"/keyframe_move/{keyframe_index}",
706
+ scale=0.4,
707
+ wxyz=keyframe[0].wxyz,
708
+ position=keyframe[0].position,
709
+ )
710
+ transform_controls.append(controls)
711
+ _make_transform_controls_callback(keyframe, controls)
712
+
713
+ show_keyframe_checkbox = server.gui.add_checkbox(
714
+ "Show keyframes",
715
+ initial_value=True,
716
+ hint="Show keyframes in the scene.",
717
+ )
718
+
719
+ @show_keyframe_checkbox.on_update
720
+ def _(_: viser.GuiEvent) -> None:
721
+ camera_path.set_keyframes_visible(show_keyframe_checkbox.value)
722
+
723
+ show_spline_checkbox = server.gui.add_checkbox(
724
+ "Show spline",
725
+ initial_value=True,
726
+ hint="Show camera path spline in the scene.",
727
+ )
728
+
729
+ @show_spline_checkbox.on_update
730
+ def _(_) -> None:
731
+ camera_path.show_spline = show_spline_checkbox.value
732
+ camera_path.update_spline()
733
+
734
+ playback_folder = server.gui.add_folder("Playback")
735
+ with playback_folder:
736
+ play_button = server.gui.add_button("Play", icon=viser.Icon.PLAYER_PLAY)
737
+ pause_button = server.gui.add_button(
738
+ "Pause", icon=viser.Icon.PLAYER_PAUSE, visible=False
739
+ )
740
+ preview_render_button = server.gui.add_button(
741
+ "Preview Render", hint="Show a preview of the render in the viewport."
742
+ )
743
+ preview_render_stop_button = server.gui.add_button(
744
+ "Exit Render Preview", color="red", visible=False
745
+ )
746
+
747
+ transition_sec_number = server.gui.add_number(
748
+ "Transition (sec)",
749
+ min=0.001,
750
+ max=30.0,
751
+ step=0.001,
752
+ initial_value=2.0,
753
+ hint="Time in seconds between each keyframe, which can also be overridden on a per-transition basis.",
754
+ )
755
+ framerate_number = server.gui.add_number(
756
+ "FPS", min=0.1, max=240.0, step=1e-2, initial_value=30.0
757
+ )
758
+ framerate_buttons = server.gui.add_button_group("", ("24", "30", "60"))
759
+ duration_number = server.gui.add_number(
760
+ "Duration (sec)",
761
+ min=0.0,
762
+ max=1e8,
763
+ step=0.001,
764
+ initial_value=0.0,
765
+ disabled=True,
766
+ )
767
+
768
+ @framerate_buttons.on_click
769
+ def _(_) -> None:
770
+ framerate_number.value = float(framerate_buttons.value)
771
+
772
+ @transition_sec_number.on_update
773
+ def _(_) -> None:
774
+ camera_path.default_transition_sec = transition_sec_number.value
775
+ duration_number.value = camera_path.compute_duration()
776
+
777
+ def get_max_frame_index() -> int:
778
+ return max(1, int(framerate_number.value * duration_number.value) - 1)
779
+
780
+ preview_camera_handle: Optional[viser.SceneNodeHandle] = None
781
+
782
+ def remove_preview_camera() -> None:
783
+ nonlocal preview_camera_handle
784
+ if preview_camera_handle is not None:
785
+ preview_camera_handle.remove()
786
+ preview_camera_handle = None
787
+
788
+ def compute_and_update_preview_camera_state() -> (
789
+ Optional[Tuple[tf.SE3, float, float]]
790
+ ):
791
+ """Update the render tab state with the current preview camera pose.
792
+ Returns current camera pose + FOV if available."""
793
+
794
+ if preview_frame_slider is None:
795
+ return
796
+ maybe_pose_and_fov_rad_and_time = camera_path.interpolate_pose_and_fov_rad(
797
+ preview_frame_slider.value / get_max_frame_index()
798
+ )
799
+ if maybe_pose_and_fov_rad_and_time is None:
800
+ remove_preview_camera()
801
+ return
802
+ pose, fov_rad, time = maybe_pose_and_fov_rad_and_time
803
+ render_tab_state.preview_fov = fov_rad
804
+ render_tab_state.preview_aspect = camera_path.get_aspect()
805
+ render_tab_state.preview_camera_type = camera_type.value
806
+ if gui_timestep_handle is not None:
807
+ gui_timestep_handle.value = int(time)
808
+ return pose, fov_rad, time
809
+
810
+ def add_preview_frame_slider() -> Optional[viser.GuiInputHandle[int]]:
811
+ """Helper for creating the current frame # slider. This is removed and
812
+ re-added anytime the `max` value changes."""
813
+
814
+ with playback_folder:
815
+ preview_frame_slider = server.gui.add_slider(
816
+ "Preview frame",
817
+ min=0,
818
+ max=get_max_frame_index(),
819
+ step=1,
820
+ initial_value=0,
821
+ # Place right after the pause button.
822
+ order=preview_render_stop_button.order + 0.01,
823
+ disabled=get_max_frame_index() == 1,
824
+ )
825
+ play_button.disabled = preview_frame_slider.disabled
826
+ preview_render_button.disabled = preview_frame_slider.disabled
827
+
828
+ @preview_frame_slider.on_update
829
+ def _(_) -> None:
830
+ nonlocal preview_camera_handle
831
+ maybe_pose_and_fov_rad_and_time = compute_and_update_preview_camera_state()
832
+ if maybe_pose_and_fov_rad_and_time is None:
833
+ return
834
+ pose, fov_rad, time = maybe_pose_and_fov_rad_and_time
835
+
836
+ preview_camera_handle = server.scene.add_camera_frustum(
837
+ "/preview_camera",
838
+ fov=fov_rad,
839
+ aspect=resolution.value[0] / resolution.value[1],
840
+ scale=0.35,
841
+ wxyz=pose.rotation().wxyz,
842
+ position=pose.translation(),
843
+ color=(10, 200, 30),
844
+ )
845
+ if render_tab_state.preview_render:
846
+ for client in server.get_clients().values():
847
+ client.camera.wxyz = pose.rotation().wxyz
848
+ client.camera.position = pose.translation()
849
+ if gui_timestep_handle is not None:
850
+ gui_timestep_handle.value = int(time)
851
+
852
+ return preview_frame_slider
853
+
854
+ # We back up the camera poses before and after we start previewing renders.
855
+ camera_pose_backup_from_id: Dict[int, tuple] = {}
856
+
857
+ @preview_render_button.on_click
858
+ def _(_) -> None:
859
+ render_tab_state.preview_render = True
860
+ preview_render_button.visible = False
861
+ preview_render_stop_button.visible = True
862
+
863
+ maybe_pose_and_fov_rad_and_time = compute_and_update_preview_camera_state()
864
+ if maybe_pose_and_fov_rad_and_time is None:
865
+ remove_preview_camera()
866
+ return
867
+ pose, fov, time = maybe_pose_and_fov_rad_and_time
868
+ del fov
869
+
870
+ # Hide all scene nodes when we're previewing the render.
871
+ server.scene.set_global_visibility(True)
872
+
873
+ # Back up and then set camera poses.
874
+ for client in server.get_clients().values():
875
+ camera_pose_backup_from_id[client.client_id] = (
876
+ client.camera.position,
877
+ client.camera.look_at,
878
+ client.camera.up_direction,
879
+ )
880
+ client.camera.wxyz = pose.rotation().wxyz
881
+ client.camera.position = pose.translation()
882
+ if gui_timestep_handle is not None:
883
+ gui_timestep_handle.value = int(time)
884
+
885
+ @preview_render_stop_button.on_click
886
+ def _(_) -> None:
887
+ render_tab_state.preview_render = False
888
+ preview_render_button.visible = True
889
+ preview_render_stop_button.visible = False
890
+
891
+ # Revert camera poses.
892
+ for client in server.get_clients().values():
893
+ if client.client_id not in camera_pose_backup_from_id:
894
+ continue
895
+ cam_position, cam_look_at, cam_up = camera_pose_backup_from_id.pop(
896
+ client.client_id
897
+ )
898
+ client.camera.position = cam_position
899
+ client.camera.look_at = cam_look_at
900
+ client.camera.up_direction = cam_up
901
+ client.flush()
902
+
903
+ # Un-hide scene nodes.
904
+ server.scene.set_global_visibility(True)
905
+
906
+ preview_frame_slider = add_preview_frame_slider()
907
+
908
+ # Update the # of frames.
909
+ @duration_number.on_update
910
+ @framerate_number.on_update
911
+ def _(_) -> None:
912
+ remove_preview_camera() # Will be re-added when slider is updated.
913
+
914
+ nonlocal preview_frame_slider
915
+ old = preview_frame_slider
916
+ assert old is not None
917
+
918
+ preview_frame_slider = add_preview_frame_slider()
919
+ if preview_frame_slider is not None:
920
+ old.remove()
921
+ else:
922
+ preview_frame_slider = old
923
+
924
+ camera_path.framerate = framerate_number.value
925
+ camera_path.update_spline()
926
+
927
+ # Play the camera trajectory when the play button is pressed.
928
+ @play_button.on_click
929
+ def _(_) -> None:
930
+ play_button.visible = False
931
+ pause_button.visible = True
932
+
933
+ def play() -> None:
934
+ while not play_button.visible:
935
+ max_frame = int(framerate_number.value * duration_number.value)
936
+ if max_frame > 0:
937
+ assert preview_frame_slider is not None
938
+ preview_frame_slider.value = (
939
+ preview_frame_slider.value + 1
940
+ ) % max_frame
941
+ time.sleep(1.0 / framerate_number.value)
942
+
943
+ threading.Thread(target=play).start()
944
+
945
+ # Play the camera trajectory when the play button is pressed.
946
+ @pause_button.on_click
947
+ def _(_) -> None:
948
+ play_button.visible = True
949
+ pause_button.visible = False
950
+
951
+ # add button for loading existing path
952
+ load_camera_path_button = server.gui.add_button(
953
+ "Load Path", icon=viser.Icon.FOLDER_OPEN, hint="Load an existing camera path."
954
+ )
955
+
956
+ @load_camera_path_button.on_click
957
+ def _(event: viser.GuiEvent) -> None:
958
+ assert event.client is not None
959
+ camera_path_dir = datapath.parent
960
+ camera_path_dir.mkdir(parents=True, exist_ok=True)
961
+ preexisting_camera_paths = list(camera_path_dir.glob("*.json"))
962
+ preexisting_camera_filenames = [p.name for p in preexisting_camera_paths]
963
+
964
+ with event.client.gui.add_modal("Load Path") as modal:
965
+ if len(preexisting_camera_filenames) == 0:
966
+ event.client.gui.add_markdown("No existing paths found")
967
+ else:
968
+ event.client.gui.add_markdown("Select existing camera path:")
969
+ camera_path_dropdown = event.client.gui.add_dropdown(
970
+ label="Camera Path",
971
+ options=[str(p) for p in preexisting_camera_filenames],
972
+ initial_value=str(preexisting_camera_filenames[0]),
973
+ )
974
+ load_button = event.client.gui.add_button("Load")
975
+
976
+ @load_button.on_click
977
+ def _(_) -> None:
978
+ # load the json file
979
+ json_path = datapath / camera_path_dropdown.value
980
+ with open(json_path, "r") as f:
981
+ json_data = json.load(f)
982
+
983
+ keyframes = json_data["keyframes"]
984
+ camera_path.reset()
985
+ for i in range(len(keyframes)):
986
+ frame = keyframes[i]
987
+ pose = tf.SE3.from_matrix(
988
+ np.array(frame["matrix"]).reshape(4, 4)
989
+ )
990
+ # apply the x rotation by 180 deg
991
+ pose = tf.SE3.from_rotation_and_translation(
992
+ pose.rotation() @ tf.SO3.from_x_radians(np.pi),
993
+ pose.translation(),
994
+ )
995
+
996
+ camera_path.add_camera(
997
+ Keyframe(
998
+ frame["time"],
999
+ position=pose.translation(),
1000
+ wxyz=pose.rotation().wxyz,
1001
+ # There are some floating point conversions between degrees and radians, so the fov and
1002
+ # default_Fov values will not be exactly matched.
1003
+ override_fov_enabled=abs(
1004
+ frame["fov"] - json_data.get("default_fov", 0.0)
1005
+ )
1006
+ > 1e-3,
1007
+ override_fov_rad=frame["fov"] / 180.0 * np.pi,
1008
+ aspect=frame["aspect"],
1009
+ override_transition_enabled=frame.get(
1010
+ "override_transition_enabled", None
1011
+ ),
1012
+ override_transition_sec=frame.get(
1013
+ "override_transition_sec", None
1014
+ ),
1015
+ )
1016
+ )
1017
+
1018
+ transition_sec_number.value = json_data.get(
1019
+ "default_transition_sec", 0.5
1020
+ )
1021
+
1022
+ # update the render name
1023
+ camera_path_name.value = json_path.stem
1024
+ camera_path.update_spline()
1025
+ modal.close()
1026
+
1027
+ cancel_button = event.client.gui.add_button("Cancel")
1028
+
1029
+ @cancel_button.on_click
1030
+ def _(_) -> None:
1031
+ modal.close()
1032
+
1033
+ # set the initial value to the current date-time string
1034
+ now = datetime.datetime.now()
1035
+ camera_path_name = server.gui.add_text(
1036
+ "Camera path name",
1037
+ initial_value=now.strftime("%Y-%m-%d %H:%M:%S"),
1038
+ hint="Name of the render",
1039
+ )
1040
+
1041
+ save_path_button = server.gui.add_button(
1042
+ "Save Camera Path",
1043
+ color="green",
1044
+ icon=viser.Icon.FILE_EXPORT,
1045
+ hint="Save the camera path to json.",
1046
+ )
1047
+
1048
+ reset_up_button = server.gui.add_button(
1049
+ "Reset Up Direction",
1050
+ icon=viser.Icon.ARROW_BIG_UP_LINES,
1051
+ color="gray",
1052
+ hint="Set the up direction of the camera orbit controls to the camera's current up direction.",
1053
+ )
1054
+
1055
+ @reset_up_button.on_click
1056
+ def _(event: viser.GuiEvent) -> None:
1057
+ assert event.client is not None
1058
+ event.client.camera.up_direction = tf.SO3(event.client.camera.wxyz) @ np.array(
1059
+ [0.0, -1.0, 0.0]
1060
+ )
1061
+
1062
+ @save_path_button.on_click
1063
+ def _(event: viser.GuiEvent) -> None:
1064
+ assert event.client is not None
1065
+ num_frames = int(framerate_number.value * duration_number.value)
1066
+ json_data = {}
1067
+ # json data has the properties:
1068
+ # keyframes: list of keyframes with
1069
+ # matrix : flattened 4x4 matrix
1070
+ # fov: float in degrees
1071
+ # aspect: float
1072
+ # camera_type: string of camera type
1073
+ # render_height: int
1074
+ # render_width: int
1075
+ # fps: int
1076
+ # seconds: float
1077
+ # is_cycle: bool
1078
+ # smoothness_value: float
1079
+ # camera_path: list of frames with properties
1080
+ # camera_to_world: flattened 4x4 matrix
1081
+ # fov: float in degrees
1082
+ # aspect: float
1083
+ # first populate the keyframes:
1084
+ keyframes = []
1085
+ for keyframe, dummy in camera_path._keyframes.values():
1086
+ pose = tf.SE3.from_rotation_and_translation(
1087
+ tf.SO3(keyframe.wxyz), keyframe.position
1088
+ )
1089
+ keyframes.append(
1090
+ {
1091
+ "matrix": pose.as_matrix().flatten().tolist(),
1092
+ "fov": (
1093
+ np.rad2deg(keyframe.override_fov_rad)
1094
+ if keyframe.override_fov_enabled
1095
+ else fov_degrees.value
1096
+ ),
1097
+ "aspect": keyframe.aspect,
1098
+ "override_transition_enabled": keyframe.override_transition_enabled,
1099
+ "override_transition_sec": keyframe.override_transition_sec,
1100
+ }
1101
+ )
1102
+ json_data["default_fov"] = fov_degrees.value
1103
+ json_data["default_transition_sec"] = transition_sec_number.value
1104
+ json_data["keyframes"] = keyframes
1105
+ json_data["camera_type"] = camera_type.value.lower()
1106
+ json_data["render_height"] = resolution.value[1]
1107
+ json_data["render_width"] = resolution.value[0]
1108
+ json_data["fps"] = framerate_number.value
1109
+ json_data["seconds"] = duration_number.value
1110
+ json_data["is_cycle"] = loop.value
1111
+ json_data["smoothness_value"] = tension_slider.value
1112
+
1113
+ def get_intrinsics(W, H, fov):
1114
+ focal = 0.5 * H / np.tan(0.5 * fov)
1115
+ return np.array(
1116
+ [[focal, 0.0, 0.5 * W], [0.0, focal, 0.5 * H], [0.0, 0.0, 1.0]]
1117
+ )
1118
+
1119
+ # now populate the camera path:
1120
+ camera_path_list = []
1121
+ for i in range(num_frames):
1122
+ maybe_pose_and_fov_and_time = camera_path.interpolate_pose_and_fov_rad(
1123
+ i / num_frames
1124
+ )
1125
+ if maybe_pose_and_fov_and_time is None:
1126
+ return
1127
+ pose, fov, time = maybe_pose_and_fov_and_time
1128
+ H = resolution.value[1]
1129
+ W = resolution.value[0]
1130
+ K = get_intrinsics(W, H, fov)
1131
+ # rotate the axis of the camera 180 about x axis
1132
+ w2c = pose.inverse().as_matrix()
1133
+ camera_path_list.append(
1134
+ {
1135
+ "time": time,
1136
+ "w2c": w2c.flatten().tolist(),
1137
+ "K": K.flatten().tolist(),
1138
+ "img_wh": (W, H),
1139
+ }
1140
+ )
1141
+ json_data["camera_path"] = camera_path_list
1142
+
1143
+ # now write the json file
1144
+ out_name = camera_path_name.value
1145
+ json_outfile = datapath / f"{out_name}.json"
1146
+ datapath.mkdir(parents=True, exist_ok=True)
1147
+ print(f"writing to {json_outfile}")
1148
+ with open(json_outfile.absolute(), "w") as outfile:
1149
+ json.dump(json_data, outfile)
1150
+
1151
+ camera_path = CameraPath(server, duration_number)
1152
+ camera_path.default_fov = fov_degrees.value / 180.0 * np.pi
1153
+ camera_path.default_transition_sec = transition_sec_number.value
1154
+
1155
+ return render_tab_state
1156
+
1157
+
1158
+ if __name__ == "__main__":
1159
+ populate_render_tab(
1160
+ server=viser.ViserServer(),
1161
+ datapath=Path("."),
1162
+ gui_timestep_handle=None,
1163
+ )
1164
+ while True:
1165
+ time.sleep(10.0)
EXP6_SOMwoTrack_Fulld_EPOCHS200/code/2025-01-28-165740/flow3d/vis/utils.py ADDED
@@ -0,0 +1,544 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import colorsys
2
+ from typing import cast
3
+
4
+ import cv2
5
+ import numpy as np
6
+
7
+ # import nvdiffrast.torch as dr
8
+ import torch
9
+ import torch.nn.functional as F
10
+ from matplotlib import colormaps
11
+ from viser import ViserServer
12
+
13
+
14
+ class Singleton(type):
15
+ _instances = {}
16
+
17
+ def __call__(cls, *args, **kwargs):
18
+ if cls not in cls._instances:
19
+ cls._instances[cls] = super(Singleton, cls).__call__(*args, **kwargs)
20
+ return cls._instances[cls]
21
+
22
+
23
+ class VisManager(metaclass=Singleton):
24
+ _servers = {}
25
+
26
+
27
+ def get_server(port: int | None = None) -> ViserServer:
28
+ manager = VisManager()
29
+ if port is None:
30
+ avail_ports = list(manager._servers.keys())
31
+ port = avail_ports[0] if len(avail_ports) > 0 else 8890
32
+ if port not in manager._servers:
33
+ manager._servers[port] = ViserServer(port=port, verbose=False)
34
+ return manager._servers[port]
35
+
36
+
37
+ def project_2d_tracks(tracks_3d_w, Ks, T_cw, return_depth=False):
38
+ """
39
+ :param tracks_3d_w (torch.Tensor): (T, N, 3)
40
+ :param Ks (torch.Tensor): (T, 3, 3)
41
+ :param T_cw (torch.Tensor): (T, 4, 4)
42
+ :returns tracks_2d (torch.Tensor): (T, N, 2)
43
+ """
44
+ tracks_3d_c = torch.einsum(
45
+ "tij,tnj->tni", T_cw, F.pad(tracks_3d_w, (0, 1), value=1)
46
+ )[..., :3]
47
+ tracks_3d_v = torch.einsum("tij,tnj->tni", Ks, tracks_3d_c)
48
+ if return_depth:
49
+ return (
50
+ tracks_3d_v[..., :2] / torch.clamp(tracks_3d_v[..., 2:], min=1e-5),
51
+ tracks_3d_v[..., 2],
52
+ )
53
+ return tracks_3d_v[..., :2] / torch.clamp(tracks_3d_v[..., 2:], min=1e-5)
54
+
55
+
56
+ def draw_keypoints_video(
57
+ imgs, kps, colors=None, occs=None, cmap: str = "gist_rainbow", radius: int = 3
58
+ ):
59
+ """
60
+ :param imgs (np.ndarray): (T, H, W, 3) uint8 [0, 255]
61
+ :param kps (np.ndarray): (N, T, 2)
62
+ :param colors (np.ndarray): (N, 3) float [0, 1]
63
+ :param occ (np.ndarray): (N, T) bool
64
+ return out_frames (T, H, W, 3)
65
+ """
66
+ if colors is None:
67
+ label = np.linspace(0, 1, kps.shape[0])
68
+ colors = np.asarray(colormaps.get_cmap(cmap)(label))[..., :3]
69
+ out_frames = []
70
+ for t in range(len(imgs)):
71
+ occ = occs[:, t] if occs is not None else None
72
+ vis = draw_keypoints_cv2(imgs[t], kps[:, t], colors, occ, radius=radius)
73
+ out_frames.append(vis)
74
+ return out_frames
75
+
76
+
77
+ def draw_keypoints_cv2(img, kps, colors=None, occs=None, radius=3):
78
+ """
79
+ :param img (H, W, 3)
80
+ :param kps (N, 2)
81
+ :param occs (N)
82
+ :param colors (N, 3) from 0 to 1
83
+ """
84
+ out_img = img.copy()
85
+ kps = kps.round().astype("int").tolist()
86
+ if colors is not None:
87
+ colors = (255 * colors).astype("int").tolist()
88
+ for n in range(len(kps)):
89
+ kp = kps[n]
90
+ color = colors[n] if colors is not None else (255, 0, 0)
91
+ thickness = -1 if occs is None or occs[n] == 0 else 1
92
+ out_img = cv2.circle(out_img, kp, radius, color, thickness, cv2.LINE_AA)
93
+ return out_img
94
+
95
+
96
+ def draw_tracks_2d(
97
+ img: torch.Tensor,
98
+ tracks_2d: torch.Tensor,
99
+ track_point_size: int = 2,
100
+ track_line_width: int = 1,
101
+ cmap_name: str = "gist_rainbow",
102
+ ):
103
+ cmap = colormaps.get_cmap(cmap_name)
104
+ # (H, W, 3).
105
+ img_np = (img.cpu().numpy() * 255.0).astype(np.uint8)
106
+ # (P, N, 2).
107
+ tracks_2d_np = tracks_2d.cpu().numpy()
108
+
109
+ num_tracks, num_frames = tracks_2d_np.shape[:2]
110
+
111
+ canvas = img_np.copy()
112
+ for i in range(num_frames - 1):
113
+ alpha = max(1 - 0.9 * ((num_frames - 1 - i) / (num_frames * 0.99)), 0.1)
114
+ img_curr = canvas.copy()
115
+ for j in range(num_tracks):
116
+ color = tuple(np.array(cmap(j / max(1, float(num_tracks - 1)))[:3]) * 255)
117
+ color_alpha = 1
118
+ hsv = colorsys.rgb_to_hsv(color[0], color[1], color[2])
119
+ color = colorsys.hsv_to_rgb(hsv[0], hsv[1] * color_alpha, hsv[2])
120
+ pt1 = tracks_2d_np[j, i]
121
+ pt2 = tracks_2d_np[j, i + 1]
122
+ p1 = (int(round(pt1[0])), int(round(pt1[1])))
123
+ p2 = (int(round(pt2[0])), int(round(pt2[1])))
124
+ img_curr = cv2.line(
125
+ img_curr,
126
+ p1,
127
+ p2,
128
+ color,
129
+ thickness=track_line_width,
130
+ lineType=cv2.LINE_AA,
131
+ )
132
+ canvas = cv2.addWeighted(img_curr, alpha, canvas, 1 - alpha, 0)
133
+
134
+ for j in range(num_tracks):
135
+ color = tuple(np.array(cmap(j / max(1, float(num_tracks - 1)))[:3]) * 255)
136
+ pt = tracks_2d_np[j, -1]
137
+ pt = (int(round(pt[0])), int(round(pt[1])))
138
+ canvas = cv2.circle(
139
+ canvas,
140
+ pt,
141
+ track_point_size,
142
+ color,
143
+ thickness=-1,
144
+ lineType=cv2.LINE_AA,
145
+ )
146
+
147
+ return canvas
148
+
149
+
150
+ def generate_line_verts_faces(starts, ends, line_width):
151
+ """
152
+ Args:
153
+ starts: (P, N, 2).
154
+ ends: (P, N, 2).
155
+ line_width: int.
156
+
157
+ Returns:
158
+ verts: (P * N * 4, 2).
159
+ faces: (P * N * 2, 3).
160
+ """
161
+ P, N, _ = starts.shape
162
+
163
+ directions = F.normalize(ends - starts, dim=-1)
164
+ deltas = (
165
+ torch.cat([-directions[..., 1:], directions[..., :1]], dim=-1)
166
+ * line_width
167
+ / 2.0
168
+ )
169
+ v0 = starts + deltas
170
+ v1 = starts - deltas
171
+ v2 = ends + deltas
172
+ v3 = ends - deltas
173
+ verts = torch.stack([v0, v1, v2, v3], dim=-2)
174
+ verts = verts.reshape(-1, 2)
175
+
176
+ faces = []
177
+ for p in range(P):
178
+ for n in range(N):
179
+ base_index = p * N * 4 + n * 4
180
+ # Two triangles per rectangle: (0, 1, 2) and (2, 1, 3)
181
+ faces.append([base_index, base_index + 1, base_index + 2])
182
+ faces.append([base_index + 2, base_index + 1, base_index + 3])
183
+ faces = torch.as_tensor(faces, device=starts.device)
184
+
185
+ return verts, faces
186
+
187
+
188
+ def generate_point_verts_faces(points, point_size, num_segments=10):
189
+ """
190
+ Args:
191
+ points: (P, 2).
192
+ point_size: int.
193
+ num_segments: int.
194
+
195
+ Returns:
196
+ verts: (P * (num_segments + 1), 2).
197
+ faces: (P * num_segments, 3).
198
+ """
199
+ P, _ = points.shape
200
+
201
+ angles = torch.linspace(0, 2 * torch.pi, num_segments + 1, device=points.device)[
202
+ ..., :-1
203
+ ]
204
+ unit_circle = torch.stack([torch.cos(angles), torch.sin(angles)], dim=1)
205
+ scaled_circles = (point_size / 2.0) * unit_circle
206
+ scaled_circles = scaled_circles[None].repeat(P, 1, 1)
207
+ verts = points[:, None] + scaled_circles
208
+ verts = torch.cat([verts, points[:, None]], dim=1)
209
+ verts = verts.reshape(-1, 2)
210
+
211
+ faces = F.pad(
212
+ torch.as_tensor(
213
+ [[i, (i + 1) % num_segments] for i in range(num_segments)],
214
+ device=points.device,
215
+ ),
216
+ (0, 1),
217
+ value=num_segments,
218
+ )
219
+ faces = faces[None, :] + torch.arange(P, device=points.device)[:, None, None] * (
220
+ num_segments + 1
221
+ )
222
+ faces = faces.reshape(-1, 3)
223
+
224
+ return verts, faces
225
+
226
+
227
+ def pixel_to_verts_clip(pixels, img_wh, z: float | torch.Tensor = 0.0, w=1.0):
228
+ verts_clip = pixels / pixels.new_tensor(img_wh) * 2.0 - 1.0
229
+ w = torch.full_like(verts_clip[..., :1], w)
230
+ verts_clip = torch.cat([verts_clip, z * w, w], dim=-1)
231
+ return verts_clip
232
+
233
+
234
+ def draw_tracks_2d_th(
235
+ img: torch.Tensor,
236
+ tracks_2d: torch.Tensor,
237
+ track_point_size: int = 5,
238
+ track_point_segments: int = 16,
239
+ track_line_width: int = 2,
240
+ cmap_name: str = "gist_rainbow",
241
+ ):
242
+ cmap = colormaps.get_cmap(cmap_name)
243
+ CTX = dr.RasterizeCudaContext()
244
+
245
+ W, H = img.shape[1], img.shape[0]
246
+ if W % 8 != 0 or H % 8 != 0:
247
+ # Make sure img is divisible by 8.
248
+ img = F.pad(
249
+ img,
250
+ (
251
+ 0,
252
+ 0,
253
+ 0,
254
+ 8 - W % 8 if W % 8 != 0 else 0,
255
+ 0,
256
+ 8 - H % 8 if H % 8 != 0 else 0,
257
+ ),
258
+ value=0.0,
259
+ )
260
+ num_tracks, num_frames = tracks_2d.shape[:2]
261
+
262
+ track_colors = torch.tensor(
263
+ [cmap(j / max(1, float(num_tracks - 1)))[:3] for j in range(num_tracks)],
264
+ device=img.device,
265
+ ).float()
266
+
267
+ # Generate line verts.
268
+ verts_l, faces_l = generate_line_verts_faces(
269
+ tracks_2d[:, :-1], tracks_2d[:, 1:], track_line_width
270
+ )
271
+ # Generate point verts.
272
+ verts_p, faces_p = generate_point_verts_faces(
273
+ tracks_2d[:, -1], track_point_size, track_point_segments
274
+ )
275
+
276
+ verts = torch.cat([verts_l, verts_p], dim=0)
277
+ faces = torch.cat([faces_l, faces_p + len(verts_l)], dim=0)
278
+ vert_colors = torch.cat(
279
+ [
280
+ (
281
+ track_colors[:, None]
282
+ .repeat_interleave(4 * (num_frames - 1), dim=1)
283
+ .reshape(-1, 3)
284
+ ),
285
+ (
286
+ track_colors[:, None]
287
+ .repeat_interleave(track_point_segments + 1, dim=1)
288
+ .reshape(-1, 3)
289
+ ),
290
+ ],
291
+ dim=0,
292
+ )
293
+ track_zs = torch.linspace(0.0, 1.0, num_tracks, device=img.device)[:, None]
294
+ vert_zs = torch.cat(
295
+ [
296
+ (
297
+ track_zs[:, None]
298
+ .repeat_interleave(4 * (num_frames - 1), dim=1)
299
+ .reshape(-1, 1)
300
+ ),
301
+ (
302
+ track_zs[:, None]
303
+ .repeat_interleave(track_point_segments + 1, dim=1)
304
+ .reshape(-1, 1)
305
+ ),
306
+ ],
307
+ dim=0,
308
+ )
309
+ track_alphas = torch.linspace(
310
+ max(0.1, 1.0 - (num_frames - 1) * 0.1), 1.0, num_frames, device=img.device
311
+ )
312
+ vert_alphas = torch.cat(
313
+ [
314
+ (
315
+ track_alphas[None, :-1, None]
316
+ .repeat_interleave(num_tracks, dim=0)
317
+ .repeat_interleave(4, dim=-2)
318
+ .reshape(-1, 1)
319
+ ),
320
+ (
321
+ track_alphas[None, -1:, None]
322
+ .repeat_interleave(num_tracks, dim=0)
323
+ .repeat_interleave(track_point_segments + 1, dim=-2)
324
+ .reshape(-1, 1)
325
+ ),
326
+ ],
327
+ dim=0,
328
+ )
329
+
330
+ # Small trick to always render one track in front of the other.
331
+ verts_clip = pixel_to_verts_clip(verts, (img.shape[1], img.shape[0]), vert_zs)
332
+ faces_int32 = faces.to(torch.int32)
333
+
334
+ rast, _ = cast(
335
+ tuple,
336
+ dr.rasterize(CTX, verts_clip[None], faces_int32, (img.shape[0], img.shape[1])),
337
+ )
338
+ rgba = cast(
339
+ torch.Tensor,
340
+ dr.interpolate(
341
+ torch.cat([vert_colors, vert_alphas], dim=-1).contiguous(),
342
+ rast,
343
+ faces_int32,
344
+ ),
345
+ )[0]
346
+ rgba = cast(torch.Tensor, dr.antialias(rgba, rast, verts_clip, faces_int32))[
347
+ 0
348
+ ].clamp(0, 1)
349
+ # Compose.
350
+ color = rgba[..., :-1] * rgba[..., -1:] + (1.0 - rgba[..., -1:]) * img
351
+
352
+ # Unpad.
353
+ color = color[:H, :W]
354
+
355
+ return (color.cpu().numpy() * 255.0).astype(np.uint8)
356
+
357
+
358
+ def make_video_divisble(
359
+ video: torch.Tensor | np.ndarray, block_size=16
360
+ ) -> torch.Tensor | np.ndarray:
361
+ H, W = video.shape[1:3]
362
+ H_new = H - H % block_size
363
+ W_new = W - W % block_size
364
+ return video[:, :H_new, :W_new]
365
+
366
+
367
+ def apply_float_colormap(img: torch.Tensor, colormap: str = "turbo") -> torch.Tensor:
368
+ """Convert single channel to a color img.
369
+
370
+ Args:
371
+ img (torch.Tensor): (..., 1) float32 single channel image.
372
+ colormap (str): Colormap for img.
373
+
374
+ Returns:
375
+ (..., 3) colored img with colors in [0, 1].
376
+ """
377
+ img = torch.nan_to_num(img, 0)
378
+ if colormap == "gray":
379
+ return img.repeat(1, 1, 3)
380
+ img_long = (img * 255).long()
381
+ img_long_min = torch.min(img_long)
382
+ img_long_max = torch.max(img_long)
383
+ assert img_long_min >= 0, f"the min value is {img_long_min}"
384
+ assert img_long_max <= 255, f"the max value is {img_long_max}"
385
+ return torch.tensor(
386
+ colormaps[colormap].colors, # type: ignore
387
+ device=img.device,
388
+ )[img_long[..., 0]]
389
+
390
+
391
+ def apply_depth_colormap(
392
+ depth: torch.Tensor,
393
+ acc: torch.Tensor | None = None,
394
+ near_plane: float | None = None,
395
+ far_plane: float | None = None,
396
+ ) -> torch.Tensor:
397
+ """Converts a depth image to color for easier analysis.
398
+
399
+ Args:
400
+ depth (torch.Tensor): (..., 1) float32 depth.
401
+ acc (torch.Tensor | None): (..., 1) optional accumulation mask.
402
+ near_plane: Closest depth to consider. If None, use min image value.
403
+ far_plane: Furthest depth to consider. If None, use max image value.
404
+
405
+ Returns:
406
+ (..., 3) colored depth image with colors in [0, 1].
407
+ """
408
+ near_plane = near_plane or float(torch.min(depth))
409
+ far_plane = far_plane or float(torch.max(depth))
410
+ depth = (depth - near_plane) / (far_plane - near_plane + 1e-10)
411
+ depth = torch.clip(depth, 0.0, 1.0)
412
+ img = apply_float_colormap(depth, colormap="turbo")
413
+ if acc is not None:
414
+ img = img * acc + (1.0 - acc)
415
+ return img
416
+
417
+
418
+ def float2uint8(x):
419
+ return (255.0 * x).astype(np.uint8)
420
+
421
+
422
+ def uint82float(img):
423
+ return np.ascontiguousarray(img) / 255.0
424
+
425
+
426
+ def drawMatches(
427
+ img1,
428
+ img2,
429
+ kp1,
430
+ kp2,
431
+ num_vis=200,
432
+ center=None,
433
+ idx_vis=None,
434
+ radius=2,
435
+ seed=1234,
436
+ mask=None,
437
+ ):
438
+ num_pts = len(kp1)
439
+ if idx_vis is None:
440
+ if num_vis < num_pts:
441
+ rng = np.random.RandomState(seed)
442
+ idx_vis = rng.choice(num_pts, num_vis, replace=False)
443
+ else:
444
+ idx_vis = np.arange(num_pts)
445
+
446
+ kp1_vis = kp1[idx_vis]
447
+ kp2_vis = kp2[idx_vis]
448
+
449
+ h1, w1 = img1.shape[:2]
450
+ h2, w2 = img2.shape[:2]
451
+
452
+ kp1_vis[:, 0] = np.clip(kp1_vis[:, 0], a_min=0, a_max=w1 - 1)
453
+ kp1_vis[:, 1] = np.clip(kp1_vis[:, 1], a_min=0, a_max=h1 - 1)
454
+
455
+ kp2_vis[:, 0] = np.clip(kp2_vis[:, 0], a_min=0, a_max=w2 - 1)
456
+ kp2_vis[:, 1] = np.clip(kp2_vis[:, 1], a_min=0, a_max=h2 - 1)
457
+
458
+ img1 = float2uint8(img1)
459
+ img2 = float2uint8(img2)
460
+
461
+ if center is None:
462
+ center = np.median(kp1, axis=0)
463
+
464
+ set_max = range(128)
465
+ colors = {m: i for i, m in enumerate(set_max)}
466
+ hsv = colormaps.get_cmap("hsv")
467
+ colors = {
468
+ m: (255 * np.array(hsv(i / float(len(colors))))[:3][::-1]).astype(np.int32)
469
+ for m, i in colors.items()
470
+ }
471
+
472
+ if mask is not None:
473
+ ind = np.argsort(mask)[::-1]
474
+ kp1_vis = kp1_vis[ind]
475
+ kp2_vis = kp2_vis[ind]
476
+ mask = mask[ind]
477
+
478
+ for i, (pt1, pt2) in enumerate(zip(kp1_vis, kp2_vis)):
479
+ # random_color = tuple(np.random.randint(low=0, high=255, size=(3,)).tolist())
480
+ coord_angle = np.arctan2(pt1[1] - center[1], pt1[0] - center[0])
481
+ corr_color = np.int32(64 * coord_angle / np.pi) % 128
482
+ color = tuple(colors[corr_color].tolist())
483
+
484
+ if (
485
+ (pt1[0] <= w1 - 1)
486
+ and (pt1[0] >= 0)
487
+ and (pt1[1] <= h1 - 1)
488
+ and (pt1[1] >= 0)
489
+ ):
490
+ img1 = cv2.circle(
491
+ img1, (int(pt1[0]), int(pt1[1])), radius, color, -1, cv2.LINE_AA
492
+ )
493
+ if (
494
+ (pt2[0] <= w2 - 1)
495
+ and (pt2[0] >= 0)
496
+ and (pt2[1] <= h2 - 1)
497
+ and (pt2[1] >= 0)
498
+ ):
499
+ if mask is not None and mask[i]:
500
+ continue
501
+ # img2 = cv2.drawMarker(img2, (int(pt2[0]), int(pt2[1])), color, markerType=cv2.MARKER_CROSS,
502
+ # markerSize=int(5*radius), thickness=int(radius/2), line_type=cv2.LINE_AA)
503
+ else:
504
+ img2 = cv2.circle(
505
+ img2, (int(pt2[0]), int(pt2[1])), radius, color, -1, cv2.LINE_AA
506
+ )
507
+
508
+ out = np.concatenate([img1, img2], axis=1)
509
+ return out
510
+
511
+
512
+ def plot_correspondences(
513
+ rgbs, kpts, query_id=0, masks=None, num_vis=1000000, radius=3, seed=1234
514
+ ):
515
+ num_rgbs = len(rgbs)
516
+ rng = np.random.RandomState(seed)
517
+ permutation = rng.permutation(kpts.shape[1])
518
+ kpts = kpts[:, permutation, :][:, :num_vis]
519
+ if masks is not None:
520
+ masks = masks[:, permutation][:, :num_vis]
521
+
522
+ rgbq = rgbs[query_id] # [h, w, 3]
523
+ kptsq = kpts[query_id] # [n, 2]
524
+
525
+ frames = []
526
+ for i in range(num_rgbs):
527
+ rgbi = rgbs[i]
528
+ kptsi = kpts[i]
529
+ if masks is not None:
530
+ maski = masks[i]
531
+ else:
532
+ maski = None
533
+ frame = drawMatches(
534
+ rgbq,
535
+ rgbi,
536
+ kptsq,
537
+ kptsi,
538
+ mask=maski,
539
+ num_vis=num_vis,
540
+ radius=radius,
541
+ seed=seed,
542
+ )
543
+ frames.append(frame)
544
+ return frames