xiaoyuxi commited on
Commit
c8d9d42
·
0 Parent(s):

Cleaned history, reset to current state

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +69 -0
  3. README.md +14 -0
  4. _viz/viz_template.html +1778 -0
  5. app.py +1118 -0
  6. app_3rd/README.md +12 -0
  7. app_3rd/sam_utils/hf_sam_predictor.py +129 -0
  8. app_3rd/sam_utils/inference.py +123 -0
  9. app_3rd/spatrack_utils/infer_track.py +194 -0
  10. config/__init__.py +0 -0
  11. config/magic_infer_moge.yaml +48 -0
  12. examples/backpack.mp4 +3 -0
  13. examples/ball.mp4 +3 -0
  14. examples/basketball.mp4 +3 -0
  15. examples/biker.mp4 +3 -0
  16. examples/cinema_0.mp4 +3 -0
  17. examples/cinema_1.mp4 +3 -0
  18. examples/drifting.mp4 +3 -0
  19. examples/ego_kc1.mp4 +3 -0
  20. examples/ego_teaser.mp4 +3 -0
  21. examples/handwave.mp4 +3 -0
  22. examples/hockey.mp4 +3 -0
  23. examples/ken_block_0.mp4 +3 -0
  24. examples/kiss.mp4 +3 -0
  25. examples/kitchen.mp4 +3 -0
  26. examples/kitchen_egocentric.mp4 +3 -0
  27. examples/pillow.mp4 +3 -0
  28. examples/protein.mp4 +3 -0
  29. examples/pusht.mp4 +3 -0
  30. examples/robot1.mp4 +3 -0
  31. examples/robot2.mp4 +3 -0
  32. examples/robot_3.mp4 +3 -0
  33. examples/robot_unitree.mp4 +3 -0
  34. examples/running.mp4 +3 -0
  35. examples/teleop2.mp4 +3 -0
  36. examples/vertical_place.mp4 +3 -0
  37. models/SpaTrackV2/models/SpaTrack.py +759 -0
  38. models/SpaTrackV2/models/__init__.py +0 -0
  39. models/SpaTrackV2/models/blocks.py +519 -0
  40. models/SpaTrackV2/models/camera_transform.py +248 -0
  41. models/SpaTrackV2/models/depth_refiner/backbone.py +472 -0
  42. models/SpaTrackV2/models/depth_refiner/decode_head.py +619 -0
  43. models/SpaTrackV2/models/depth_refiner/depth_refiner.py +115 -0
  44. models/SpaTrackV2/models/depth_refiner/network.py +429 -0
  45. models/SpaTrackV2/models/depth_refiner/stablilization_attention.py +1187 -0
  46. models/SpaTrackV2/models/depth_refiner/stablizer.py +342 -0
  47. models/SpaTrackV2/models/predictor.py +153 -0
  48. models/SpaTrackV2/models/tracker3D/TrackRefiner.py +1478 -0
  49. models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py +418 -0
  50. models/SpaTrackV2/models/tracker3D/co_tracker/utils.py +929 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore the multi media
2
+ checkpoints
3
+ **/checkpoints/
4
+ **/temp/
5
+ temp
6
+ assets_dev
7
+ assets/example0/results
8
+ assets/example0/snowboard.npz
9
+ assets/example1/results
10
+ assets/davis_eval
11
+ assets/*/results
12
+ *gradio*
13
+ #
14
+ models/monoD/zoeDepth/ckpts/*
15
+ models/monoD/depth_anything/ckpts/*
16
+ vis_results
17
+ dist_encrypted
18
+ # remove the dependencies
19
+ deps
20
+
21
+ # filter the __pycache__ files
22
+ __pycache__/
23
+ /**/**/__pycache__
24
+ /**/__pycache__
25
+
26
+ outputs
27
+ scripts/lauch_exp/config
28
+ scripts/lauch_exp/submit_job.log
29
+ scripts/lauch_exp/hydra_output
30
+ scripts/lauch_wulan
31
+ scripts/custom_video
32
+ # ignore the visualizer
33
+ viser
34
+ viser_result
35
+ benchmark/results
36
+ benchmark
37
+
38
+ ossutil_output
39
+
40
+ prev_version
41
+ spat_ceres
42
+ wandb
43
+ *.log
44
+ seg_target.py
45
+
46
+ eval_davis.py
47
+ eval_multiple_gpu.py
48
+ eval_pose_scan.py
49
+ eval_single_gpu.py
50
+
51
+ infer_cam.py
52
+ infer_stream.py
53
+
54
+ *.egg-info/
55
+ **/*.egg-info
56
+
57
+ eval_kinectics.py
58
+ models/SpaTrackV2/datasets
59
+
60
+ scripts
61
+ config/fix_2d.yaml
62
+
63
+ models/SpaTrackV2/datasets
64
+ scripts/
65
+
66
+ models/**/build
67
+ models/**/dist
68
+
69
+ temp_local
README.md ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SpatialTrackerV2
3
+ emoji: ⚡️
4
+ colorFrom: yellow
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 5.31.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Official Space for SpatialTrackerV2
12
+ ---
13
+
14
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
_viz/viz_template.html ADDED
@@ -0,0 +1,1778 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html lang="en">
3
+ <head>
4
+ <meta charset="UTF-8">
5
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
6
+ <title>3D Point Cloud Visualizer</title>
7
+ <style>
8
+ :root {
9
+ --primary: #9b59b6; /* Brighter purple for dark mode */
10
+ --primary-light: #3a2e4a;
11
+ --secondary: #a86add;
12
+ --accent: #ff6e6e;
13
+ --bg: #1a1a1a;
14
+ --surface: #2c2c2c;
15
+ --text: #e0e0e0;
16
+ --text-secondary: #a0a0a0;
17
+ --border: #444444;
18
+ --shadow: rgba(0, 0, 0, 0.2);
19
+ --shadow-hover: rgba(0, 0, 0, 0.3);
20
+
21
+ --space-sm: 16px;
22
+ --space-md: 24px;
23
+ --space-lg: 32px;
24
+ }
25
+
26
+ body {
27
+ margin: 0;
28
+ overflow: hidden;
29
+ background: var(--bg);
30
+ color: var(--text);
31
+ font-family: 'Inter', sans-serif;
32
+ -webkit-font-smoothing: antialiased;
33
+ }
34
+
35
+ #canvas-container {
36
+ position: absolute;
37
+ width: 100%;
38
+ height: 100%;
39
+ }
40
+
41
+ #ui-container {
42
+ position: absolute;
43
+ top: 0;
44
+ left: 0;
45
+ width: 100%;
46
+ height: 100%;
47
+ pointer-events: none;
48
+ z-index: 10;
49
+ }
50
+
51
+ #status-bar {
52
+ position: absolute;
53
+ top: 16px;
54
+ left: 16px;
55
+ background: rgba(30, 30, 30, 0.9);
56
+ padding: 8px 16px;
57
+ border-radius: 8px;
58
+ pointer-events: auto;
59
+ box-shadow: 0 4px 6px var(--shadow);
60
+ backdrop-filter: blur(4px);
61
+ border: 1px solid var(--border);
62
+ color: var(--text);
63
+ transition: opacity 0.5s ease, transform 0.5s ease;
64
+ font-weight: 500;
65
+ }
66
+
67
+ #status-bar.hidden {
68
+ opacity: 0;
69
+ transform: translateY(-20px);
70
+ pointer-events: none;
71
+ }
72
+
73
+ #control-panel {
74
+ position: absolute;
75
+ bottom: 16px;
76
+ left: 50%;
77
+ transform: translateX(-50%);
78
+ background: rgba(44, 44, 44, 0.95);
79
+ padding: 6px 8px;
80
+ border-radius: 6px;
81
+ display: flex;
82
+ gap: 8px;
83
+ align-items: center;
84
+ justify-content: space-between;
85
+ pointer-events: auto;
86
+ box-shadow: 0 4px 10px var(--shadow);
87
+ backdrop-filter: blur(4px);
88
+ border: 1px solid var(--border);
89
+ }
90
+
91
+ #timeline {
92
+ width: 150px;
93
+ height: 4px;
94
+ background: rgba(255, 255, 255, 0.1);
95
+ border-radius: 2px;
96
+ position: relative;
97
+ cursor: pointer;
98
+ }
99
+
100
+ #progress {
101
+ position: absolute;
102
+ height: 100%;
103
+ background: var(--primary);
104
+ border-radius: 2px;
105
+ width: 0%;
106
+ }
107
+
108
+ #playback-controls {
109
+ display: flex;
110
+ gap: 4px;
111
+ align-items: center;
112
+ }
113
+
114
+ button {
115
+ background: rgba(255, 255, 255, 0.08);
116
+ border: 1px solid var(--border);
117
+ color: var(--text);
118
+ padding: 4px 6px;
119
+ border-radius: 3px;
120
+ cursor: pointer;
121
+ display: flex;
122
+ align-items: center;
123
+ justify-content: center;
124
+ transition: background 0.2s, transform 0.2s;
125
+ font-family: 'Inter', sans-serif;
126
+ font-weight: 500;
127
+ font-size: 6px;
128
+ }
129
+
130
+ button:hover {
131
+ background: rgba(255, 255, 255, 0.15);
132
+ transform: translateY(-1px);
133
+ }
134
+
135
+ button.active {
136
+ background: var(--primary);
137
+ color: white;
138
+ box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
139
+ }
140
+
141
+ select, input {
142
+ background: rgba(255, 255, 255, 0.08);
143
+ border: 1px solid var(--border);
144
+ color: var(--text);
145
+ padding: 4px 6px;
146
+ border-radius: 3px;
147
+ cursor: pointer;
148
+ font-family: 'Inter', sans-serif;
149
+ font-size: 6px;
150
+ }
151
+
152
+ .icon {
153
+ width: 10px;
154
+ height: 10px;
155
+ fill: currentColor;
156
+ }
157
+
158
+ .tooltip {
159
+ position: absolute;
160
+ bottom: 100%;
161
+ left: 50%;
162
+ transform: translateX(-50%);
163
+ background: var(--surface);
164
+ color: var(--text);
165
+ padding: 3px 6px;
166
+ border-radius: 3px;
167
+ font-size: 7px;
168
+ white-space: nowrap;
169
+ margin-bottom: 4px;
170
+ opacity: 0;
171
+ transition: opacity 0.2s;
172
+ pointer-events: none;
173
+ box-shadow: 0 2px 4px var(--shadow);
174
+ border: 1px solid var(--border);
175
+ }
176
+
177
+ button:hover .tooltip {
178
+ opacity: 1;
179
+ }
180
+
181
+ #settings-panel {
182
+ position: absolute;
183
+ top: 16px;
184
+ right: 16px;
185
+ background: rgba(44, 44, 44, 0.98);
186
+ padding: 10px;
187
+ border-radius: 6px;
188
+ width: 195px;
189
+ max-height: calc(100vh - 40px);
190
+ overflow-y: auto;
191
+ pointer-events: auto;
192
+ box-shadow: 0 4px 15px var(--shadow);
193
+ backdrop-filter: blur(4px);
194
+ border: 1px solid var(--border);
195
+ display: block;
196
+ opacity: 1;
197
+ scrollbar-width: thin;
198
+ scrollbar-color: var(--primary-light) transparent;
199
+ transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
200
+ }
201
+
202
+ #settings-panel.is-hidden {
203
+ transform: translateX(calc(100% + 20px));
204
+ opacity: 0;
205
+ pointer-events: none;
206
+ }
207
+
208
+ #settings-panel::-webkit-scrollbar {
209
+ width: 3px;
210
+ }
211
+
212
+ #settings-panel::-webkit-scrollbar-track {
213
+ background: transparent;
214
+ }
215
+
216
+ #settings-panel::-webkit-scrollbar-thumb {
217
+ background-color: var(--primary-light);
218
+ border-radius: 3px;
219
+ }
220
+
221
+ @media (max-height: 700px) {
222
+ #settings-panel {
223
+ max-height: calc(100vh - 40px);
224
+ }
225
+ }
226
+
227
+ @media (max-width: 768px) {
228
+ #control-panel {
229
+ width: 90%;
230
+ flex-wrap: wrap;
231
+ justify-content: center;
232
+ }
233
+
234
+ #timeline {
235
+ width: 100%;
236
+ order: 3;
237
+ margin-top: 10px;
238
+ }
239
+
240
+ #settings-panel {
241
+ width: 140px;
242
+ right: 10px;
243
+ top: 10px;
244
+ max-height: calc(100vh - 20px);
245
+ }
246
+ }
247
+
248
+ .settings-group {
249
+ margin-bottom: 8px;
250
+ }
251
+
252
+ .settings-group h3 {
253
+ margin: 0 0 6px 0;
254
+ font-size: 10px;
255
+ font-weight: 500;
256
+ color: var(--text-secondary);
257
+ }
258
+
259
+ .slider-container {
260
+ display: flex;
261
+ align-items: center;
262
+ gap: 6px;
263
+ width: 100%;
264
+ }
265
+
266
+ .slider-container label {
267
+ min-width: 60px;
268
+ font-size: 10px;
269
+ flex-shrink: 0;
270
+ }
271
+
272
+ input[type="range"] {
273
+ flex: 1;
274
+ height: 2px;
275
+ -webkit-appearance: none;
276
+ background: rgba(255, 255, 255, 0.1);
277
+ border-radius: 1px;
278
+ min-width: 0;
279
+ }
280
+
281
+ input[type="range"]::-webkit-slider-thumb {
282
+ -webkit-appearance: none;
283
+ width: 8px;
284
+ height: 8px;
285
+ border-radius: 50%;
286
+ background: var(--primary);
287
+ cursor: pointer;
288
+ }
289
+
290
+ .toggle-switch {
291
+ position: relative;
292
+ display: inline-block;
293
+ width: 20px;
294
+ height: 10px;
295
+ }
296
+
297
+ .toggle-switch input {
298
+ opacity: 0;
299
+ width: 0;
300
+ height: 0;
301
+ }
302
+
303
+ .toggle-slider {
304
+ position: absolute;
305
+ cursor: pointer;
306
+ top: 0;
307
+ left: 0;
308
+ right: 0;
309
+ bottom: 0;
310
+ background: rgba(255, 255, 255, 0.1);
311
+ transition: .4s;
312
+ border-radius: 10px;
313
+ }
314
+
315
+ .toggle-slider:before {
316
+ position: absolute;
317
+ content: "";
318
+ height: 8px;
319
+ width: 8px;
320
+ left: 1px;
321
+ bottom: 1px;
322
+ background: var(--surface);
323
+ border: 1px solid var(--border);
324
+ transition: .4s;
325
+ border-radius: 50%;
326
+ }
327
+
328
+ input:checked + .toggle-slider {
329
+ background: var(--primary);
330
+ }
331
+
332
+ input:checked + .toggle-slider:before {
333
+ transform: translateX(10px);
334
+ }
335
+
336
+ .checkbox-container {
337
+ display: flex;
338
+ align-items: center;
339
+ gap: 4px;
340
+ margin-bottom: 4px;
341
+ }
342
+
343
+ .checkbox-container label {
344
+ font-size: 10px;
345
+ cursor: pointer;
346
+ }
347
+
348
+ #loading-overlay {
349
+ position: absolute;
350
+ top: 0;
351
+ left: 0;
352
+ width: 100%;
353
+ height: 100%;
354
+ background: var(--bg);
355
+ display: flex;
356
+ flex-direction: column;
357
+ align-items: center;
358
+ justify-content: center;
359
+ z-index: 100;
360
+ transition: opacity 0.5s;
361
+ }
362
+
363
+ #loading-overlay.fade-out {
364
+ opacity: 0;
365
+ pointer-events: none;
366
+ }
367
+
368
+ .spinner {
369
+ width: 50px;
370
+ height: 50px;
371
+ border: 5px solid rgba(155, 89, 182, 0.2);
372
+ border-radius: 50%;
373
+ border-top-color: var(--primary);
374
+ animation: spin 1s ease-in-out infinite;
375
+ margin-bottom: 16px;
376
+ }
377
+
378
+ @keyframes spin {
379
+ to { transform: rotate(360deg); }
380
+ }
381
+
382
+ #loading-text {
383
+ margin-top: 16px;
384
+ font-size: 18px;
385
+ color: var(--text);
386
+ font-weight: 500;
387
+ }
388
+
389
+ #frame-counter {
390
+ color: var(--text-secondary);
391
+ font-size: 7px;
392
+ font-weight: 500;
393
+ min-width: 60px;
394
+ text-align: center;
395
+ padding: 0 4px;
396
+ }
397
+
398
+ .control-btn {
399
+ background: rgba(255, 255, 255, 0.08);
400
+ border: 1px solid var(--border);
401
+ padding: 4px 6px;
402
+ border-radius: 3px;
403
+ cursor: pointer;
404
+ display: flex;
405
+ align-items: center;
406
+ justify-content: center;
407
+ transition: all 0.2s ease;
408
+ font-size: 6px;
409
+ }
410
+
411
+ .control-btn:hover {
412
+ background: rgba(255, 255, 255, 0.15);
413
+ transform: translateY(-1px);
414
+ }
415
+
416
+ .control-btn.active {
417
+ background: var(--primary);
418
+ color: white;
419
+ }
420
+
421
+ .control-btn.active:hover {
422
+ background: var(--primary);
423
+ box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
424
+ }
425
+
426
+ #settings-toggle-btn {
427
+ position: relative;
428
+ border-radius: 6px;
429
+ z-index: 20;
430
+ }
431
+
432
+ #settings-toggle-btn.active {
433
+ background: var(--primary);
434
+ color: white;
435
+ }
436
+
437
+ #status-bar,
438
+ #control-panel,
439
+ #settings-panel,
440
+ button,
441
+ input,
442
+ select,
443
+ .toggle-switch {
444
+ pointer-events: auto;
445
+ }
446
+
447
+ h2 {
448
+ font-size: 0.9rem;
449
+ font-weight: 600;
450
+ margin-top: 0;
451
+ margin-bottom: 12px;
452
+ color: var(--primary);
453
+ cursor: move;
454
+ user-select: none;
455
+ display: flex;
456
+ align-items: center;
457
+ }
458
+
459
+ .drag-handle {
460
+ font-size: 10px;
461
+ margin-right: 4px;
462
+ opacity: 0.6;
463
+ }
464
+
465
+ h2:hover .drag-handle {
466
+ opacity: 1;
467
+ }
468
+
469
+ .loading-subtitle {
470
+ font-size: 7px;
471
+ color: var(--text-secondary);
472
+ margin-top: 4px;
473
+ }
474
+
475
+ #reset-view-btn {
476
+ background: var(--primary-light);
477
+ color: var(--primary);
478
+ border: 1px solid rgba(155, 89, 182, 0.2);
479
+ font-weight: 600;
480
+ font-size: 9px;
481
+ padding: 4px 6px;
482
+ transition: all 0.2s;
483
+ }
484
+
485
+ #reset-view-btn:hover {
486
+ background: var(--primary);
487
+ color: white;
488
+ transform: translateY(-2px);
489
+ box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
490
+ }
491
+
492
+ #show-settings-btn {
493
+ position: absolute;
494
+ top: 16px;
495
+ right: 16px;
496
+ z-index: 15;
497
+ display: none;
498
+ }
499
+
500
+ #settings-panel.visible {
501
+ display: block;
502
+ opacity: 1;
503
+ animation: slideIn 0.3s ease forwards;
504
+ }
505
+
506
+ @keyframes slideIn {
507
+ from {
508
+ transform: translateY(20px);
509
+ opacity: 0;
510
+ }
511
+ to {
512
+ transform: translateY(0);
513
+ opacity: 1;
514
+ }
515
+ }
516
+
517
+ .dragging {
518
+ opacity: 0.9;
519
+ box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
520
+ transition: none !important;
521
+ }
522
+
523
+ /* Tooltip for draggable element */
524
+ .tooltip-drag {
525
+ position: absolute;
526
+ left: 50%;
527
+ transform: translateX(-50%);
528
+ background: var(--primary);
529
+ color: white;
530
+ font-size: 9px;
531
+ padding: 2px 4px;
532
+ border-radius: 2px;
533
+ opacity: 0;
534
+ pointer-events: none;
535
+ transition: opacity 0.3s;
536
+ white-space: nowrap;
537
+ bottom: 100%;
538
+ margin-bottom: 4px;
539
+ }
540
+
541
+ h2:hover .tooltip-drag {
542
+ opacity: 1;
543
+ }
544
+
545
+ .btn-group {
546
+ display: flex;
547
+ margin-top: 8px;
548
+ }
549
+
550
+ #reset-settings-btn {
551
+ background: var(--primary-light);
552
+ color: var(--primary);
553
+ border: 1px solid rgba(155, 89, 182, 0.2);
554
+ font-weight: 600;
555
+ font-size: 9px;
556
+ padding: 4px 6px;
557
+ transition: all 0.2s;
558
+ }
559
+
560
+ #reset-settings-btn:hover {
561
+ background: var(--primary);
562
+ color: white;
563
+ transform: translateY(-2px);
564
+ box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
565
+ }
566
+ </style>
567
+ </head>
568
+ <body>
569
+ <link rel="preconnect" href="https://fonts.googleapis.com">
570
+ <link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
571
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
572
+
573
+ <div id="canvas-container"></div>
574
+
575
+ <div id="ui-container">
576
+ <div id="status-bar">Initializing...</div>
577
+
578
+ <div id="control-panel">
579
+ <button id="play-pause-btn" class="control-btn">
580
+ <svg class="icon" viewBox="0 0 24 24">
581
+ <path id="play-icon" d="M8 5v14l11-7z"/>
582
+ <path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
583
+ </svg>
584
+ <span class="tooltip">Play/Pause</span>
585
+ </button>
586
+
587
+ <div id="timeline">
588
+ <div id="progress"></div>
589
+ </div>
590
+
591
+ <div id="frame-counter">Frame 0 / 0</div>
592
+
593
+ <div id="playback-controls">
594
+ <button id="speed-btn" class="control-btn">1x</button>
595
+ </div>
596
+ </div>
597
+
598
+ <div id="settings-panel">
599
+ <h2>
600
+ <span class="drag-handle">☰</span>
601
+ Visualization Settings
602
+ <button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 2px;" title="Hide Panel">
603
+ <svg class="icon" viewBox="0 0 24 24" style="width: 9px; height: 9px;">
604
+ <path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
605
+ </svg>
606
+ </button>
607
+ </h2>
608
+
609
+ <div class="settings-group">
610
+ <h3>Point Cloud</h3>
611
+ <div class="slider-container">
612
+ <label for="point-size">Size</label>
613
+ <input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
614
+ </div>
615
+ <div class="slider-container">
616
+ <label for="point-opacity">Opacity</label>
617
+ <input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
618
+ </div>
619
+ <div class="slider-container">
620
+ <label for="max-depth">Max Depth</label>
621
+ <input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
622
+ </div>
623
+ </div>
624
+
625
+ <div class="settings-group">
626
+ <h3>Trajectory</h3>
627
+ <div class="checkbox-container">
628
+ <label class="toggle-switch">
629
+ <input type="checkbox" id="show-trajectory" checked>
630
+ <span class="toggle-slider"></span>
631
+ </label>
632
+ <label for="show-trajectory">Show Trajectory</label>
633
+ </div>
634
+ <div class="checkbox-container">
635
+ <label class="toggle-switch">
636
+ <input type="checkbox" id="enable-rich-trail">
637
+ <span class="toggle-slider"></span>
638
+ </label>
639
+ <label for="enable-rich-trail">Visual-Rich Trail</label>
640
+ </div>
641
+ <div class="slider-container">
642
+ <label for="trajectory-line-width">Line Width</label>
643
+ <input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
644
+ </div>
645
+ <div class="slider-container">
646
+ <label for="trajectory-ball-size">Ball Size</label>
647
+ <input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
648
+ </div>
649
+ <div class="slider-container">
650
+ <label for="trajectory-history">History Frames</label>
651
+ <input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
652
+ </div>
653
+ <div class="slider-container" id="tail-opacity-container" style="display: none;">
654
+ <label for="trajectory-fade">Tail Opacity</label>
655
+ <input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
656
+ </div>
657
+ </div>
658
+
659
+ <div class="settings-group">
660
+ <h3>Camera</h3>
661
+ <div class="checkbox-container">
662
+ <label class="toggle-switch">
663
+ <input type="checkbox" id="show-camera-frustum" checked>
664
+ <span class="toggle-slider"></span>
665
+ </label>
666
+ <label for="show-camera-frustum">Show Camera Frustum</label>
667
+ </div>
668
+ <div class="slider-container">
669
+ <label for="frustum-size">Size</label>
670
+ <input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
671
+ </div>
672
+ </div>
673
+
674
+ <div class="settings-group">
675
+ <div class="btn-group">
676
+ <button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
677
+ <button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
678
+ </div>
679
+ </div>
680
+ </div>
681
+
682
+ <button id="show-settings-btn" class="control-btn" title="Show Settings">
683
+ <svg class="icon" viewBox="0 0 24 24">
684
+ <path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
685
+ </svg>
686
+ </button>
687
+ </div>
688
+
689
+ <div id="loading-overlay">
690
+ <!-- <div class="spinner"></div> -->
691
+ <div id="loading-text"></div>
692
+ <div class="loading-subtitle" style="font-size: medium;">Interactive Viewer of 3D Tracking</div>
693
+ </div>
694
+
695
+ <!-- Libraries -->
696
+ <script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
697
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/build/three.min.js"></script>
698
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/controls/OrbitControls.js"></script>
699
+ <script src="https://cdn.jsdelivr.net/npm/dat.gui@0.7.7/build/dat.gui.min.js"></script>
700
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineSegmentsGeometry.js"></script>
701
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineGeometry.js"></script>
702
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineMaterial.js"></script>
703
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineSegments2.js"></script>
704
+ <script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/Line2.js"></script>
705
+
706
+ <script>
707
+ class PointCloudVisualizer {
708
+ constructor() {
709
+ this.data = null;
710
+ this.config = {};
711
+ this.currentFrame = 0;
712
+ this.isPlaying = false;
713
+ this.playbackSpeed = 1;
714
+ this.lastFrameTime = 0;
715
+ this.defaultSettings = null;
716
+
717
+ this.ui = {
718
+ statusBar: document.getElementById('status-bar'),
719
+ playPauseBtn: document.getElementById('play-pause-btn'),
720
+ speedBtn: document.getElementById('speed-btn'),
721
+ timeline: document.getElementById('timeline'),
722
+ progress: document.getElementById('progress'),
723
+ settingsPanel: document.getElementById('settings-panel'),
724
+ loadingOverlay: document.getElementById('loading-overlay'),
725
+ loadingText: document.getElementById('loading-text'),
726
+ settingsToggleBtn: document.getElementById('settings-toggle-btn'),
727
+ frameCounter: document.getElementById('frame-counter'),
728
+ pointSize: document.getElementById('point-size'),
729
+ pointOpacity: document.getElementById('point-opacity'),
730
+ maxDepth: document.getElementById('max-depth'),
731
+ showTrajectory: document.getElementById('show-trajectory'),
732
+ enableRichTrail: document.getElementById('enable-rich-trail'),
733
+ trajectoryLineWidth: document.getElementById('trajectory-line-width'),
734
+ trajectoryBallSize: document.getElementById('trajectory-ball-size'),
735
+ trajectoryHistory: document.getElementById('trajectory-history'),
736
+ trajectoryFade: document.getElementById('trajectory-fade'),
737
+ tailOpacityContainer: document.getElementById('tail-opacity-container'),
738
+ resetViewBtn: document.getElementById('reset-view-btn'),
739
+ showCameraFrustum: document.getElementById('show-camera-frustum'),
740
+ frustumSize: document.getElementById('frustum-size'),
741
+ hideSettingsBtn: document.getElementById('hide-settings-btn'),
742
+ showSettingsBtn: document.getElementById('show-settings-btn')
743
+ };
744
+
745
+ this.scene = null;
746
+ this.camera = null;
747
+ this.renderer = null;
748
+ this.controls = null;
749
+ this.pointCloud = null;
750
+ this.trajectories = [];
751
+ this.cameraFrustum = null;
752
+
753
+ this.initThreeJS();
754
+ this.loadDefaultSettings().then(() => {
755
+ this.initEventListeners();
756
+ this.loadData();
757
+ });
758
+ }
759
+
760
+ async loadDefaultSettings() {
761
+ try {
762
+ const urlParams = new URLSearchParams(window.location.search);
763
+ const dataPath = urlParams.get('data') || '';
764
+
765
+ const defaultSettings = {
766
+ pointSize: 0.03,
767
+ pointOpacity: 1.0,
768
+ showTrajectory: true,
769
+ trajectoryLineWidth: 2.5,
770
+ trajectoryBallSize: 0.015,
771
+ trajectoryHistory: 0,
772
+ showCameraFrustum: true,
773
+ frustumSize: 0.2
774
+ };
775
+
776
+ if (!dataPath) {
777
+ this.defaultSettings = defaultSettings;
778
+ this.applyDefaultSettings();
779
+ return;
780
+ }
781
+
782
+ // Try to extract dataset and videoId from the data path
783
+ // Expected format: demos/datasetname/videoid.bin
784
+ const pathParts = dataPath.split('/');
785
+ if (pathParts.length < 3) {
786
+ this.defaultSettings = defaultSettings;
787
+ this.applyDefaultSettings();
788
+ return;
789
+ }
790
+
791
+ const datasetName = pathParts[pathParts.length - 2];
792
+ let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
793
+
794
+ // Load settings from data.json
795
+ const response = await fetch('./data.json');
796
+ if (!response.ok) {
797
+ this.defaultSettings = defaultSettings;
798
+ this.applyDefaultSettings();
799
+ return;
800
+ }
801
+
802
+ const settingsData = await response.json();
803
+
804
+ // Check if this dataset and video exist
805
+ if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
806
+ this.defaultSettings = settingsData[datasetName][videoId];
807
+ } else {
808
+ this.defaultSettings = defaultSettings;
809
+ }
810
+
811
+ this.applyDefaultSettings();
812
+ } catch (error) {
813
+ console.error("Error loading default settings:", error);
814
+
815
+ this.defaultSettings = {
816
+ pointSize: 0.03,
817
+ pointOpacity: 1.0,
818
+ showTrajectory: true,
819
+ trajectoryLineWidth: 2.5,
820
+ trajectoryBallSize: 0.015,
821
+ trajectoryHistory: 0,
822
+ showCameraFrustum: true,
823
+ frustumSize: 0.2
824
+ };
825
+
826
+ this.applyDefaultSettings();
827
+ }
828
+ }
829
+
830
+ applyDefaultSettings() {
831
+ if (!this.defaultSettings) return;
832
+
833
+ if (this.ui.pointSize) {
834
+ this.ui.pointSize.value = this.defaultSettings.pointSize;
835
+ }
836
+
837
+ if (this.ui.pointOpacity) {
838
+ this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
839
+ }
840
+
841
+ if (this.ui.maxDepth) {
842
+ this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
843
+ }
844
+
845
+ if (this.ui.showTrajectory) {
846
+ this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
847
+ }
848
+
849
+ if (this.ui.trajectoryLineWidth) {
850
+ this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
851
+ }
852
+
853
+ if (this.ui.trajectoryBallSize) {
854
+ this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
855
+ }
856
+
857
+ if (this.ui.trajectoryHistory) {
858
+ this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
859
+ }
860
+
861
+ if (this.ui.showCameraFrustum) {
862
+ this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
863
+ }
864
+
865
+ if (this.ui.frustumSize) {
866
+ this.ui.frustumSize.value = this.defaultSettings.frustumSize;
867
+ }
868
+ }
869
+
870
+ initThreeJS() {
871
+ this.scene = new THREE.Scene();
872
+ this.scene.background = new THREE.Color(0x1a1a1a);
873
+
874
+ this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
875
+ this.camera.position.set(0, 0, 0);
876
+
877
+ this.renderer = new THREE.WebGLRenderer({ antialias: true });
878
+ this.renderer.setPixelRatio(window.devicePixelRatio);
879
+ this.renderer.setSize(window.innerWidth, window.innerHeight);
880
+ document.getElementById('canvas-container').appendChild(this.renderer.domElement);
881
+
882
+ this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
883
+ this.controls.enableDamping = true;
884
+ this.controls.dampingFactor = 0.05;
885
+ this.controls.target.set(0, 0, 0);
886
+ this.controls.minDistance = 0.1;
887
+ this.controls.maxDistance = 1000;
888
+ this.controls.update();
889
+
890
+ const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
891
+ this.scene.add(ambientLight);
892
+
893
+ const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
894
+ directionalLight.position.set(1, 1, 1);
895
+ this.scene.add(directionalLight);
896
+ }
897
+
898
+ initEventListeners() {
899
+ window.addEventListener('resize', () => this.onWindowResize());
900
+
901
+ this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
902
+
903
+ this.ui.timeline.addEventListener('click', (e) => {
904
+ const rect = this.ui.timeline.getBoundingClientRect();
905
+ const pos = (e.clientX - rect.left) / rect.width;
906
+ this.seekTo(pos);
907
+ });
908
+
909
+ this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
910
+
911
+ this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
912
+ this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
913
+ this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
914
+ this.ui.showTrajectory.addEventListener('change', () => {
915
+ this.trajectories.forEach(trajectory => {
916
+ trajectory.visible = this.ui.showTrajectory.checked;
917
+ });
918
+ });
919
+
920
+ this.ui.enableRichTrail.addEventListener('change', () => {
921
+ this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
922
+ this.updateTrajectories(this.currentFrame);
923
+ });
924
+
925
+ this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
926
+ this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
927
+ this.ui.trajectoryHistory.addEventListener('input', () => {
928
+ this.updateTrajectories(this.currentFrame);
929
+ });
930
+ this.ui.trajectoryFade.addEventListener('input', () => {
931
+ this.updateTrajectories(this.currentFrame);
932
+ });
933
+
934
+ this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
935
+
936
+ const resetSettingsBtn = document.getElementById('reset-settings-btn');
937
+ if (resetSettingsBtn) {
938
+ resetSettingsBtn.addEventListener('click', () => this.resetSettings());
939
+ }
940
+
941
+ document.addEventListener('keydown', (e) => {
942
+ if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
943
+ this.ui.settingsPanel.classList.remove('visible');
944
+ this.ui.settingsToggleBtn.classList.remove('active');
945
+ }
946
+ });
947
+
948
+ if (this.ui.settingsToggleBtn) {
949
+ this.ui.settingsToggleBtn.addEventListener('click', () => {
950
+ const isVisible = this.ui.settingsPanel.classList.toggle('visible');
951
+ this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
952
+
953
+ if (isVisible) {
954
+ const panelRect = this.ui.settingsPanel.getBoundingClientRect();
955
+ const viewportHeight = window.innerHeight;
956
+
957
+ if (panelRect.bottom > viewportHeight) {
958
+ this.ui.settingsPanel.style.bottom = 'auto';
959
+ this.ui.settingsPanel.style.top = '80px';
960
+ }
961
+ }
962
+ });
963
+ }
964
+
965
+ if (this.ui.frustumSize) {
966
+ this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
967
+ }
968
+
969
+ if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
970
+ this.ui.hideSettingsBtn.addEventListener('click', () => {
971
+ this.ui.settingsPanel.classList.add('is-hidden');
972
+ this.ui.showSettingsBtn.style.display = 'flex';
973
+ });
974
+
975
+ this.ui.showSettingsBtn.addEventListener('click', () => {
976
+ this.ui.settingsPanel.classList.remove('is-hidden');
977
+ this.ui.showSettingsBtn.style.display = 'none';
978
+ });
979
+ }
980
+ }
981
+
982
+ makeElementDraggable(element) {
983
+ let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
984
+
985
+ const dragHandle = element.querySelector('h2');
986
+
987
+ if (dragHandle) {
988
+ dragHandle.onmousedown = dragMouseDown;
989
+ dragHandle.title = "Drag to move panel";
990
+ } else {
991
+ element.onmousedown = dragMouseDown;
992
+ }
993
+
994
+ function dragMouseDown(e) {
995
+ e = e || window.event;
996
+ e.preventDefault();
997
+ pos3 = e.clientX;
998
+ pos4 = e.clientY;
999
+ document.onmouseup = closeDragElement;
1000
+ document.onmousemove = elementDrag;
1001
+
1002
+ element.classList.add('dragging');
1003
+ }
1004
+
1005
+ function elementDrag(e) {
1006
+ e = e || window.event;
1007
+ e.preventDefault();
1008
+ pos1 = pos3 - e.clientX;
1009
+ pos2 = pos4 - e.clientY;
1010
+ pos3 = e.clientX;
1011
+ pos4 = e.clientY;
1012
+
1013
+ const newTop = element.offsetTop - pos2;
1014
+ const newLeft = element.offsetLeft - pos1;
1015
+
1016
+ const viewportWidth = window.innerWidth;
1017
+ const viewportHeight = window.innerHeight;
1018
+
1019
+ const panelRect = element.getBoundingClientRect();
1020
+
1021
+ const maxTop = viewportHeight - 50;
1022
+ const maxLeft = viewportWidth - 50;
1023
+
1024
+ element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
1025
+ element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
1026
+
1027
+ // Remove bottom/right settings when dragging
1028
+ element.style.bottom = 'auto';
1029
+ element.style.right = 'auto';
1030
+ }
1031
+
1032
+ function closeDragElement() {
1033
+ document.onmouseup = null;
1034
+ document.onmousemove = null;
1035
+
1036
+ element.classList.remove('dragging');
1037
+ }
1038
+ }
1039
+
1040
+ async loadData() {
1041
+ try {
1042
+ // this.ui.loadingText.textContent = "Loading binary data...";
1043
+
1044
+ let arrayBuffer;
1045
+
1046
+ if (window.embeddedBase64) {
1047
+ // Base64 embedded path
1048
+ const binaryString = atob(window.embeddedBase64);
1049
+ const len = binaryString.length;
1050
+ const bytes = new Uint8Array(len);
1051
+ for (let i = 0; i < len; i++) {
1052
+ bytes[i] = binaryString.charCodeAt(i);
1053
+ }
1054
+ arrayBuffer = bytes.buffer;
1055
+ } else {
1056
+ // Default fetch path (fallback)
1057
+ const urlParams = new URLSearchParams(window.location.search);
1058
+ const dataPath = urlParams.get('data') || 'data.bin';
1059
+
1060
+ const response = await fetch(dataPath);
1061
+ if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
1062
+ arrayBuffer = await response.arrayBuffer();
1063
+ }
1064
+
1065
+ const dataView = new DataView(arrayBuffer);
1066
+ const headerLen = dataView.getUint32(0, true);
1067
+
1068
+ const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
1069
+ const header = JSON.parse(headerText);
1070
+
1071
+ const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
1072
+ const decompressed = pako.inflate(compressedBlob).buffer;
1073
+
1074
+ const arrays = {};
1075
+ for (const key in header) {
1076
+ if (key === "meta") continue;
1077
+
1078
+ const meta = header[key];
1079
+ const { dtype, shape, offset, length } = meta;
1080
+ const slice = decompressed.slice(offset, offset + length);
1081
+
1082
+ let typedArray;
1083
+ switch (dtype) {
1084
+ case "uint8": typedArray = new Uint8Array(slice); break;
1085
+ case "uint16": typedArray = new Uint16Array(slice); break;
1086
+ case "float32": typedArray = new Float32Array(slice); break;
1087
+ case "float64": typedArray = new Float64Array(slice); break;
1088
+ default: throw new Error(`Unknown dtype: ${dtype}`);
1089
+ }
1090
+
1091
+ arrays[key] = { data: typedArray, shape: shape };
1092
+ }
1093
+
1094
+ this.data = arrays;
1095
+ this.config = header.meta;
1096
+
1097
+ this.initCameraWithCorrectFOV();
1098
+ this.ui.loadingText.textContent = "Creating point cloud...";
1099
+
1100
+ this.initPointCloud();
1101
+ this.initTrajectories();
1102
+
1103
+ setTimeout(() => {
1104
+ this.ui.loadingOverlay.classList.add('fade-out');
1105
+ this.ui.statusBar.classList.add('hidden');
1106
+ this.startAnimation();
1107
+ }, 500);
1108
+ } catch (error) {
1109
+ console.error("Error loading data:", error);
1110
+ this.ui.statusBar.textContent = `Error: ${error.message}`;
1111
+ // this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
1112
+ }
1113
+ }
1114
+
1115
+ initPointCloud() {
1116
+ const numPoints = this.config.resolution[0] * this.config.resolution[1];
1117
+ const positions = new Float32Array(numPoints * 3);
1118
+ const colors = new Float32Array(numPoints * 3);
1119
+
1120
+ const geometry = new THREE.BufferGeometry();
1121
+ geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
1122
+ geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
1123
+
1124
+ const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
1125
+ const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
1126
+
1127
+ const material = new THREE.PointsMaterial({
1128
+ size: pointSize,
1129
+ vertexColors: true,
1130
+ transparent: true,
1131
+ opacity: pointOpacity,
1132
+ sizeAttenuation: true
1133
+ });
1134
+
1135
+ this.pointCloud = new THREE.Points(geometry, material);
1136
+ this.scene.add(this.pointCloud);
1137
+ }
1138
+
1139
+ initTrajectories() {
1140
+ if (!this.data.trajectories) return;
1141
+
1142
+ this.trajectories.forEach(trajectory => {
1143
+ if (trajectory.userData.lineSegments) {
1144
+ trajectory.userData.lineSegments.forEach(segment => {
1145
+ segment.geometry.dispose();
1146
+ segment.material.dispose();
1147
+ });
1148
+ }
1149
+ this.scene.remove(trajectory);
1150
+ });
1151
+ this.trajectories = [];
1152
+
1153
+ const shape = this.data.trajectories.shape;
1154
+ if (!shape || shape.length < 2) return;
1155
+
1156
+ const [totalFrames, numTrajectories] = shape;
1157
+ const palette = this.createColorPalette(numTrajectories);
1158
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1159
+ const maxHistory = 500; // Max value of the history slider, for the object pool
1160
+
1161
+ for (let i = 0; i < numTrajectories; i++) {
1162
+ const trajectoryGroup = new THREE.Group();
1163
+
1164
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1165
+ const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
1166
+ const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
1167
+ const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
1168
+ trajectoryGroup.add(positionMarker);
1169
+
1170
+ // High-Performance Line (default)
1171
+ const simpleLineGeometry = new THREE.BufferGeometry();
1172
+ const simpleLinePositions = new Float32Array(maxHistory * 3);
1173
+ simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
1174
+ const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
1175
+ simpleLine.frustumCulled = false;
1176
+ trajectoryGroup.add(simpleLine);
1177
+
1178
+ // High-Quality Line Segments (for rich trail)
1179
+ const lineSegments = [];
1180
+ const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
1181
+
1182
+ // Create a pool of line segment objects
1183
+ for (let j = 0; j < maxHistory - 1; j++) {
1184
+ const lineGeometry = new THREE.LineGeometry();
1185
+ lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
1186
+ const lineMaterial = new THREE.LineMaterial({
1187
+ color: palette[i],
1188
+ linewidth: lineWidth,
1189
+ resolution: resolution,
1190
+ transparent: true,
1191
+ depthWrite: false, // Correctly handle transparency
1192
+ opacity: 0
1193
+ });
1194
+ const segment = new THREE.Line2(lineGeometry, lineMaterial);
1195
+ segment.frustumCulled = false;
1196
+ segment.visible = false; // Start with all segments hidden
1197
+ trajectoryGroup.add(segment);
1198
+ lineSegments.push(segment);
1199
+ }
1200
+
1201
+ trajectoryGroup.userData = {
1202
+ marker: positionMarker,
1203
+ simpleLine: simpleLine,
1204
+ lineSegments: lineSegments,
1205
+ color: palette[i]
1206
+ };
1207
+
1208
+ this.scene.add(trajectoryGroup);
1209
+ this.trajectories.push(trajectoryGroup);
1210
+ }
1211
+
1212
+ const showTrajectory = this.ui.showTrajectory.checked;
1213
+ this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
1214
+ }
1215
+
1216
+ createColorPalette(count) {
1217
+ const colors = [];
1218
+ const hueStep = 360 / count;
1219
+
1220
+ for (let i = 0; i < count; i++) {
1221
+ const hue = (i * hueStep) % 360;
1222
+ const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
1223
+ colors.push(color);
1224
+ }
1225
+
1226
+ return colors;
1227
+ }
1228
+
1229
+ updatePointCloud(frameIndex) {
1230
+ if (!this.data || !this.pointCloud) return;
1231
+
1232
+ const positions = this.pointCloud.geometry.attributes.position.array;
1233
+ const colors = this.pointCloud.geometry.attributes.color.array;
1234
+
1235
+ const rgbVideo = this.data.rgb_video;
1236
+ const depthsRgb = this.data.depths_rgb;
1237
+ const intrinsics = this.data.intrinsics;
1238
+ const invExtrinsics = this.data.inv_extrinsics;
1239
+
1240
+ const width = this.config.resolution[0];
1241
+ const height = this.config.resolution[1];
1242
+ const numPoints = width * height;
1243
+
1244
+ const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
1245
+ const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
1246
+
1247
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1248
+ const transform = this.getTransformElements(invExtrMat);
1249
+
1250
+ const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
1251
+ const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
1252
+
1253
+ const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
1254
+
1255
+ let validPointCount = 0;
1256
+
1257
+ for (let i = 0; i < numPoints; i++) {
1258
+ const xPix = i % width;
1259
+ const yPix = Math.floor(i / width);
1260
+
1261
+ const d0 = depthFrame[i * 3];
1262
+ const d1 = depthFrame[i * 3 + 1];
1263
+ const depthEncoded = d0 | (d1 << 8);
1264
+ const depthValue = (depthEncoded / ((1 << 16) - 1)) *
1265
+ (this.config.depthRange[1] - this.config.depthRange[0]) +
1266
+ this.config.depthRange[0];
1267
+
1268
+ if (depthValue === 0 || depthValue > maxDepth) {
1269
+ continue;
1270
+ }
1271
+
1272
+ const X = ((xPix - cx) * depthValue) / fx;
1273
+ const Y = ((yPix - cy) * depthValue) / fy;
1274
+ const Z = depthValue;
1275
+
1276
+ const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
1277
+ const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
1278
+ const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
1279
+
1280
+ const index = validPointCount * 3;
1281
+ positions[index] = tx;
1282
+ positions[index + 1] = -ty;
1283
+ positions[index + 2] = -tz;
1284
+
1285
+ colors[index] = rgbFrame[i * 3] / 255;
1286
+ colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
1287
+ colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
1288
+
1289
+ validPointCount++;
1290
+ }
1291
+
1292
+ this.pointCloud.geometry.setDrawRange(0, validPointCount);
1293
+ this.pointCloud.geometry.attributes.position.needsUpdate = true;
1294
+ this.pointCloud.geometry.attributes.color.needsUpdate = true;
1295
+ this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
1296
+
1297
+ this.updateTrajectories(frameIndex);
1298
+
1299
+ const progress = (frameIndex + 1) / this.config.totalFrames;
1300
+ this.ui.progress.style.width = `${progress * 100}%`;
1301
+
1302
+ if (this.ui.frameCounter && this.config.totalFrames) {
1303
+ this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
1304
+ }
1305
+
1306
+ this.updateCameraFrustum(frameIndex);
1307
+ }
1308
+
1309
+ updateTrajectories(frameIndex) {
1310
+ if (!this.data.trajectories || this.trajectories.length === 0) return;
1311
+
1312
+ const trajectoryData = this.data.trajectories.data;
1313
+ const [totalFrames, numTrajectories] = this.data.trajectories.shape;
1314
+ const historyFrames = parseInt(this.ui.trajectoryHistory.value);
1315
+ const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
1316
+
1317
+ const isRichMode = this.ui.enableRichTrail.checked;
1318
+
1319
+ for (let i = 0; i < numTrajectories; i++) {
1320
+ const trajectoryGroup = this.trajectories[i];
1321
+ const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
1322
+
1323
+ const currentPos = new THREE.Vector3();
1324
+ const currentOffset = (frameIndex * numTrajectories + i) * 3;
1325
+
1326
+ currentPos.x = trajectoryData[currentOffset];
1327
+ currentPos.y = -trajectoryData[currentOffset + 1];
1328
+ currentPos.z = -trajectoryData[currentOffset + 2];
1329
+
1330
+ marker.position.copy(currentPos);
1331
+ marker.material.opacity = 1.0;
1332
+
1333
+ const historyToShow = Math.min(historyFrames, frameIndex + 1);
1334
+
1335
+ if (isRichMode) {
1336
+ // --- High-Quality Mode ---
1337
+ simpleLine.visible = false;
1338
+
1339
+ for (let j = 0; j < lineSegments.length; j++) {
1340
+ const segment = lineSegments[j];
1341
+ if (j < historyToShow - 1) {
1342
+ const headFrame = frameIndex - j;
1343
+ const tailFrame = frameIndex - j - 1;
1344
+ const headOffset = (headFrame * numTrajectories + i) * 3;
1345
+ const tailOffset = (tailFrame * numTrajectories + i) * 3;
1346
+ const positions = [
1347
+ trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
1348
+ trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
1349
+ ];
1350
+ segment.geometry.setPositions(positions);
1351
+ const headOpacity = 1.0;
1352
+ const normalizedAge = j / Math.max(1, historyToShow - 2);
1353
+ const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
1354
+ segment.material.opacity = Math.max(0, alpha);
1355
+ segment.visible = true;
1356
+ } else {
1357
+ segment.visible = false;
1358
+ }
1359
+ }
1360
+ } else {
1361
+ // --- Performance Mode ---
1362
+ lineSegments.forEach(s => s.visible = false);
1363
+ simpleLine.visible = true;
1364
+
1365
+ const positions = simpleLine.geometry.attributes.position.array;
1366
+ for (let j = 0; j < historyToShow; j++) {
1367
+ const historyFrame = Math.max(0, frameIndex - j);
1368
+ const offset = (historyFrame * numTrajectories + i) * 3;
1369
+ positions[j * 3] = trajectoryData[offset];
1370
+ positions[j * 3 + 1] = -trajectoryData[offset + 1];
1371
+ positions[j * 3 + 2] = -trajectoryData[offset + 2];
1372
+ }
1373
+ simpleLine.geometry.setDrawRange(0, historyToShow);
1374
+ simpleLine.geometry.attributes.position.needsUpdate = true;
1375
+ }
1376
+ }
1377
+ }
1378
+
1379
+ updateTrajectorySettings() {
1380
+ if (!this.trajectories || this.trajectories.length === 0) return;
1381
+
1382
+ const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
1383
+ const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
1384
+
1385
+ this.trajectories.forEach(trajectoryGroup => {
1386
+ const { marker, lineSegments } = trajectoryGroup.userData;
1387
+
1388
+ marker.geometry.dispose();
1389
+ marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
1390
+
1391
+ // Line width only affects rich mode
1392
+ lineSegments.forEach(segment => {
1393
+ if (segment.material) {
1394
+ segment.material.linewidth = lineWidth;
1395
+ }
1396
+ });
1397
+ });
1398
+
1399
+ this.updateTrajectories(this.currentFrame);
1400
+ }
1401
+
1402
+ getDepthColor(normalizedDepth) {
1403
+ const hue = (1 - normalizedDepth) * 240 / 360;
1404
+ const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
1405
+ return color;
1406
+ }
1407
+
1408
+ getFrame(typedArray, shape, frameIndex) {
1409
+ const [T, H, W, C] = shape;
1410
+ const frameSize = H * W * C;
1411
+ const offset = frameIndex * frameSize;
1412
+ return typedArray.subarray(offset, offset + frameSize);
1413
+ }
1414
+
1415
+ get3x3Matrix(typedArray, shape, frameIndex) {
1416
+ const frameSize = 9;
1417
+ const offset = frameIndex * frameSize;
1418
+ const K = [];
1419
+ for (let i = 0; i < 3; i++) {
1420
+ const row = [];
1421
+ for (let j = 0; j < 3; j++) {
1422
+ row.push(typedArray[offset + i * 3 + j]);
1423
+ }
1424
+ K.push(row);
1425
+ }
1426
+ return K;
1427
+ }
1428
+
1429
+ get4x4Matrix(typedArray, shape, frameIndex) {
1430
+ const frameSize = 16;
1431
+ const offset = frameIndex * frameSize;
1432
+ const M = [];
1433
+ for (let i = 0; i < 4; i++) {
1434
+ const row = [];
1435
+ for (let j = 0; j < 4; j++) {
1436
+ row.push(typedArray[offset + i * 4 + j]);
1437
+ }
1438
+ M.push(row);
1439
+ }
1440
+ return M;
1441
+ }
1442
+
1443
+ getTransformElements(matrix) {
1444
+ return {
1445
+ m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
1446
+ m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
1447
+ m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
1448
+ };
1449
+ }
1450
+
1451
+ togglePlayback() {
1452
+ this.isPlaying = !this.isPlaying;
1453
+
1454
+ const playIcon = document.getElementById('play-icon');
1455
+ const pauseIcon = document.getElementById('pause-icon');
1456
+
1457
+ if (this.isPlaying) {
1458
+ playIcon.style.display = 'none';
1459
+ pauseIcon.style.display = 'block';
1460
+ this.lastFrameTime = performance.now();
1461
+ } else {
1462
+ playIcon.style.display = 'block';
1463
+ pauseIcon.style.display = 'none';
1464
+ }
1465
+ }
1466
+
1467
+ cyclePlaybackSpeed() {
1468
+ const speeds = [0.5, 1, 2, 4, 8];
1469
+ const speedRates = speeds.map(s => s * this.config.baseFrameRate);
1470
+
1471
+ let currentIndex = 0;
1472
+ const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
1473
+
1474
+ for (let i = 0; i < speeds.length; i++) {
1475
+ if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
1476
+ currentIndex = i;
1477
+ }
1478
+ }
1479
+
1480
+ const nextIndex = (currentIndex + 1) % speeds.length;
1481
+ this.playbackSpeed = speedRates[nextIndex];
1482
+ this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
1483
+
1484
+ if (speeds[nextIndex] === 1) {
1485
+ this.ui.speedBtn.classList.remove('active');
1486
+ } else {
1487
+ this.ui.speedBtn.classList.add('active');
1488
+ }
1489
+ }
1490
+
1491
+ seekTo(position) {
1492
+ const frameIndex = Math.floor(position * this.config.totalFrames);
1493
+ this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
1494
+ this.updatePointCloud(this.currentFrame);
1495
+ }
1496
+
1497
+ updatePointCloudSettings() {
1498
+ if (!this.pointCloud) return;
1499
+
1500
+ const size = parseFloat(this.ui.pointSize.value);
1501
+ const opacity = parseFloat(this.ui.pointOpacity.value);
1502
+
1503
+ this.pointCloud.material.size = size;
1504
+ this.pointCloud.material.opacity = opacity;
1505
+ this.pointCloud.material.needsUpdate = true;
1506
+
1507
+ this.updatePointCloud(this.currentFrame);
1508
+ }
1509
+
1510
+ updateControls() {
1511
+ if (!this.controls) return;
1512
+ this.controls.update();
1513
+ }
1514
+
1515
+ resetView() {
1516
+ if (!this.camera || !this.controls) return;
1517
+
1518
+ // Reset camera position
1519
+ this.camera.position.set(0, 0, this.config.cameraZ || 0);
1520
+
1521
+ // Reset controls
1522
+ this.controls.reset();
1523
+
1524
+ // Set target slightly in front of camera
1525
+ this.controls.target.set(0, 0, -1);
1526
+ this.controls.update();
1527
+
1528
+ // Show status message
1529
+ this.ui.statusBar.textContent = "View reset";
1530
+ this.ui.statusBar.classList.remove('hidden');
1531
+
1532
+ // Hide status message after a few seconds
1533
+ setTimeout(() => {
1534
+ this.ui.statusBar.classList.add('hidden');
1535
+ }, 3000);
1536
+ }
1537
+
1538
+ onWindowResize() {
1539
+ if (!this.camera || !this.renderer) return;
1540
+
1541
+ const windowAspect = window.innerWidth / window.innerHeight;
1542
+ this.camera.aspect = windowAspect;
1543
+ this.camera.updateProjectionMatrix();
1544
+ this.renderer.setSize(window.innerWidth, window.innerHeight);
1545
+
1546
+ if (this.trajectories && this.trajectories.length > 0) {
1547
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1548
+ this.trajectories.forEach(trajectory => {
1549
+ const { lineSegments } = trajectory.userData;
1550
+ if (lineSegments && lineSegments.length > 0) {
1551
+ lineSegments.forEach(segment => {
1552
+ if (segment.material && segment.material.resolution) {
1553
+ segment.material.resolution.copy(resolution);
1554
+ }
1555
+ });
1556
+ }
1557
+ });
1558
+ }
1559
+
1560
+ if (this.cameraFrustum) {
1561
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1562
+ this.cameraFrustum.children.forEach(line => {
1563
+ if (line.material && line.material.resolution) {
1564
+ line.material.resolution.copy(resolution);
1565
+ }
1566
+ });
1567
+ }
1568
+ }
1569
+
1570
+ startAnimation() {
1571
+ this.isPlaying = true;
1572
+ this.lastFrameTime = performance.now();
1573
+
1574
+ this.camera.position.set(0, 0, this.config.cameraZ || 0);
1575
+ this.controls.target.set(0, 0, -1);
1576
+ this.controls.update();
1577
+
1578
+ this.playbackSpeed = this.config.baseFrameRate;
1579
+
1580
+ document.getElementById('play-icon').style.display = 'none';
1581
+ document.getElementById('pause-icon').style.display = 'block';
1582
+
1583
+ this.animate();
1584
+ }
1585
+
1586
+ animate() {
1587
+ requestAnimationFrame(() => this.animate());
1588
+
1589
+ if (this.controls) {
1590
+ this.controls.update();
1591
+ }
1592
+
1593
+ if (this.isPlaying && this.data) {
1594
+ const now = performance.now();
1595
+ const delta = (now - this.lastFrameTime) / 1000;
1596
+
1597
+ const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
1598
+ if (framesToAdvance > 0) {
1599
+ this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
1600
+ this.lastFrameTime = now;
1601
+ this.updatePointCloud(this.currentFrame);
1602
+ }
1603
+ }
1604
+
1605
+ if (this.renderer && this.scene && this.camera) {
1606
+ this.renderer.render(this.scene, this.camera);
1607
+ }
1608
+ }
1609
+
1610
+ initCameraWithCorrectFOV() {
1611
+ const fov = this.config.fov || 60;
1612
+
1613
+ const windowAspect = window.innerWidth / window.innerHeight;
1614
+
1615
+ this.camera = new THREE.PerspectiveCamera(
1616
+ fov,
1617
+ windowAspect,
1618
+ 0.1,
1619
+ 10000
1620
+ );
1621
+
1622
+ this.controls.object = this.camera;
1623
+ this.controls.update();
1624
+
1625
+ this.initCameraFrustum();
1626
+ }
1627
+
1628
+ initCameraFrustum() {
1629
+ this.cameraFrustum = new THREE.Group();
1630
+
1631
+ this.scene.add(this.cameraFrustum);
1632
+
1633
+ this.initCameraFrustumGeometry();
1634
+
1635
+ const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
1636
+
1637
+ this.cameraFrustum.visible = showCameraFrustum;
1638
+ }
1639
+
1640
+ initCameraFrustumGeometry() {
1641
+ const fov = this.config.fov || 60;
1642
+ const originalAspect = this.config.original_aspect_ratio || 1.33;
1643
+
1644
+ const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
1645
+
1646
+ const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
1647
+ const halfWidth = halfHeight * originalAspect;
1648
+
1649
+ const vertices = [
1650
+ new THREE.Vector3(0, 0, 0),
1651
+ new THREE.Vector3(-halfWidth, -halfHeight, size),
1652
+ new THREE.Vector3(halfWidth, -halfHeight, size),
1653
+ new THREE.Vector3(halfWidth, halfHeight, size),
1654
+ new THREE.Vector3(-halfWidth, halfHeight, size)
1655
+ ];
1656
+
1657
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1658
+
1659
+ const linePairs = [
1660
+ [1, 2], [2, 3], [3, 4], [4, 1],
1661
+ [0, 1], [0, 2], [0, 3], [0, 4]
1662
+ ];
1663
+
1664
+ const colors = {
1665
+ edge: new THREE.Color(0x3366ff),
1666
+ ray: new THREE.Color(0x33cc66)
1667
+ };
1668
+
1669
+ linePairs.forEach((pair, index) => {
1670
+ const positions = [
1671
+ vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
1672
+ vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
1673
+ ];
1674
+
1675
+ const lineGeometry = new THREE.LineGeometry();
1676
+ lineGeometry.setPositions(positions);
1677
+
1678
+ let color = index < 4 ? colors.edge : colors.ray;
1679
+
1680
+ const lineMaterial = new THREE.LineMaterial({
1681
+ color: color,
1682
+ linewidth: 2,
1683
+ resolution: resolution,
1684
+ dashed: false
1685
+ });
1686
+
1687
+ const line = new THREE.Line2(lineGeometry, lineMaterial);
1688
+ this.cameraFrustum.add(line);
1689
+ });
1690
+ }
1691
+
1692
+ updateCameraFrustum(frameIndex) {
1693
+ if (!this.cameraFrustum || !this.data) return;
1694
+
1695
+ const invExtrinsics = this.data.inv_extrinsics;
1696
+ if (!invExtrinsics) return;
1697
+
1698
+ const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
1699
+
1700
+ const matrix = new THREE.Matrix4();
1701
+ matrix.set(
1702
+ invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
1703
+ invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
1704
+ invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
1705
+ invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
1706
+ );
1707
+
1708
+ const position = new THREE.Vector3();
1709
+ position.setFromMatrixPosition(matrix);
1710
+
1711
+ const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
1712
+
1713
+ const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
1714
+
1715
+ const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
1716
+
1717
+ const quaternion = new THREE.Quaternion();
1718
+ quaternion.setFromRotationMatrix(finalRotation);
1719
+
1720
+ position.y = -position.y;
1721
+ position.z = -position.z;
1722
+
1723
+ this.cameraFrustum.position.copy(position);
1724
+ this.cameraFrustum.quaternion.copy(quaternion);
1725
+
1726
+ const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
1727
+
1728
+ if (this.cameraFrustum.visible !== showCameraFrustum) {
1729
+ this.cameraFrustum.visible = showCameraFrustum;
1730
+ }
1731
+
1732
+ const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
1733
+ this.cameraFrustum.children.forEach(line => {
1734
+ if (line.material && line.material.resolution) {
1735
+ line.material.resolution.copy(resolution);
1736
+ }
1737
+ });
1738
+ }
1739
+
1740
+ updateFrustumDimensions() {
1741
+ if (!this.cameraFrustum) return;
1742
+
1743
+ while(this.cameraFrustum.children.length > 0) {
1744
+ const child = this.cameraFrustum.children[0];
1745
+ if (child.geometry) child.geometry.dispose();
1746
+ if (child.material) child.material.dispose();
1747
+ this.cameraFrustum.remove(child);
1748
+ }
1749
+
1750
+ this.initCameraFrustumGeometry();
1751
+
1752
+ this.updateCameraFrustum(this.currentFrame);
1753
+ }
1754
+
1755
+ resetSettings() {
1756
+ if (!this.defaultSettings) return;
1757
+
1758
+ this.applyDefaultSettings();
1759
+
1760
+ this.updatePointCloudSettings();
1761
+ this.updateTrajectorySettings();
1762
+ this.updateFrustumDimensions();
1763
+
1764
+ this.ui.statusBar.textContent = "Settings reset to defaults";
1765
+ this.ui.statusBar.classList.remove('hidden');
1766
+
1767
+ setTimeout(() => {
1768
+ this.ui.statusBar.classList.add('hidden');
1769
+ }, 3000);
1770
+ }
1771
+ }
1772
+
1773
+ window.addEventListener('DOMContentLoaded', () => {
1774
+ new PointCloudVisualizer();
1775
+ });
1776
+ </script>
1777
+ </body>
1778
+ </html>
app.py ADDED
@@ -0,0 +1,1118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import json
4
+ import numpy as np
5
+ import cv2
6
+ import base64
7
+ import time
8
+ import tempfile
9
+ import shutil
10
+ import glob
11
+ import threading
12
+ import subprocess
13
+ import struct
14
+ import zlib
15
+ from pathlib import Path
16
+ from einops import rearrange
17
+ from typing import List, Tuple, Union
18
+ try:
19
+ import spaces
20
+ except ImportError:
21
+ # Fallback for local development
22
+ def spaces(func):
23
+ return func
24
+ import torch
25
+ import logging
26
+ from concurrent.futures import ThreadPoolExecutor
27
+ import atexit
28
+ import uuid
29
+
30
+ # Configure logging
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger = logging.getLogger(__name__)
33
+
34
+ # Import custom modules with error handling
35
+ try:
36
+ from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
37
+ from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
38
+ except ImportError as e:
39
+ logger.error(f"Failed to import custom modules: {e}")
40
+ raise
41
+
42
+ # Constants
43
+ MAX_FRAMES = 80
44
+ COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
45
+ MARKERS = [1, 5] # Cross for negative, Star for positive
46
+ MARKER_SIZE = 8
47
+
48
+ # Thread pool for delayed deletion
49
+ thread_pool_executor = ThreadPoolExecutor(max_workers=2)
50
+
51
+ def delete_later(path: Union[str, os.PathLike], delay: int = 600):
52
+ """Delete file or directory after specified delay (default 10 minutes)"""
53
+ def _delete():
54
+ try:
55
+ if os.path.isfile(path):
56
+ os.remove(path)
57
+ elif os.path.isdir(path):
58
+ shutil.rmtree(path)
59
+ except Exception as e:
60
+ logger.warning(f"Failed to delete {path}: {e}")
61
+
62
+ def _wait_and_delete():
63
+ time.sleep(delay)
64
+ _delete()
65
+
66
+ thread_pool_executor.submit(_wait_and_delete)
67
+ atexit.register(_delete)
68
+
69
+ def create_user_temp_dir():
70
+ """Create a unique temporary directory for each user session"""
71
+ session_id = str(uuid.uuid4())[:8] # Short unique ID
72
+ temp_dir = os.path.join("temp_local", f"session_{session_id}")
73
+ os.makedirs(temp_dir, exist_ok=True)
74
+
75
+ # Schedule deletion after 10 minutes
76
+ delete_later(temp_dir, delay=600)
77
+
78
+ return temp_dir
79
+
80
+ from huggingface_hub import hf_hub_download
81
+ # init the model
82
+ os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
83
+
84
+ if os.environ.get("VGGT_DIR", None) is not None:
85
+ from models.vggt.vggt.models.vggt_moe import VGGT_MoE
86
+ from models.vggt.vggt.utils.load_fn import preprocess_image
87
+ vggt_model = VGGT_MoE()
88
+ vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
89
+ vggt_model.eval()
90
+ vggt_model = vggt_model.to("cuda")
91
+
92
+ # Global model initialization
93
+ print("🚀 Initializing local models...")
94
+ tracker_model, _ = get_tracker_predictor(".", vo_points=756)
95
+ predictor = get_sam_predictor()
96
+ print("✅ Models loaded successfully!")
97
+
98
+ gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
99
+
100
+ @spaces.GPU
101
+ def gpu_run_inference(predictor_arg, image, points, boxes):
102
+ """GPU-accelerated SAM inference"""
103
+ if predictor_arg is None:
104
+ print("Initializing SAM predictor inside GPU function...")
105
+ predictor_arg = get_sam_predictor(predictor=predictor)
106
+
107
+ # Ensure predictor is on GPU
108
+ try:
109
+ if hasattr(predictor_arg, 'model'):
110
+ predictor_arg.model = predictor_arg.model.cuda()
111
+ elif hasattr(predictor_arg, 'sam'):
112
+ predictor_arg.sam = predictor_arg.sam.cuda()
113
+ elif hasattr(predictor_arg, 'to'):
114
+ predictor_arg = predictor_arg.to('cuda')
115
+
116
+ if hasattr(image, 'cuda'):
117
+ image = image.cuda()
118
+
119
+ except Exception as e:
120
+ print(f"Warning: Could not move predictor to GPU: {e}")
121
+
122
+ return run_inference(predictor_arg, image, points, boxes)
123
+
124
+ @spaces.GPU
125
+ def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
126
+ """GPU-accelerated tracking"""
127
+ import torchvision.transforms as T
128
+ import decord
129
+
130
+ if tracker_model_arg is None or tracker_viser_arg is None:
131
+ print("Initializing tracker models inside GPU function...")
132
+ out_dir = os.path.join(temp_dir, "results")
133
+ os.makedirs(out_dir, exist_ok=True)
134
+ tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
135
+
136
+ # Setup paths
137
+ video_path = os.path.join(temp_dir, f"{video_name}.mp4")
138
+ mask_path = os.path.join(temp_dir, f"{video_name}.png")
139
+ out_dir = os.path.join(temp_dir, "results")
140
+ os.makedirs(out_dir, exist_ok=True)
141
+
142
+ # Load video using decord
143
+ video_reader = decord.VideoReader(video_path)
144
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
145
+
146
+ # Resize to ensure minimum side is 336
147
+ h, w = video_tensor.shape[2:]
148
+ scale = max(224 / h, 224 / w)
149
+ if scale < 1:
150
+ new_h, new_w = int(h * scale), int(w * scale)
151
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
152
+ video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
153
+
154
+ # Move to GPU
155
+ video_tensor = video_tensor.cuda()
156
+ print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
157
+
158
+ depth_tensor = None
159
+ intrs = None
160
+ extrs = None
161
+ data_npz_load = {}
162
+
163
+ # run vggt
164
+ if os.environ.get("VGGT_DIR", None) is not None:
165
+ # process the image tensor
166
+ video_tensor = preprocess_image(video_tensor)[None]
167
+ with torch.no_grad():
168
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
169
+ # Predict attributes including cameras, depth maps, and point maps.
170
+ predictions = vggt_model(video_tensor.cuda()/255)
171
+ extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
172
+ depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
173
+
174
+ depth_tensor = depth_map.squeeze().cpu().numpy()
175
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
176
+ extrs = extrinsic.squeeze().cpu().numpy()
177
+ intrs = intrinsic.squeeze().cpu().numpy()
178
+ video_tensor = video_tensor.squeeze()
179
+ #NOTE: 20% of the depth is not reliable
180
+ # threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
181
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
182
+
183
+ # Load and process mask
184
+ if os.path.exists(mask_path):
185
+ mask = cv2.imread(mask_path)
186
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
187
+ mask = mask.sum(axis=-1)>0
188
+ else:
189
+ mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
190
+ grid_size = 10
191
+
192
+ # Get frame dimensions and create grid points
193
+ frame_H, frame_W = video_tensor.shape[2:]
194
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
195
+
196
+ # Sample mask values at grid points and filter
197
+ if os.path.exists(mask_path):
198
+ grid_pts_int = grid_pts[0].long()
199
+ mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
200
+ grid_pts = grid_pts[:, mask_values]
201
+
202
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
203
+ print(f"Query points shape: {query_xyt.shape}")
204
+
205
+ # Run model inference
206
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
207
+ (
208
+ c2w_traj, intrs, point_map, conf_depth,
209
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
210
+ ) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
211
+ intrs=intrs, extrs=extrs,
212
+ queries=query_xyt,
213
+ fps=1, full_point=False, iters_track=4,
214
+ query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
215
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
216
+
217
+ # Resize results to avoid large I/O
218
+ max_size = 224
219
+ h, w = video.shape[2:]
220
+ scale = min(max_size / h, max_size / w)
221
+ if scale < 1:
222
+ new_h, new_w = int(h * scale), int(w * scale)
223
+ video = T.Resize((new_h, new_w))(video)
224
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
225
+ point_map = T.Resize((new_h, new_w))(point_map)
226
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
227
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
228
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
229
+
230
+ # Visualize tracks
231
+ tracker_viser_arg.visualize(video=video[None],
232
+ tracks=track2d_pred[None][...,:2],
233
+ visibility=vis_pred[None],filename="test")
234
+
235
+ # Save in tapip3d format
236
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
237
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
238
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
239
+ data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
240
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
241
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
242
+ data_npz_load["confs"] = conf_pred.cpu().numpy()
243
+ data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
244
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
245
+
246
+ return None
247
+
248
+ def compress_and_write(filename, header, blob):
249
+ header_bytes = json.dumps(header).encode("utf-8")
250
+ header_len = struct.pack("<I", len(header_bytes))
251
+ with open(filename, "wb") as f:
252
+ f.write(header_len)
253
+ f.write(header_bytes)
254
+ f.write(blob)
255
+
256
+ def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
257
+ fixed_size = (width, height)
258
+
259
+ data = np.load(npz_file)
260
+ extrinsics = data["extrinsics"]
261
+ intrinsics = data["intrinsics"]
262
+ trajs = data["coords"]
263
+ T, C, H, W = data["video"].shape
264
+
265
+ fx = intrinsics[0, 0, 0]
266
+ fy = intrinsics[0, 1, 1]
267
+ fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
268
+ fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
269
+ original_aspect_ratio = (W / fx) / (H / fy)
270
+
271
+ rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
272
+ rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
273
+ for frame in rgb_video])
274
+
275
+ depth_video = data["depths"].astype(np.float32)
276
+ if "confs_depth" in data.keys():
277
+ confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
278
+ depth_video = depth_video * confs
279
+ depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
280
+ for frame in depth_video])
281
+
282
+ scale_x = fixed_size[0] / W
283
+ scale_y = fixed_size[1] / H
284
+ intrinsics = intrinsics.copy()
285
+ intrinsics[:, 0, :] *= scale_x
286
+ intrinsics[:, 1, :] *= scale_y
287
+
288
+ min_depth = float(depth_video.min()) * 0.8
289
+ max_depth = float(depth_video.max()) * 1.5
290
+
291
+ depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
292
+ depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
293
+
294
+ depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
295
+ depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
296
+ depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
297
+
298
+ first_frame_inv = np.linalg.inv(extrinsics[0])
299
+ normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
300
+
301
+ normalized_trajs = np.zeros_like(trajs)
302
+ for t in range(T):
303
+ homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
304
+ transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
305
+ normalized_trajs[t] = transformed_trajs[:, :3]
306
+
307
+ arrays = {
308
+ "rgb_video": rgb_video,
309
+ "depths_rgb": depths_rgb,
310
+ "intrinsics": intrinsics,
311
+ "extrinsics": normalized_extrinsics,
312
+ "inv_extrinsics": np.linalg.inv(normalized_extrinsics),
313
+ "trajectories": normalized_trajs.astype(np.float32),
314
+ "cameraZ": 0.0
315
+ }
316
+
317
+ header = {}
318
+ blob_parts = []
319
+ offset = 0
320
+ for key, arr in arrays.items():
321
+ arr = np.ascontiguousarray(arr)
322
+ arr_bytes = arr.tobytes()
323
+ header[key] = {
324
+ "dtype": str(arr.dtype),
325
+ "shape": arr.shape,
326
+ "offset": offset,
327
+ "length": len(arr_bytes)
328
+ }
329
+ blob_parts.append(arr_bytes)
330
+ offset += len(arr_bytes)
331
+
332
+ raw_blob = b"".join(blob_parts)
333
+ compressed_blob = zlib.compress(raw_blob, level=9)
334
+
335
+ header["meta"] = {
336
+ "depthRange": [min_depth, max_depth],
337
+ "totalFrames": int(T),
338
+ "resolution": fixed_size,
339
+ "baseFrameRate": fps,
340
+ "numTrajectoryPoints": normalized_trajs.shape[1],
341
+ "fov": float(fov_y),
342
+ "fov_x": float(fov_x),
343
+ "original_aspect_ratio": float(original_aspect_ratio),
344
+ "fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
345
+ }
346
+
347
+ compress_and_write('./_viz/data.bin', header, compressed_blob)
348
+ with open('./_viz/data.bin', "rb") as f:
349
+ encoded_blob = base64.b64encode(f.read()).decode("ascii")
350
+ os.unlink('./_viz/data.bin')
351
+
352
+ random_path = f'./_viz/_{time.time()}.html'
353
+ with open('./_viz/viz_template.html') as f:
354
+ html_template = f.read()
355
+ html_out = html_template.replace(
356
+ "<head>",
357
+ f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
358
+ )
359
+ with open(random_path,'w') as f:
360
+ f.write(html_out)
361
+
362
+ return random_path
363
+
364
+ def numpy_to_base64(arr):
365
+ """Convert numpy array to base64 string"""
366
+ return base64.b64encode(arr.tobytes()).decode('utf-8')
367
+
368
+ def base64_to_numpy(b64_str, shape, dtype):
369
+ """Convert base64 string back to numpy array"""
370
+ return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
371
+
372
+ def get_video_name(video_path):
373
+ """Extract video name without extension"""
374
+ return os.path.splitext(os.path.basename(video_path))[0]
375
+
376
+ def extract_first_frame(video_path):
377
+ """Extract first frame from video file"""
378
+ try:
379
+ cap = cv2.VideoCapture(video_path)
380
+ ret, frame = cap.read()
381
+ cap.release()
382
+
383
+ if ret:
384
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
385
+ return frame_rgb
386
+ else:
387
+ return None
388
+ except Exception as e:
389
+ print(f"Error extracting first frame: {e}")
390
+ return None
391
+
392
+ def handle_video_upload(video):
393
+ """Handle video upload and extract first frame"""
394
+ if video is None:
395
+ return (None, None, [],
396
+ gr.update(value=50),
397
+ gr.update(value=756),
398
+ gr.update(value=3))
399
+
400
+ # Create user-specific temporary directory
401
+ user_temp_dir = create_user_temp_dir()
402
+
403
+ # Get original video name and copy to temp directory
404
+ if isinstance(video, str):
405
+ video_name = get_video_name(video)
406
+ video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
407
+ shutil.copy(video, video_path)
408
+ else:
409
+ video_name = get_video_name(video.name)
410
+ video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
411
+ with open(video_path, 'wb') as f:
412
+ f.write(video.read())
413
+
414
+ print(f"📁 Video saved to: {video_path}")
415
+
416
+ # Extract first frame
417
+ frame = extract_first_frame(video_path)
418
+ if frame is None:
419
+ return (None, None, [],
420
+ gr.update(value=50),
421
+ gr.update(value=756),
422
+ gr.update(value=3))
423
+
424
+ # Resize frame to have minimum side length of 336
425
+ h, w = frame.shape[:2]
426
+ scale = 336 / min(h, w)
427
+ new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
428
+ frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
429
+
430
+ # Store frame data with temp directory info
431
+ frame_data = {
432
+ 'data': numpy_to_base64(frame),
433
+ 'shape': frame.shape,
434
+ 'dtype': str(frame.dtype),
435
+ 'temp_dir': user_temp_dir,
436
+ 'video_name': video_name,
437
+ 'video_path': video_path
438
+ }
439
+
440
+ # Get video-specific settings
441
+ print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
442
+ grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
443
+ print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
444
+
445
+ return (json.dumps(frame_data), frame, [],
446
+ gr.update(value=grid_size_val),
447
+ gr.update(value=vo_points_val),
448
+ gr.update(value=fps_val))
449
+
450
+ def save_masks(o_masks, video_name, temp_dir):
451
+ """Save binary masks to files in user-specific temp directory"""
452
+ o_files = []
453
+ for mask, _ in o_masks:
454
+ o_mask = np.uint8(mask.squeeze() * 255)
455
+ o_file = os.path.join(temp_dir, f"{video_name}.png")
456
+ cv2.imwrite(o_file, o_mask)
457
+ o_files.append(o_file)
458
+ return o_files
459
+
460
+ def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
461
+ """Handle point selection for SAM"""
462
+ if original_img is None:
463
+ return None, []
464
+
465
+ try:
466
+ # Convert stored image data back to numpy array
467
+ frame_data = json.loads(original_img)
468
+ original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
469
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
470
+ video_name = frame_data.get('video_name', 'video')
471
+
472
+ # Create a display image for visualization
473
+ display_img = original_img_array.copy()
474
+ new_sel_pix = sel_pix.copy() if sel_pix else []
475
+ new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
476
+
477
+ print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
478
+ # Run SAM inference
479
+ o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
480
+
481
+ # Draw points on display image
482
+ for point, label in new_sel_pix:
483
+ cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
484
+
485
+ # Draw mask overlay on display image
486
+ if o_masks:
487
+ mask = o_masks[0][0]
488
+ overlay = display_img.copy()
489
+ overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
490
+ display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
491
+
492
+ # Save mask for tracking
493
+ save_masks(o_masks, video_name, temp_dir)
494
+ print(f"✅ Mask saved for video: {video_name}")
495
+
496
+ return display_img, new_sel_pix
497
+
498
+ except Exception as e:
499
+ print(f"❌ Error in select_point: {e}")
500
+ return None, []
501
+
502
+ def reset_points(original_img: str, sel_pix):
503
+ """Reset all points and clear the mask"""
504
+ if original_img is None:
505
+ return None, []
506
+
507
+ try:
508
+ # Convert stored image data back to numpy array
509
+ frame_data = json.loads(original_img)
510
+ original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
511
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
512
+
513
+ # Create a display image (just the original image)
514
+ display_img = original_img_array.copy()
515
+
516
+ # Clear all points
517
+ new_sel_pix = []
518
+
519
+ # Clear any existing masks
520
+ for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
521
+ try:
522
+ os.remove(mask_file)
523
+ except Exception as e:
524
+ logger.warning(f"Failed to remove mask file {mask_file}: {e}")
525
+
526
+ print("🔄 Points and masks reset")
527
+ return display_img, new_sel_pix
528
+
529
+ except Exception as e:
530
+ print(f"❌ Error in reset_points: {e}")
531
+ return None, []
532
+
533
+ def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
534
+ """Launch visualization with user-specific temp directory"""
535
+ if original_image_state is None:
536
+ return None, None, None
537
+
538
+ try:
539
+ # Get user's temp directory from stored frame data
540
+ frame_data = json.loads(original_image_state)
541
+ temp_dir = frame_data.get('temp_dir', 'temp_local')
542
+ video_name = frame_data.get('video_name', 'video')
543
+
544
+ print(f"🚀 Starting tracking for video: {video_name}")
545
+ print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
546
+
547
+ # Check for mask files
548
+ mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
549
+ video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
550
+
551
+ if not video_files:
552
+ print("❌ No video file found")
553
+ return "❌ Error: No video file found", None, None
554
+
555
+ video_path = video_files[0]
556
+ mask_path = mask_files[0] if mask_files else None
557
+
558
+ # Run tracker
559
+ print("🎯 Running tracker...")
560
+ out_dir = os.path.join(temp_dir, "results")
561
+ os.makedirs(out_dir, exist_ok=True)
562
+
563
+ gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
564
+
565
+ # Process results
566
+ npz_path = os.path.join(out_dir, "result.npz")
567
+ track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
568
+
569
+ if os.path.exists(npz_path):
570
+ print("📊 Processing 3D visualization...")
571
+ html_path = process_point_cloud_data(npz_path)
572
+
573
+ # Schedule deletion of generated files
574
+ delete_later(html_path, delay=600)
575
+ if os.path.exists(track2d_video):
576
+ delete_later(track2d_video, delay=600)
577
+ delete_later(npz_path, delay=600)
578
+
579
+ # Create iframe HTML
580
+ iframe_html = f"""
581
+ <div style='border: 3px solid #667eea; border-radius: 10px;
582
+ background: #f8f9ff; height: 650px; width: 100%;
583
+ box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
584
+ margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
585
+ <iframe id="viz_iframe" src="/gradio_api/file={html_path}"
586
+ width="100%" height="650" frameborder="0"
587
+ style="border: none; display: block; width: 100%; height: 650px;
588
+ margin: 0; padding: 0; border-radius: 7px;">
589
+ </iframe>
590
+ </div>
591
+ """
592
+
593
+ print("✅ Tracking completed successfully!")
594
+ return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
595
+ else:
596
+ print("❌ Tracking failed - no results generated")
597
+ return "❌ Error: Tracking failed to generate results", None, None
598
+
599
+ except Exception as e:
600
+ print(f"❌ Error in launch_viz: {e}")
601
+ return f"❌ Error: {str(e)}", None, None
602
+
603
+ def clear_all():
604
+ """Clear all buffers and temporary files"""
605
+ return (None, None, [],
606
+ gr.update(value=50),
607
+ gr.update(value=756),
608
+ gr.update(value=3))
609
+
610
+ def clear_all_with_download():
611
+ """Clear all buffers including both download components"""
612
+ return (None, None, [],
613
+ gr.update(value=50),
614
+ gr.update(value=756),
615
+ gr.update(value=3),
616
+ None, # tracking_video_download
617
+ None) # HTML download component
618
+
619
+ def get_video_settings(video_name):
620
+ """Get video-specific settings based on video name"""
621
+ video_settings = {
622
+ "running": (50, 512, 2),
623
+ "backpack": (40, 600, 2),
624
+ "kitchen": (60, 800, 3),
625
+ "pillow": (35, 500, 2),
626
+ "handwave": (35, 500, 8),
627
+ "hockey": (45, 700, 2),
628
+ "drifting": (35, 1000, 6),
629
+ "basketball": (45, 1500, 5),
630
+ "ken_block_0": (45, 700, 2),
631
+ "ego_kc1": (45, 500, 4),
632
+ "vertical_place": (45, 500, 3),
633
+ "ego_teaser": (45, 1200, 10),
634
+ "robot_unitree": (45, 500, 4),
635
+ "robot_3": (35, 400, 5),
636
+ "teleop2": (45, 256, 7),
637
+ "pusht": (45, 256, 10),
638
+ "cinema_0": (45, 356, 5),
639
+ "cinema_1": (45, 756, 3),
640
+ "robot1": (45, 600, 2),
641
+ "robot2": (45, 600, 2),
642
+ "protein": (45, 600, 2),
643
+ "kitchen_egocentric": (45, 600, 2),
644
+ }
645
+
646
+ return video_settings.get(video_name, (50, 756, 3))
647
+
648
+ # Create the Gradio interface
649
+ print("🎨 Creating Gradio interface...")
650
+
651
+ with gr.Blocks(
652
+ theme=gr.themes.Soft(),
653
+ title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
654
+ css="""
655
+ .gradio-container {
656
+ max-width: 1200px !important;
657
+ margin: auto !important;
658
+ }
659
+ .gr-button {
660
+ margin: 5px;
661
+ }
662
+ .gr-form {
663
+ background: white;
664
+ border-radius: 10px;
665
+ padding: 20px;
666
+ box-shadow: 0 2px 10px rgba(0,0,0,0.1);
667
+ }
668
+ /* 移除 gr.Group 的默认灰色背景 */
669
+ .gr-form {
670
+ background: transparent !important;
671
+ border: none !important;
672
+ box-shadow: none !important;
673
+ padding: 0 !important;
674
+ }
675
+ /* 固定3D可视化器尺寸 */
676
+ #viz_container {
677
+ height: 650px !important;
678
+ min-height: 650px !important;
679
+ max-height: 650px !important;
680
+ width: 100% !important;
681
+ margin: 0 !important;
682
+ padding: 0 !important;
683
+ overflow: hidden !important;
684
+ }
685
+ #viz_container > div {
686
+ height: 650px !important;
687
+ min-height: 650px !important;
688
+ max-height: 650px !important;
689
+ width: 100% !important;
690
+ margin: 0 !important;
691
+ padding: 0 !important;
692
+ box-sizing: border-box !important;
693
+ }
694
+ #viz_container iframe {
695
+ height: 650px !important;
696
+ min-height: 650px !important;
697
+ max-height: 650px !important;
698
+ width: 100% !important;
699
+ border: none !important;
700
+ display: block !important;
701
+ margin: 0 !important;
702
+ padding: 0 !important;
703
+ box-sizing: border-box !important;
704
+ }
705
+ /* 固定视频上传组件高度 */
706
+ .gr-video {
707
+ height: 300px !important;
708
+ min-height: 300px !important;
709
+ max-height: 300px !important;
710
+ }
711
+ .gr-video video {
712
+ height: 260px !important;
713
+ max-height: 260px !important;
714
+ object-fit: contain !important;
715
+ background: #f8f9fa;
716
+ }
717
+ .gr-video .gr-video-player {
718
+ height: 260px !important;
719
+ max-height: 260px !important;
720
+ }
721
+ /* 强力移除examples的灰色背景 - 使用更通用的选择器 */
722
+ .horizontal-examples,
723
+ .horizontal-examples > *,
724
+ .horizontal-examples * {
725
+ background: transparent !important;
726
+ background-color: transparent !important;
727
+ border: none !important;
728
+ }
729
+
730
+ /* Examples组件水平滚动样式 */
731
+ .horizontal-examples [data-testid="examples"] {
732
+ background: transparent !important;
733
+ background-color: transparent !important;
734
+ }
735
+
736
+ .horizontal-examples [data-testid="examples"] > div {
737
+ background: transparent !important;
738
+ background-color: transparent !important;
739
+ overflow-x: auto !important;
740
+ overflow-y: hidden !important;
741
+ scrollbar-width: thin;
742
+ scrollbar-color: #667eea transparent;
743
+ padding: 0 !important;
744
+ margin-top: 10px;
745
+ border: none !important;
746
+ }
747
+
748
+ .horizontal-examples [data-testid="examples"] table {
749
+ display: flex !important;
750
+ flex-wrap: nowrap !important;
751
+ min-width: max-content !important;
752
+ gap: 15px !important;
753
+ padding: 10px 0;
754
+ background: transparent !important;
755
+ border: none !important;
756
+ }
757
+
758
+ .horizontal-examples [data-testid="examples"] tbody {
759
+ display: flex !important;
760
+ flex-direction: row !important;
761
+ flex-wrap: nowrap !important;
762
+ gap: 15px !important;
763
+ background: transparent !important;
764
+ }
765
+
766
+ .horizontal-examples [data-testid="examples"] tr {
767
+ display: flex !important;
768
+ flex-direction: column !important;
769
+ min-width: 160px !important;
770
+ max-width: 160px !important;
771
+ margin: 0 !important;
772
+ background: white !important;
773
+ border-radius: 12px;
774
+ box-shadow: 0 3px 12px rgba(0,0,0,0.12);
775
+ transition: all 0.3s ease;
776
+ cursor: pointer;
777
+ overflow: hidden;
778
+ border: none !important;
779
+ }
780
+
781
+ .horizontal-examples [data-testid="examples"] tr:hover {
782
+ transform: translateY(-4px);
783
+ box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
784
+ }
785
+
786
+ .horizontal-examples [data-testid="examples"] td {
787
+ text-align: center !important;
788
+ padding: 0 !important;
789
+ border: none !important;
790
+ background: transparent !important;
791
+ }
792
+
793
+ .horizontal-examples [data-testid="examples"] td:first-child {
794
+ padding: 0 !important;
795
+ background: transparent !important;
796
+ }
797
+
798
+ .horizontal-examples [data-testid="examples"] video {
799
+ border-radius: 8px 8px 0 0 !important;
800
+ width: 100% !important;
801
+ height: 90px !important;
802
+ object-fit: cover !important;
803
+ background: #f8f9fa !important;
804
+ }
805
+
806
+ .horizontal-examples [data-testid="examples"] td:last-child {
807
+ font-size: 11px !important;
808
+ font-weight: 600 !important;
809
+ color: #333 !important;
810
+ padding: 8px 12px !important;
811
+ background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%) !important;
812
+ border-radius: 0 0 8px 8px;
813
+ }
814
+
815
+ /* 滚动条样式 */
816
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar {
817
+ height: 8px;
818
+ }
819
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-track {
820
+ background: transparent;
821
+ border-radius: 4px;
822
+ }
823
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb {
824
+ background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
825
+ border-radius: 4px;
826
+ }
827
+ .horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb:hover {
828
+ background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
829
+ }
830
+ """
831
+ ) as demo:
832
+
833
+ # Add prominent main title
834
+
835
+ gr.Markdown("""
836
+ # ✨ SpatialTrackerV2
837
+
838
+ Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
839
+
840
+ **⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
841
+
842
+ **🔬 Advanced Usage with SAM:**
843
+ 1. Upload a video file or select from examples below
844
+ 2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
845
+ 3. Adjust tracking parameters for optimal performance
846
+ 4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
847
+
848
+ """)
849
+
850
+ # Status indicator
851
+ gr.Markdown("**Status:** 🟢 Local Processing Mode")
852
+
853
+ # Main content area - video upload left, 3D visualization right
854
+ with gr.Row():
855
+ with gr.Column(scale=1):
856
+ # Video upload section
857
+ gr.Markdown("### 📂 Select Video")
858
+
859
+ # Define video_input here so it can be referenced in examples
860
+ video_input = gr.Video(
861
+ label="Upload Video or Select Example",
862
+ format="mp4",
863
+ height=250 # Matched height with 3D viz
864
+ )
865
+
866
+
867
+ # Traditional examples but with horizontal scroll styling
868
+ gr.Markdown("🎨**Examples:** (scroll horizontally to see all videos)")
869
+ with gr.Row(elem_classes=["horizontal-examples"]):
870
+ # Horizontal video examples with slider
871
+ # gr.HTML("<div style='margin-top: 5px;'></div>")
872
+ gr.Examples(
873
+ examples=[
874
+ ["./examples/robot1.mp4"],
875
+ ["./examples/robot2.mp4"],
876
+ ["./examples/protein.mp4"],
877
+ ["./examples/kitchen_egocentric.mp4"],
878
+ ["./examples/hockey.mp4"],
879
+ ["./examples/running.mp4"],
880
+ ["./examples/robot_3.mp4"],
881
+ ["./examples/backpack.mp4"],
882
+ ["./examples/kitchen.mp4"],
883
+ ["./examples/pillow.mp4"],
884
+ ["./examples/handwave.mp4"],
885
+ ["./examples/drifting.mp4"],
886
+ ["./examples/basketball.mp4"],
887
+ ["./examples/ken_block_0.mp4"],
888
+ ["./examples/ego_kc1.mp4"],
889
+ ["./examples/vertical_place.mp4"],
890
+ ["./examples/ego_teaser.mp4"],
891
+ ["./examples/robot_unitree.mp4"],
892
+ ["./examples/teleop2.mp4"],
893
+ ["./examples/pusht.mp4"],
894
+ ["./examples/cinema_0.mp4"],
895
+ ["./examples/cinema_1.mp4"],
896
+ ],
897
+ inputs=[video_input],
898
+ outputs=[video_input],
899
+ fn=None,
900
+ cache_examples=False,
901
+ label="",
902
+ examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
903
+ )
904
+
905
+ with gr.Column(scale=2):
906
+ # 3D Visualization - wider and taller to match left side
907
+ with gr.Group():
908
+ gr.Markdown("### 🌐 3D Trajectory Visualization")
909
+ viz_html = gr.HTML(
910
+ label="3D Trajectory Visualization",
911
+ value="""
912
+ <div style='border: 3px solid #667eea; border-radius: 10px;
913
+ background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
914
+ text-align: center; height: 650px; display: flex;
915
+ flex-direction: column; justify-content: center; align-items: center;
916
+ box-shadow: 0 4px 16px rgba(102, 126, 234, 0.15);
917
+ margin: 0; padding: 20px; box-sizing: border-box;'>
918
+ <div style='font-size: 56px; margin-bottom: 25px;'>🌐</div>
919
+ <h3 style='color: #667eea; margin-bottom: 18px; font-size: 28px; font-weight: 600;'>
920
+ 3D Trajectory Visualization
921
+ </h3>
922
+ <p style='color: #666; font-size: 18px; line-height: 1.6; max-width: 550px; margin-bottom: 30px;'>
923
+ Track any pixels in 3D space with camera motion
924
+ </p>
925
+ <div style='background: rgba(102, 126, 234, 0.1); border-radius: 30px;
926
+ padding: 15px 30px; border: 1px solid rgba(102, 126, 234, 0.2);'>
927
+ <span style='color: #667eea; font-weight: 600; font-size: 16px;'>
928
+ ⚡ Powered by SpatialTracker V2
929
+ </span>
930
+ </div>
931
+ </div>
932
+ """,
933
+ elem_id="viz_container"
934
+ )
935
+
936
+ # Start button section - below video area
937
+ with gr.Row():
938
+ with gr.Column(scale=3):
939
+ launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
940
+ with gr.Column(scale=1):
941
+ clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
942
+
943
+ # Tracking parameters section
944
+ with gr.Row():
945
+ gr.Markdown("### ⚙️ Tracking Parameters")
946
+ with gr.Row():
947
+ grid_size = gr.Slider(
948
+ minimum=10, maximum=100, step=10, value=50,
949
+ label="Grid Size", info="Tracking detail level"
950
+ )
951
+ vo_points = gr.Slider(
952
+ minimum=100, maximum=2000, step=50, value=756,
953
+ label="VO Points", info="Motion accuracy"
954
+ )
955
+ fps = gr.Slider(
956
+ minimum=1, maximum=20, step=1, value=3,
957
+ label="FPS", info="Processing speed"
958
+ )
959
+
960
+ # Advanced Point Selection with SAM - Collapsed by default
961
+ with gr.Row():
962
+ gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
963
+ with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
964
+ gr.HTML("""
965
+ <div style='margin-bottom: 15px;'>
966
+ <ul style='color: #4a5568; font-size: 14px; line-height: 1.6; margin: 0; padding-left: 20px;'>
967
+ <li>Click on target objects in the image for SAM-guided segmentation</li>
968
+ <li>Positive points: include these areas | Negative points: exclude these areas</li>
969
+ <li>Get more accurate 3D tracking results with SAM's powerful segmentation</li>
970
+ </ul>
971
+ </div>
972
+ """)
973
+
974
+ with gr.Row():
975
+ with gr.Column():
976
+ interactive_frame = gr.Image(
977
+ label="Click to select tracking points with SAM guidance",
978
+ type="numpy",
979
+ interactive=True,
980
+ height=300
981
+ )
982
+
983
+ with gr.Row():
984
+ point_type = gr.Radio(
985
+ choices=["positive_point", "negative_point"],
986
+ value="positive_point",
987
+ label="Point Type",
988
+ info="Positive: track these areas | Negative: avoid these areas"
989
+ )
990
+
991
+ with gr.Row():
992
+ reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
993
+
994
+ # Downloads section - hidden but still functional for local processing
995
+ with gr.Row(visible=False):
996
+ with gr.Column(scale=1):
997
+ tracking_video_download = gr.File(
998
+ label="📹 Download 2D Tracking Video",
999
+ interactive=False,
1000
+ visible=False
1001
+ )
1002
+ with gr.Column(scale=1):
1003
+ html_download = gr.File(
1004
+ label="📄 Download 3D Visualization HTML",
1005
+ interactive=False,
1006
+ visible=False
1007
+ )
1008
+
1009
+ # GitHub Star Section
1010
+ gr.HTML("""
1011
+ <div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
1012
+ border-radius: 8px; padding: 20px; margin: 15px 0;
1013
+ box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
1014
+ border: 1px solid rgba(102, 126, 234, 0.15);'>
1015
+ <div style='text-align: center;'>
1016
+ <h3 style='color: #4a5568; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
1017
+ ⭐ Love SpatialTracker? Give us a Star! ⭐
1018
+ </h3>
1019
+ <p style='color: #666; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
1020
+ Help us grow by starring our repository on GitHub! Your support means a lot to the community. 🚀
1021
+ </p>
1022
+ <a href="https://github.com/henry123-boy/SpaTrackerV2" target="_blank"
1023
+ style='display: inline-flex; align-items: center; gap: 8px;
1024
+ background: rgba(102, 126, 234, 0.1); color: #4a5568;
1025
+ padding: 10px 20px; border-radius: 25px; text-decoration: none;
1026
+ font-weight: bold; font-size: 14px; border: 1px solid rgba(102, 126, 234, 0.2);
1027
+ transition: all 0.3s ease;'
1028
+ onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-2px)'"
1029
+ onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
1030
+ <span style='font-size: 16px;'>⭐</span>
1031
+ Star SpatialTracker V2 on GitHub
1032
+ </a>
1033
+ </div>
1034
+ </div>
1035
+ """)
1036
+
1037
+ # Acknowledgments Section
1038
+ gr.HTML("""
1039
+ <div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
1040
+ border-radius: 8px; padding: 20px; margin: 15px 0;
1041
+ box-shadow: 0 2px 8px rgba(255, 193, 7, 0.1);
1042
+ border: 1px solid rgba(255, 193, 7, 0.2);'>
1043
+ <div style='text-align: center;'>
1044
+ <h3 style='color: #5d4037; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
1045
+ 📚 Acknowledgments
1046
+ </h3>
1047
+ <p style='color: #5d4037; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
1048
+ Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work and contribution to the computer vision community!
1049
+ </p>
1050
+ <a href="https://github.com/zbw001/TAPIP3D" target="_blank"
1051
+ style='display: inline-flex; align-items: center; gap: 8px;
1052
+ background: rgba(255, 193, 7, 0.15); color: #5d4037;
1053
+ padding: 10px 20px; border-radius: 25px; text-decoration: none;
1054
+ font-weight: bold; font-size: 14px; border: 1px solid rgba(255, 193, 7, 0.3);
1055
+ transition: all 0.3s ease;'
1056
+ onmouseover="this.style.background='rgba(255, 193, 7, 0.25)'; this.style.transform='translateY(-2px)'"
1057
+ onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'; this.style.transform='translateY(0)'">
1058
+ 📚 Visit TAPIP3D Repository
1059
+ </a>
1060
+ </div>
1061
+ </div>
1062
+ """)
1063
+
1064
+ # Footer
1065
+ gr.HTML("""
1066
+ <div style='text-align: center; margin: 20px 0 10px 0;'>
1067
+ <span style='font-size: 12px; color: #888; font-style: italic;'>
1068
+ Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
1069
+ </span>
1070
+ </div>
1071
+ """)
1072
+
1073
+ # Hidden state variables
1074
+ original_image_state = gr.State(None)
1075
+ selected_points = gr.State([])
1076
+
1077
+ # Event handlers
1078
+ video_input.change(
1079
+ fn=handle_video_upload,
1080
+ inputs=[video_input],
1081
+ outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
1082
+ )
1083
+
1084
+ interactive_frame.select(
1085
+ fn=select_point,
1086
+ inputs=[original_image_state, selected_points, point_type],
1087
+ outputs=[interactive_frame, selected_points]
1088
+ )
1089
+
1090
+ reset_points_btn.click(
1091
+ fn=reset_points,
1092
+ inputs=[original_image_state, selected_points],
1093
+ outputs=[interactive_frame, selected_points]
1094
+ )
1095
+
1096
+ clear_all_btn.click(
1097
+ fn=clear_all_with_download,
1098
+ outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
1099
+ )
1100
+
1101
+ launch_btn.click(
1102
+ fn=launch_viz,
1103
+ inputs=[grid_size, vo_points, fps, original_image_state],
1104
+ outputs=[viz_html, tracking_video_download, html_download]
1105
+ )
1106
+
1107
+ # Launch the interface
1108
+ if __name__ == "__main__":
1109
+ print("🌟 Launching SpatialTracker V2 Local Version...")
1110
+ print("🔗 Running in Local Processing Mode")
1111
+
1112
+ demo.launch(
1113
+ server_name="0.0.0.0",
1114
+ server_port=7860,
1115
+ share=True,
1116
+ debug=True,
1117
+ show_error=True
1118
+ )
app_3rd/README.md ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🌟 SpatialTrackerV2 Integrated with SAM 🌟
2
+ SAM receives a point prompt and generates a mask for the target object, facilitating easy interaction to obtain the object's 3D trajectories with SpaTrack2.
3
+
4
+ ## Installation
5
+ ```
6
+
7
+ python -m pip install git+https://github.com/facebookresearch/segment-anything.git
8
+ cd app_3rd/sam_utils
9
+ mkdir checkpoints
10
+ cd checkpoints
11
+ wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
12
+ ```
app_3rd/sam_utils/hf_sam_predictor.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import numpy as np
3
+ import torch
4
+ from typing import Optional, Tuple, List, Union
5
+ import warnings
6
+ import cv2
7
+ try:
8
+ from transformers import SamModel, SamProcessor
9
+ from huggingface_hub import hf_hub_download
10
+ HF_AVAILABLE = True
11
+ except ImportError:
12
+ HF_AVAILABLE = False
13
+ warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
14
+
15
+ # Hugging Face model mapping
16
+ HF_MODELS = {
17
+ 'vit_b': 'facebook/sam-vit-base',
18
+ 'vit_l': 'facebook/sam-vit-large',
19
+ 'vit_h': 'facebook/sam-vit-huge'
20
+ }
21
+
22
+ class HFSamPredictor:
23
+ """
24
+ Hugging Face version of SamPredictor that wraps the transformers SAM models.
25
+ This class provides the same interface as the original SamPredictor for seamless integration.
26
+ """
27
+
28
+ def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
29
+ """
30
+ Initialize the HF SAM predictor.
31
+
32
+ Args:
33
+ model: The SAM model from transformers
34
+ processor: The SAM processor from transformers
35
+ device: Device to run the model on ('cuda', 'cpu', etc.)
36
+ """
37
+ self.model = model
38
+ self.processor = processor
39
+ self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
40
+ self.model.to(self.device)
41
+ self.model.eval()
42
+
43
+ # Store the current image and its features
44
+ self.original_size = None
45
+ self.input_size = None
46
+ self.features = None
47
+ self.image = None
48
+
49
+ @classmethod
50
+ def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
51
+ """
52
+ Load a SAM model from Hugging Face Hub.
53
+
54
+ Args:
55
+ model_name: Model name from HF_MODELS or direct HF model path
56
+ device: Device to load the model on
57
+
58
+ Returns:
59
+ HFSamPredictor instance
60
+ """
61
+ if not HF_AVAILABLE:
62
+ raise ImportError("transformers and huggingface_hub are required for HF SAM models")
63
+
64
+ # Map model type to HF model name if needed
65
+ if model_name in HF_MODELS:
66
+ model_name = HF_MODELS[model_name]
67
+
68
+ print(f"Loading SAM model from Hugging Face: {model_name}")
69
+
70
+ # Load model and processor
71
+ model = SamModel.from_pretrained(model_name)
72
+ processor = SamProcessor.from_pretrained(model_name)
73
+ return cls(model, processor, device)
74
+
75
+ def preprocess(self, image: np.ndarray,
76
+ input_points: List[List[float]], input_labels: List[int]) -> None:
77
+ """
78
+ Set the image for prediction. This preprocesses the image and extracts features.
79
+
80
+ Args:
81
+ image: Input image as numpy array (H, W, C) in RGB format
82
+ """
83
+ if image.dtype != np.uint8:
84
+ image = (image * 255).astype(np.uint8)
85
+
86
+ self.image = image
87
+ self.original_size = image.shape[:2]
88
+
89
+ # Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
90
+ inputs = self.processor(
91
+ images=image,
92
+ input_points=input_points,
93
+ input_labels=input_labels,
94
+ return_tensors="pt"
95
+ )
96
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
97
+
98
+ self.input_size = inputs['pixel_values'].shape[-2:]
99
+ self.features = inputs
100
+ return inputs
101
+
102
+
103
+ def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
104
+ image: Optional[np.ndarray] = None) -> HFSamPredictor:
105
+ """
106
+ Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
107
+
108
+ Args:
109
+ model_type: Model type ('vit_b', 'vit_l', 'vit_h')
110
+ device: Device to run the model on
111
+ image: Optional image to set immediately
112
+
113
+ Returns:
114
+ HFSamPredictor instance
115
+ """
116
+ if not HF_AVAILABLE:
117
+ raise ImportError("transformers and huggingface_hub are required for HF SAM models")
118
+
119
+ if device is None:
120
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
121
+
122
+ # Load the predictor
123
+ predictor = HFSamPredictor.from_pretrained(model_type, device)
124
+
125
+ # Set image if provided
126
+ if image is not None:
127
+ predictor.set_image(image)
128
+
129
+ return predictor
app_3rd/sam_utils/inference.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+
3
+ import numpy as np
4
+ import torch
5
+ from segment_anything import SamPredictor, sam_model_registry
6
+
7
+ # Try to import HF SAM support
8
+ try:
9
+ from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
10
+ HF_AVAILABLE = True
11
+ except ImportError:
12
+ HF_AVAILABLE = False
13
+
14
+ models = {
15
+ 'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
16
+ 'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
17
+ 'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
18
+ }
19
+
20
+
21
+ def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True, predictor=None):
22
+ """
23
+ Get SAM predictor with option to use HuggingFace version
24
+
25
+ Args:
26
+ model_type: Model type ('vit_b', 'vit_l', 'vit_h')
27
+ device: Device to run on
28
+ image: Optional image to set immediately
29
+ use_hf: Whether to use HuggingFace SAM instead of original SAM
30
+ """
31
+ if predictor is not None:
32
+ return predictor
33
+ if use_hf:
34
+ if not HF_AVAILABLE:
35
+ raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
36
+ return get_hf_sam_predictor(model_type, device, image)
37
+
38
+ # Original SAM logic
39
+ if device is None and torch.cuda.is_available():
40
+ device = 'cuda'
41
+ elif device is None:
42
+ device = 'cpu'
43
+ # sam model
44
+ sam = sam_model_registry[model_type](checkpoint=models[model_type])
45
+ sam = sam.to(device)
46
+
47
+ predictor = SamPredictor(sam)
48
+ if image is not None:
49
+ predictor.set_image(image)
50
+ return predictor
51
+
52
+
53
+ def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
54
+ """
55
+ Run inference with either original SAM or HF SAM predictor
56
+
57
+ Args:
58
+ predictor: SamPredictor or HFSamPredictor instance
59
+ input_x: Input image
60
+ selected_points: List of (point, label) tuples
61
+ multi_object: Whether to handle multiple objects
62
+ """
63
+ if len(selected_points) == 0:
64
+ return []
65
+
66
+ # Check if using HF SAM
67
+ if isinstance(predictor, HFSamPredictor):
68
+ return _run_hf_inference(predictor, input_x, selected_points, multi_object)
69
+ else:
70
+ return _run_original_inference(predictor, input_x, selected_points, multi_object)
71
+
72
+
73
+ def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
74
+ """Run inference with original SAM"""
75
+ points = torch.Tensor(
76
+ [p for p, _ in selected_points]
77
+ ).to(predictor.device).unsqueeze(1)
78
+
79
+ labels = torch.Tensor(
80
+ [int(l) for _, l in selected_points]
81
+ ).to(predictor.device).unsqueeze(1)
82
+
83
+ transformed_points = predictor.transform.apply_coords_torch(
84
+ points, input_x.shape[:2])
85
+
86
+ masks, scores, logits = predictor.predict_torch(
87
+ point_coords=transformed_points[:,0][None],
88
+ point_labels=labels[:,0][None],
89
+ multimask_output=False,
90
+ )
91
+ masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
92
+
93
+ gc.collect()
94
+ torch.cuda.empty_cache()
95
+
96
+ return [(masks, 'final_mask')]
97
+
98
+
99
+ def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
100
+ """Run inference with HF SAM"""
101
+ # Prepare points and labels for HF SAM
102
+ select_pts = [[list(p) for p, _ in selected_points]]
103
+ select_lbls = [[int(l) for _, l in selected_points]]
104
+
105
+ # Preprocess inputs
106
+ inputs = predictor.preprocess(input_x, select_pts, select_lbls)
107
+
108
+ # Run inference
109
+ with torch.no_grad():
110
+ outputs = predictor.model(**inputs)
111
+
112
+ # Post-process masks
113
+ masks = predictor.processor.image_processor.post_process_masks(
114
+ outputs.pred_masks.cpu(),
115
+ inputs["original_sizes"].cpu(),
116
+ inputs["reshaped_input_sizes"].cpu(),
117
+ )
118
+ masks = masks[0][:,:1,...].cpu().numpy()
119
+
120
+ gc.collect()
121
+ torch.cuda.empty_cache()
122
+
123
+ return [(masks, 'final_mask')]
app_3rd/spatrack_utils/infer_track.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.SpaTrackV2.models.predictor import Predictor
2
+ import yaml
3
+ import easydict
4
+ import os
5
+ import numpy as np
6
+ import cv2
7
+ import torch
8
+ import torchvision.transforms as T
9
+ from PIL import Image
10
+ import io
11
+ import moviepy.editor as mp
12
+ from models.SpaTrackV2.utils.visualizer import Visualizer
13
+ import tqdm
14
+ from models.SpaTrackV2.models.utils import get_points_on_a_grid
15
+ import glob
16
+ from rich import print
17
+ import argparse
18
+ import decord
19
+ from huggingface_hub import hf_hub_download
20
+
21
+ config = {
22
+ "ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
23
+ "cfg_dir": "config/magic_infer_moge.yaml",
24
+ }
25
+
26
+ def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
27
+ """
28
+ Initialize and return the tracker predictor and visualizer
29
+ Args:
30
+ output_dir: Directory to save visualization results
31
+ vo_points: Number of points for visual odometry
32
+ Returns:
33
+ Tuple of (tracker_predictor, visualizer)
34
+ """
35
+ viz = True
36
+ os.makedirs(output_dir, exist_ok=True)
37
+
38
+ with open(config["cfg_dir"], "r") as f:
39
+ cfg = yaml.load(f, Loader=yaml.FullLoader)
40
+ cfg = easydict.EasyDict(cfg)
41
+ cfg.out_dir = output_dir
42
+ cfg.model.track_num = vo_points
43
+
44
+ # Check if it's a local path or HuggingFace repo
45
+ if tracker_model is not None:
46
+ model = tracker_model
47
+ model.spatrack.track_num = vo_points
48
+ else:
49
+ if os.path.exists(config["ckpt_dir"]):
50
+ # Local file
51
+ model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
52
+ else:
53
+ # HuggingFace repo - download the model
54
+ print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
55
+ checkpoint_path = hf_hub_download(
56
+ repo_id=config["ckpt_dir"],
57
+ repo_type="model",
58
+ filename="SpaTrack3_offline.pth"
59
+ )
60
+ model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
61
+ model.eval()
62
+ model.to("cuda")
63
+
64
+ viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
65
+ fps=10, pad_value=0, tracks_leave_trace=5)
66
+
67
+ return model, viser
68
+
69
+ def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
70
+ """
71
+ Run tracking on a video sequence
72
+ Args:
73
+ model: Tracker predictor instance
74
+ viser: Visualizer instance
75
+ temp_dir: Directory containing temporary files
76
+ video_name: Name of the video file (without extension)
77
+ grid_size: Size of the tracking grid
78
+ vo_points: Number of points for visual odometry
79
+ fps: Frames per second for visualization
80
+ """
81
+ # Setup paths
82
+ video_path = os.path.join(temp_dir, f"{video_name}.mp4")
83
+ mask_path = os.path.join(temp_dir, f"{video_name}.png")
84
+ out_dir = os.path.join(temp_dir, "results")
85
+ os.makedirs(out_dir, exist_ok=True)
86
+
87
+ # Load video using decord
88
+ video_reader = decord.VideoReader(video_path)
89
+ video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
90
+
91
+ # resize make sure the shortest side is 336
92
+ h, w = video_tensor.shape[2:]
93
+ scale = max(336 / h, 336 / w)
94
+ if scale < 1:
95
+ new_h, new_w = int(h * scale), int(w * scale)
96
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
97
+ video_tensor = video_tensor[::fps].float()
98
+ depth_tensor = None
99
+ intrs = None
100
+ extrs = None
101
+ data_npz_load = {}
102
+
103
+ # Load and process mask
104
+ if os.path.exists(mask_path):
105
+ mask = cv2.imread(mask_path)
106
+ mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
107
+ mask = mask.sum(axis=-1)>0
108
+ else:
109
+ mask = np.ones_like(video_tensor[0,0].numpy())>0
110
+
111
+ # Get frame dimensions and create grid points
112
+ frame_H, frame_W = video_tensor.shape[2:]
113
+ grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
114
+
115
+ # Sample mask values at grid points and filter out points where mask=0
116
+ if os.path.exists(mask_path):
117
+ grid_pts_int = grid_pts[0].long()
118
+ mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
119
+ grid_pts = grid_pts[:, mask_values]
120
+
121
+ query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
122
+
123
+ # run vggt
124
+ if os.environ.get("VGGT_DIR", None) is not None:
125
+ vggt_model = VGGT()
126
+ vggt_model.load_state_dict(torch.load(VGGT_DIR))
127
+ vggt_model.eval()
128
+ vggt_model = vggt_model.to("cuda")
129
+ # process the image tensor
130
+ video_tensor = preprocess_image(video_tensor)[None]
131
+ with torch.cuda.amp.autocast(dtype=torch.bfloat16):
132
+ # Predict attributes including cameras, depth maps, and point maps.
133
+ aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
134
+ pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
135
+ # Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
136
+ extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
137
+ # Predict Depth Maps
138
+ depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
139
+ # clear the cache
140
+ del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
141
+ torch.cuda.empty_cache()
142
+ depth_tensor = depth_map.squeeze().cpu().numpy()
143
+ extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
144
+ extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
145
+ intrs = intrinsic.squeeze().cpu().numpy()
146
+ video_tensor = video_tensor.squeeze()
147
+ #NOTE: 20% of the depth is not reliable
148
+ # threshold = depth_conf.squeeze().view(-1).quantile(0.5)
149
+ unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
150
+
151
+ # Run model inference
152
+ with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
153
+ (
154
+ c2w_traj, intrs, point_map, conf_depth,
155
+ track3d_pred, track2d_pred, vis_pred, conf_pred, video
156
+ ) = model.forward(video_tensor, depth=depth_tensor,
157
+ intrs=intrs, extrs=extrs,
158
+ queries=query_xyt,
159
+ fps=1, full_point=False, iters_track=4,
160
+ query_no_BA=True, fixed_cam=False, stage=1,
161
+ support_frame=len(video_tensor)-1, replace_ratio=0.2)
162
+
163
+ # Resize results to avoid too large I/O Burden
164
+ max_size = 336
165
+ h, w = video.shape[2:]
166
+ scale = min(max_size / h, max_size / w)
167
+ if scale < 1:
168
+ new_h, new_w = int(h * scale), int(w * scale)
169
+ video = T.Resize((new_h, new_w))(video)
170
+ video_tensor = T.Resize((new_h, new_w))(video_tensor)
171
+ point_map = T.Resize((new_h, new_w))(point_map)
172
+ track2d_pred[...,:2] = track2d_pred[...,:2] * scale
173
+ intrs[:,:2,:] = intrs[:,:2,:] * scale
174
+ if depth_tensor is not None:
175
+ depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
176
+ conf_depth = T.Resize((new_h, new_w))(conf_depth)
177
+
178
+ # Visualize tracks
179
+ viser.visualize(video=video[None],
180
+ tracks=track2d_pred[None][...,:2],
181
+ visibility=vis_pred[None],filename="test")
182
+
183
+ # Save in tapip3d format
184
+ data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
185
+ data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
186
+ data_npz_load["intrinsics"] = intrs.cpu().numpy()
187
+ data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
188
+ data_npz_load["video"] = (video_tensor).cpu().numpy()/255
189
+ data_npz_load["visibs"] = vis_pred.cpu().numpy()
190
+ data_npz_load["confs"] = conf_pred.cpu().numpy()
191
+ data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
192
+ np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
193
+
194
+ print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
config/__init__.py ADDED
File without changes
config/magic_infer_moge.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ seed: 0
2
+ # config the hydra logger, only in hydra `$` can be decoded as cite
3
+ data: ./assets/room
4
+ vis_track: false
5
+ hydra:
6
+ run:
7
+ dir: .
8
+ output_subdir: null
9
+ job_logging: {}
10
+ hydra_logging: {}
11
+ mixed_precision: bf16
12
+ visdom:
13
+ viz_ip: "localhost"
14
+ port: 6666
15
+ relax_load: false
16
+ res_all: 336
17
+ # config the ckpt path
18
+ # ckpts: "/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/checkpoints/new_base.pth"
19
+ ckpts: "Yuxihenry/SpatialTracker_Files"
20
+ batch_size: 1
21
+ input:
22
+ type: image
23
+ fps: 1
24
+ model_wind_size: 32
25
+ model:
26
+ backbone_cfg:
27
+ ckpt_dir: "checkpoints/model.pt"
28
+ chunk_size: 24 # downsample factor for patchified features
29
+ ckpt_fwd: true
30
+ ft_cfg:
31
+ mode: "fix"
32
+ paras_name: []
33
+ resolution: 336
34
+ max_len: 512
35
+ Track_cfg:
36
+ base_ckpt: "checkpoints/scaled_offline.pth"
37
+ base:
38
+ stride: 4
39
+ corr_radius: 3
40
+ window_len: 60
41
+ stablizer: True
42
+ mode: "online"
43
+ s_wind: 200
44
+ overlap: 4
45
+ track_num: 0
46
+
47
+ dist_train:
48
+ num_nodes: 1
examples/backpack.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b5ac6b2285ffb48e3a740e419e38c781df9c963589a5fd894e5b4e13dd6a8b8
3
+ size 1208738
examples/ball.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:31f6e3bf875a85284b376c05170b4c08b546b7d5e95106848b1e3818a9d0db91
3
+ size 3030268
examples/basketball.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0df3b429d5fd64c298f2d79b2d818a4044e7341a71d70b957f60b24e313c3760
3
+ size 2487837
examples/biker.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fba880c24bdb8fa3b84b1b491d52f2c1f426fb09e34c3013603e5a549cf3b22b
3
+ size 249196
examples/cinema_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a68a5643c14f61c05d48e25a98ddf5cf0344d3ffcda08ad4a0adc989d49d7a9c
3
+ size 1774022
examples/cinema_1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99624e2d0fb2e9f994e46aefb904e884de37a6d78e7f6b6670e286eaa397e515
3
+ size 2370749
examples/drifting.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f3937871117d3cc5d7da3ef31d1edf5626fc8372126b73590f75f05713fe97c
3
+ size 4695804
examples/ego_kc1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:22fe64e458e329e8b3c3e20b3725ffd85c3a2e725fd03909cf883d3fd02c80b3
3
+ size 1365980
examples/ego_teaser.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8780b291b48046b1c7dea90712c1c3f59d60c03216df1c489f6f03e8d61fae5c
3
+ size 7365665
examples/handwave.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e6dde7cf4ffa7c66b6861bb5abdedc49dfc4b5b4dd9dd46ee8415dd4953935b6
3
+ size 2099369
examples/hockey.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c3be095777b442dc401e7d1f489b749611ffade3563a01e4e3d1e511311bd86
3
+ size 1795810
examples/ken_block_0.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7b788faeb4d3206fa604d622a05268f1321ad6a229178fe12319d20c9438deb1
3
+ size 196343
examples/kiss.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f78fffc5108d95d4e5837d7607226f3dd9796615ea3481f2629c69ccd2ccb12f
3
+ size 1073570
examples/kitchen.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3120e942a9b3d7b300928e43113b000fb5ccc209012a2c560ec26b8a04c2d5f9
3
+ size 543970
examples/kitchen_egocentric.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5468ab10d8d39b68b51fa616adc3d099dab7543e38dd221a0a7a20a2401824a2
3
+ size 2176685
examples/pillow.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f05818f586d7b0796fcd4714ea4be489c93701598cadc86ce7973fc24655fee
3
+ size 1407147
examples/protein.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2dc9cfceb0984b61ebc62fda4c826135ebe916c8c966a8123dcc3315d43b73f
3
+ size 2002300
examples/pusht.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:996d1923e36811a1069e4d6b5e8c0338d9068c0870ea09c4c04e13e9fbcd207a
3
+ size 5256495
examples/robot1.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a3b9e4449572129fdd96a751938e211241cdd86bcc56ffd33bfd23fc4d6e9c0
3
+ size 1178671
examples/robot2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:188b2d8824ce345c86a603bff210639a6158d72cf6119cc1d3f79d409ac68bb3
3
+ size 867261
examples/robot_3.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:784a0f9c36a316d0da5745075dbc8cefd9ce60c25b067d3d80a1d52830df8a37
3
+ size 1153015
examples/robot_unitree.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:99bc274f7613a665c6135085fe01691ebfaa9033101319071f37c550ab21d1ea
3
+ size 1964268
examples/running.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9ceb96b287fefb1c090dcd2f5db7634f808d2079413500beeb7b33023dfae51b
3
+ size 7307897
examples/teleop2.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:59ea006a18227da8cf5db1fa50cd48e71ec7eb66fef48ea2158c325088bd9fee
3
+ size 1077267
examples/vertical_place.mp4 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6c8061ae449f986113c2ecb17aefc2c13f737aecbcd41d6c057c88e6d41ac3ee
3
+ size 719810
models/SpaTrackV2/models/SpaTrack.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #python
2
+ """
3
+ SpaTrackerV2, which is an unified model to estimate 'intrinsic',
4
+ 'video depth', 'extrinsic' and '3D Tracking' from casual video frames.
5
+
6
+ Contact: DM yuxixiao@zju.edu.cn
7
+ """
8
+
9
+ import os
10
+ import numpy as np
11
+ from typing import Literal, Union, List, Tuple, Dict
12
+ import cv2
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ # from depth anything v2
17
+ from huggingface_hub import PyTorchModelHubMixin # used for model hub
18
+ from einops import rearrange
19
+ from models.monoD.depth_anything_v2.dpt import DepthAnythingV2
20
+ from models.moge.model.v1 import MoGeModel
21
+ import copy
22
+ from functools import partial
23
+ from models.SpaTrackV2.models.tracker3D.TrackRefiner import TrackRefiner3D
24
+ import kornia
25
+ from models.SpaTrackV2.utils.model_utils import sample_features5d
26
+ import utils3d
27
+ from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
28
+ from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
29
+ import random
30
+
31
+ class SpaTrack2(nn.Module, PyTorchModelHubMixin):
32
+ def __init__(
33
+ self,
34
+ loggers: list, # include [ viz, logger_tf, logger]
35
+ backbone_cfg,
36
+ Track_cfg=None,
37
+ chunk_size=24,
38
+ ckpt_fwd: bool = False,
39
+ ft_cfg=None,
40
+ resolution=518,
41
+ max_len=600, # the maximum video length we can preprocess,
42
+ track_num=768,
43
+ ):
44
+
45
+ self.chunk_size = chunk_size
46
+ self.max_len = max_len
47
+ self.resolution = resolution
48
+ # config the T-Lora Dinov2
49
+ #NOTE: initial the base model
50
+ base_cfg = copy.deepcopy(backbone_cfg)
51
+ backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
52
+
53
+ super(SpaTrack2, self).__init__()
54
+ if os.path.exists(backbone_ckpt_dir)==False:
55
+ base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
56
+ else:
57
+ checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
58
+ base_model = MoGeModel(**checkpoint["model_config"])
59
+ base_model.load_state_dict(checkpoint['model'])
60
+ # avoid the base_model is a member of SpaTrack2
61
+ object.__setattr__(self, 'base_model', base_model)
62
+
63
+ # Tracker model
64
+ self.Track3D = TrackRefiner3D(Track_cfg)
65
+ track_base_ckpt_dir = Track_cfg.base_ckpt
66
+ if os.path.exists(track_base_ckpt_dir):
67
+ track_pretrain = torch.load(track_base_ckpt_dir)
68
+ self.Track3D.load_state_dict(track_pretrain, strict=False)
69
+
70
+ # wrap the function of make lora trainable
71
+ self.make_paras_trainable = partial(self.make_paras_trainable,
72
+ mode=ft_cfg.mode,
73
+ paras_name=ft_cfg.paras_name)
74
+ self.track_num = track_num
75
+
76
+ def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
77
+ # gradient required for the lora_experts and gate
78
+ for name, param in self.named_parameters():
79
+ if any(x in name for x in paras_name):
80
+ if mode == 'fix':
81
+ param.requires_grad = False
82
+ else:
83
+ param.requires_grad = True
84
+ else:
85
+ if mode == 'fix':
86
+ param.requires_grad = True
87
+ else:
88
+ param.requires_grad = False
89
+ total_params = sum(p.numel() for p in self.parameters())
90
+ trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
91
+ print(f"Total parameters: {total_params}")
92
+ print(f"Trainable parameters: {trainable_params/total_params*100:.2f}%")
93
+
94
+ def ProcVid(self,
95
+ x: torch.Tensor):
96
+ """
97
+ split the video into several overlapped windows.
98
+
99
+ args:
100
+ x: the input video frames. [B, T, C, H, W]
101
+ outputs:
102
+ patch_size: the patch size of the video features
103
+ raises:
104
+ ValueError: if the input video is longer than `max_len`.
105
+
106
+ """
107
+ # normalize the input images
108
+ num_types = x.dtype
109
+ x = normalize_rgb(x, input_size=self.resolution)
110
+ x = x.to(num_types)
111
+ # get the video features
112
+ B, T, C, H, W = x.size()
113
+ if T > self.max_len:
114
+ raise ValueError(f"the video length should no more than {self.max_len}.")
115
+ # get the video features
116
+ patch_h, patch_w = H // 14, W // 14
117
+ patch_size = (patch_h, patch_w)
118
+ # resize and get the video features
119
+ x = x.view(B * T, C, H, W)
120
+ # operate the temporal encoding
121
+ return patch_size, x
122
+
123
+ def forward_stream(
124
+ self,
125
+ video: torch.Tensor,
126
+ queries: torch.Tensor = None,
127
+ T_org: int = None,
128
+ depth: torch.Tensor|np.ndarray|str=None,
129
+ unc_metric_in: torch.Tensor|np.ndarray|str=None,
130
+ intrs: torch.Tensor|np.ndarray|str=None,
131
+ extrs: torch.Tensor|np.ndarray|str=None,
132
+ queries_3d: torch.Tensor = None,
133
+ window_len: int = 16,
134
+ overlap_len: int = 4,
135
+ full_point: bool = False,
136
+ track2d_gt: torch.Tensor = None,
137
+ fixed_cam: bool = False,
138
+ query_no_BA: bool = False,
139
+ stage: int = 0,
140
+ support_frame: int = 0,
141
+ replace_ratio: float = 0.6,
142
+ annots_train: Dict = None,
143
+ iters_track=4,
144
+ **kwargs,
145
+ ):
146
+ # step 1 allocate the query points on the grid
147
+ T, C, H, W = video.shape
148
+
149
+ if annots_train is not None:
150
+ vis_gt = annots_train["vis"]
151
+ _, _, N = vis_gt.shape
152
+ number_visible = vis_gt.sum(dim=1)
153
+ ratio_rand = torch.rand(1, N, device=vis_gt.device)
154
+ first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
155
+ assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
156
+
157
+ first_positive_inds = first_positive_inds.long()
158
+ gather = torch.gather(
159
+ annots_train["traj_3d"][...,:2], 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
160
+ )
161
+ xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
162
+ queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)[0].cpu().numpy()
163
+
164
+
165
+ # Unfold video into segments of window_len with overlap_len
166
+ step_slide = window_len - overlap_len
167
+ if T < window_len:
168
+ video_unf = video.unsqueeze(0)
169
+ if depth is not None:
170
+ depth_unf = depth.unsqueeze(0)
171
+ else:
172
+ depth_unf = None
173
+ if unc_metric_in is not None:
174
+ unc_metric_unf = unc_metric_in.unsqueeze(0)
175
+ else:
176
+ unc_metric_unf = None
177
+ if intrs is not None:
178
+ intrs_unf = intrs.unsqueeze(0)
179
+ else:
180
+ intrs_unf = None
181
+ if extrs is not None:
182
+ extrs_unf = extrs.unsqueeze(0)
183
+ else:
184
+ extrs_unf = None
185
+ else:
186
+ video_unf = video.unfold(0, window_len, step_slide).permute(0, 4, 1, 2, 3) # [B, S, C, H, W]
187
+ if depth is not None:
188
+ depth_unf = depth.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
189
+ intrs_unf = intrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
190
+ else:
191
+ depth_unf = None
192
+ intrs_unf = None
193
+ if extrs is not None:
194
+ extrs_unf = extrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
195
+ else:
196
+ extrs_unf = None
197
+ if unc_metric_in is not None:
198
+ unc_metric_unf = unc_metric_in.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
199
+ else:
200
+ unc_metric_unf = None
201
+
202
+ # parallel
203
+ # Get number of segments
204
+ B = video_unf.shape[0]
205
+ #TODO: Process each segment in parallel using torch.nn.DataParallel
206
+ c2w_traj = torch.eye(4, 4)[None].repeat(T, 1, 1)
207
+ intrs_out = torch.eye(3, 3)[None].repeat(T, 1, 1)
208
+ point_map = torch.zeros(T, 3, H, W).cuda()
209
+ unc_metric = torch.zeros(T, H, W).cuda()
210
+ # set the queries
211
+ N, _ = queries.shape
212
+ track3d_pred = torch.zeros(T, N, 6).cuda()
213
+ track2d_pred = torch.zeros(T, N, 3).cuda()
214
+ vis_pred = torch.zeros(T, N, 1).cuda()
215
+ conf_pred = torch.zeros(T, N, 1).cuda()
216
+ dyn_preds = torch.zeros(T, N, 1).cuda()
217
+ # sort the queries by time
218
+ sorted_indices = np.argsort(queries[...,0])
219
+ sorted_inv_indices = np.argsort(sorted_indices)
220
+ sort_query = queries[sorted_indices]
221
+ sort_query = torch.from_numpy(sort_query).cuda()
222
+ if queries_3d is not None:
223
+ sort_query_3d = queries_3d[sorted_indices]
224
+ sort_query_3d = torch.from_numpy(sort_query_3d).cuda()
225
+
226
+ queries_len = 0
227
+ overlap_d = None
228
+ cache = None
229
+ loss = 0.0
230
+
231
+ for i in range(B):
232
+ segment = video_unf[i:i+1].cuda()
233
+ # Forward pass through model
234
+ # detect the key points for each frames
235
+
236
+ queries_new_mask = (sort_query[...,0] < i * step_slide + window_len) * (sort_query[...,0] >= (i * step_slide + overlap_len if i > 0 else 0))
237
+ if queries_3d is not None:
238
+ queries_new_3d = sort_query_3d[queries_new_mask]
239
+ queries_new_3d = queries_new_3d.float()
240
+ else:
241
+ queries_new_3d = None
242
+ queries_new = sort_query[queries_new_mask.bool()]
243
+ queries_new = queries_new.float()
244
+ if i > 0:
245
+ overlap2d = track2d_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
246
+ overlapvis = vis_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
247
+ overlapconf = conf_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
248
+ overlap_query = (overlapvis * overlapconf).max(dim=0)[1][None, ...]
249
+ overlap_xy = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,2))
250
+ overlap_d = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,3))[...,2].detach()
251
+ overlap_query = torch.cat([overlap_query[...,:1], overlap_xy], dim=-1)[0]
252
+ queries_new[...,0] -= i*step_slide
253
+ queries_new = torch.cat([overlap_query, queries_new], dim=0).detach()
254
+
255
+ if annots_train is None:
256
+ annots = {}
257
+ else:
258
+ annots = copy.deepcopy(annots_train)
259
+ annots["traj_3d"] = annots["traj_3d"][:, i*step_slide:i*step_slide+window_len, sorted_indices,:][...,:len(queries_new),:]
260
+ annots["vis"] = annots["vis"][:, i*step_slide:i*step_slide+window_len, sorted_indices][...,:len(queries_new)]
261
+ annots["poses_gt"] = annots["poses_gt"][:, i*step_slide:i*step_slide+window_len]
262
+ annots["depth_gt"] = annots["depth_gt"][:, i*step_slide:i*step_slide+window_len]
263
+ annots["intrs"] = annots["intrs"][:, i*step_slide:i*step_slide+window_len]
264
+ annots["traj_mat"] = annots["traj_mat"][:,i*step_slide:i*step_slide+window_len]
265
+
266
+ if depth is not None:
267
+ annots["depth_gt"] = depth_unf[i:i+1].to(segment.device).to(segment.dtype)
268
+ if unc_metric_in is not None:
269
+ annots["unc_metric"] = unc_metric_unf[i:i+1].to(segment.device).to(segment.dtype)
270
+ if intrs is not None:
271
+ intr_seg = intrs_unf[i:i+1].to(segment.device).to(segment.dtype)[0].clone()
272
+ focal = (intr_seg[:,0,0] / segment.shape[-1] + intr_seg[:,1,1]/segment.shape[-2]) / 2
273
+ pose_fake = torch.zeros(1, 8).to(depth.device).to(depth.dtype).repeat(segment.shape[1], 1)
274
+ pose_fake[:, -1] = focal
275
+ pose_fake[:,3]=1
276
+ annots["intrs_gt"] = intr_seg
277
+ if extrs is not None:
278
+ extrs_unf_norm = extrs_unf[i:i+1][0].clone()
279
+ extrs_unf_norm = torch.inverse(extrs_unf_norm[:1,...]) @ extrs_unf[i:i+1][0]
280
+ rot_vec = matrix_to_quaternion(extrs_unf_norm[:,:3,:3])
281
+ annots["poses_gt"] = torch.zeros(1, rot_vec.shape[0], 7).to(segment.device).to(segment.dtype)
282
+ annots["poses_gt"][:, :, 3:7] = rot_vec.to(segment.device).to(segment.dtype)[None]
283
+ annots["poses_gt"][:, :, :3] = extrs_unf_norm[:,:3,3].to(segment.device).to(segment.dtype)[None]
284
+ annots["use_extr"] = True
285
+
286
+ kwargs.update({"stage": stage})
287
+
288
+ #TODO: DEBUG
289
+ out = self.forward(segment, pts_q=queries_new,
290
+ pts_q_3d=queries_new_3d, overlap_d=overlap_d,
291
+ full_point=full_point,
292
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA,
293
+ support_frame=segment.shape[1]-1,
294
+ cache=cache, replace_ratio=replace_ratio,
295
+ iters_track=iters_track,
296
+ **kwargs, annots=annots)
297
+ if self.training:
298
+ loss += out["loss"].squeeze()
299
+ # from models.SpaTrackV2.utils.visualizer import Visualizer
300
+ # vis_track = Visualizer(grayscale=False,
301
+ # fps=10, pad_value=50, tracks_leave_trace=0)
302
+ # vis_track.visualize(video=segment,
303
+ # tracks=out["traj_est"][...,:2],
304
+ # visibility=out["vis_est"],
305
+ # save_video=True)
306
+ # # visualize 4d
307
+ # import os, json
308
+ # import os.path as osp
309
+ # viser4d_dir = os.path.join("viser_4d_results")
310
+ # os.makedirs(viser4d_dir, exist_ok=True)
311
+ # depth_est = annots["depth_gt"][0]
312
+ # unc_metric = out["unc_metric"]
313
+ # mask = (unc_metric > 0.5).squeeze(1)
314
+ # # pose_est = out["poses_pred"].squeeze(0)
315
+ # pose_est = annots["traj_mat"][0]
316
+ # rgb_tracks = out["rgb_tracks"].squeeze(0)
317
+ # intrinsics = out["intrs"].squeeze(0)
318
+ # for i_k in range(out["depth"].shape[0]):
319
+ # img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
320
+ # img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
321
+ # cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
322
+ # if stage == 1:
323
+ # depth = depth_est[i_k].squeeze().cpu().numpy()
324
+ # np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
325
+ # else:
326
+ # point_map_vis = out["points_map"][i_k].cpu().numpy()
327
+ # np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
328
+ # np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
329
+ # np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
330
+ # np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
331
+ # np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
332
+
333
+ queries_len = len(queries_new)
334
+ # update the track3d and track2d
335
+ left_len = len(track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :])
336
+ track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["rgb_tracks"][0,:left_len,:queries_len,:]
337
+ track2d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["traj_est"][0,:left_len,:queries_len,:3]
338
+ vis_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["vis_est"][0,:left_len,:queries_len,None]
339
+ conf_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["conf_pred"][0,:left_len,:queries_len,None]
340
+ dyn_preds[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["dyn_preds"][0,:left_len,:queries_len,None]
341
+
342
+ # process the output for each segment
343
+ seg_c2w = out["poses_pred"][0]
344
+ seg_intrs = out["intrs"][0]
345
+ seg_point_map = out["points_map"]
346
+ seg_conf_depth = out["unc_metric"]
347
+
348
+ # cache management
349
+ cache = out["cache"]
350
+ for k in cache.keys():
351
+ if "_pyramid" in k:
352
+ for j in range(len(cache[k])):
353
+ if len(cache[k][j].shape) == 5:
354
+ cache[k][j] = cache[k][j][:,:,:,:queries_len,:]
355
+ elif len(cache[k][j].shape) == 4:
356
+ cache[k][j] = cache[k][j][:,:1,:queries_len,:]
357
+ elif "_pred_cache" in k:
358
+ cache[k] = cache[k][-overlap_len:,:queries_len,:]
359
+ else:
360
+ cache[k] = cache[k][-overlap_len:]
361
+
362
+ # update the results
363
+ idx_glob = i * step_slide
364
+ # refine part
365
+ # mask_update = sort_query[..., 0] < i * step_slide + window_len
366
+ # sort_query_pick = sort_query[mask_update]
367
+ intrs_out[idx_glob:idx_glob+window_len] = seg_intrs
368
+ point_map[idx_glob:idx_glob+window_len] = seg_point_map
369
+ unc_metric[idx_glob:idx_glob+window_len] = seg_conf_depth
370
+ # update the camera poses
371
+
372
+ # if using the ground truth pose
373
+ # if extrs_unf is not None:
374
+ # c2w_traj[idx_glob:idx_glob+window_len] = extrs_unf[i:i+1][0].to(c2w_traj.device).to(c2w_traj.dtype)
375
+ # else:
376
+ prev_c2w = c2w_traj[idx_glob:idx_glob+window_len][:1]
377
+ c2w_traj[idx_glob:idx_glob+window_len] = prev_c2w@seg_c2w.to(c2w_traj.device).to(c2w_traj.dtype)
378
+
379
+ track2d_pred = track2d_pred[:T_org,sorted_inv_indices,:]
380
+ track3d_pred = track3d_pred[:T_org,sorted_inv_indices,:]
381
+ vis_pred = vis_pred[:T_org,sorted_inv_indices,:]
382
+ conf_pred = conf_pred[:T_org,sorted_inv_indices,:]
383
+ dyn_preds = dyn_preds[:T_org,sorted_inv_indices,:]
384
+ unc_metric = unc_metric[:T_org,:]
385
+ point_map = point_map[:T_org,:]
386
+ intrs_out = intrs_out[:T_org,:]
387
+ c2w_traj = c2w_traj[:T_org,:]
388
+ if self.training:
389
+ ret = {
390
+ "loss": loss,
391
+ "depth_loss": 0.0,
392
+ "ab_loss": 0.0,
393
+ "vis_loss": out["vis_loss"],
394
+ "track_loss": out["track_loss"],
395
+ "conf_loss": out["conf_loss"],
396
+ "dyn_loss": out["dyn_loss"],
397
+ "sync_loss": out["sync_loss"],
398
+ "poses_pred": c2w_traj[None],
399
+ "intrs": intrs_out[None],
400
+ "points_map": point_map,
401
+ "track3d_pred": track3d_pred[None],
402
+ "rgb_tracks": track3d_pred[None],
403
+ "track2d_pred": track2d_pred[None],
404
+ "traj_est": track2d_pred[None],
405
+ "vis_est": vis_pred[None], "conf_pred": conf_pred[None],
406
+ "dyn_preds": dyn_preds[None],
407
+ "imgs_raw": video[None],
408
+ "unc_metric": unc_metric,
409
+ }
410
+
411
+ return ret
412
+ else:
413
+ return c2w_traj, intrs_out, point_map, unc_metric, track3d_pred, track2d_pred, vis_pred, conf_pred
414
+ def forward(self,
415
+ x: torch.Tensor,
416
+ annots: Dict = {},
417
+ pts_q: torch.Tensor = None,
418
+ pts_q_3d: torch.Tensor = None,
419
+ overlap_d: torch.Tensor = None,
420
+ full_point = False,
421
+ fixed_cam = False,
422
+ support_frame = 0,
423
+ query_no_BA = False,
424
+ cache = None,
425
+ replace_ratio = 0.6,
426
+ iters_track=4,
427
+ **kwargs):
428
+ """
429
+ forward the video camera model, which predict (
430
+ `intr` `camera poses` `video depth`
431
+ )
432
+
433
+ args:
434
+ x: the input video frames. [B, T, C, H, W]
435
+ annots: the annotations for video frames.
436
+ {
437
+ "poses_gt": the pose encoding for the video frames. [B, T, 7]
438
+ "depth_gt": the ground truth depth for the video frames. [B, T, 1, H, W],
439
+ "metric": bool, whether to calculate the metric for the video frames.
440
+ }
441
+ """
442
+ self.support_frame = support_frame
443
+
444
+ #TODO: to adjust a little bit
445
+ track_loss=ab_loss=vis_loss=track_loss=conf_loss=dyn_loss=0.0
446
+ B, T, _, H, W = x.shape
447
+ imgs_raw = x.clone()
448
+ # get the video split and features for each segment
449
+ patch_size, x_resize = self.ProcVid(x)
450
+ x_resize = rearrange(x_resize, "(b t) c h w -> b t c h w", b=B)
451
+ H_resize, W_resize = x_resize.shape[-2:]
452
+
453
+ prec_fx = W / W_resize
454
+ prec_fy = H / H_resize
455
+ # get patch size
456
+ P_H, P_W = patch_size
457
+
458
+ # get the depth, pointmap and mask
459
+ #TODO: Release DepthAnything Version
460
+ points_map_gt = None
461
+ with torch.no_grad():
462
+ if_gt_depth = (("depth_gt" in annots.keys())) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3)
463
+ if if_gt_depth==False:
464
+ if cache is not None:
465
+ T_cache = cache["points_map"].shape[0]
466
+ T_new = T - T_cache
467
+ x_resize_new = x_resize[:, T_cache:]
468
+ else:
469
+ T_new = T
470
+ x_resize_new = x_resize
471
+ # infer with chunk
472
+ chunk_size = self.chunk_size
473
+ metric_depth = []
474
+ intrs = []
475
+ unc_metric = []
476
+ mask = []
477
+ points_map = []
478
+ normals = []
479
+ normals_mask = []
480
+ for i in range(0, B*T_new, chunk_size):
481
+ output = self.base_model.infer(x_resize_new.view(B*T_new, -1, H_resize, W_resize)[i:i+chunk_size])
482
+ metric_depth.append(output['depth'])
483
+ intrs.append(output['intrinsics'])
484
+ unc_metric.append(output['mask_prob'])
485
+ mask.append(output['mask'])
486
+ points_map.append(output['points'])
487
+ normals_i, normals_mask_i = utils3d.torch.points_to_normals(output['points'], mask=output['mask'])
488
+ normals.append(normals_i)
489
+ normals_mask.append(normals_mask_i)
490
+
491
+ metric_depth = torch.cat(metric_depth, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
492
+ intrs = torch.cat(intrs, dim=0).view(B, T_new, 3, 3).to(x.dtype)
493
+ intrs[:,:,0,:] *= W_resize
494
+ intrs[:,:,1,:] *= H_resize
495
+ # points_map = torch.cat(points_map, dim=0)
496
+ mask = torch.cat(mask, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
497
+ # cat the normals
498
+ normals = torch.cat(normals, dim=0)
499
+ normals_mask = torch.cat(normals_mask, dim=0)
500
+
501
+ metric_depth = metric_depth.clone()
502
+ metric_depth[metric_depth == torch.inf] = 0
503
+ _depths = metric_depth[metric_depth > 0].reshape(-1)
504
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
505
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
506
+ iqr = q75 - q25
507
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
508
+ _depth_roi = torch.tensor(
509
+ [1e-1, upper_bound.item()],
510
+ dtype=metric_depth.dtype,
511
+ device=metric_depth.device
512
+ )
513
+ mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1])
514
+ mask = mask * mask_roi
515
+ mask = mask * (~(utils3d.torch.depth_edge(metric_depth, rtol=0.03, mask=mask.bool()))) * normals_mask[:,None,...]
516
+ points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T_new, 3, 3))
517
+ unc_metric = torch.cat(unc_metric, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
518
+ unc_metric *= mask
519
+ if full_point:
520
+ unc_metric = (~(utils3d.torch.depth_edge(metric_depth, rtol=0.1, mask=torch.ones_like(metric_depth).bool()))).float() * (metric_depth != 0)
521
+ if cache is not None:
522
+ assert B==1, "only support batch size 1 right now."
523
+ unc_metric = torch.cat([cache["unc_metric"], unc_metric], dim=0)
524
+ intrs = torch.cat([cache["intrs"][None], intrs], dim=1)
525
+ points_map = torch.cat([cache["points_map"].permute(0,2,3,1), points_map], dim=0)
526
+ metric_depth = torch.cat([cache["metric_depth"], metric_depth], dim=0)
527
+
528
+ if "poses_gt" in annots.keys():
529
+ intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
530
+ H_resize, W_resize, self.resolution)
531
+ else:
532
+ c2w_traj_gt = None
533
+
534
+ if "intrs_gt" in annots.keys():
535
+ intrs = annots["intrs_gt"].view(B, T, 3, 3)
536
+ fx_factor = W_resize / W
537
+ fy_factor = H_resize / H
538
+ intrs[:,:,0,:] *= fx_factor
539
+ intrs[:,:,1,:] *= fy_factor
540
+
541
+ if "depth_gt" in annots.keys():
542
+
543
+ metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
544
+ metric_depth_gt = F.interpolate(metric_depth_gt,
545
+ size=(H_resize, W_resize), mode='nearest')
546
+
547
+ _depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
548
+ q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
549
+ q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
550
+ iqr = q75 - q25
551
+ upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
552
+ _depth_roi = torch.tensor(
553
+ [1e-1, upper_bound.item()],
554
+ dtype=metric_depth_gt.dtype,
555
+ device=metric_depth_gt.device
556
+ )
557
+ mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
558
+ # if (upper_bound > 200).any():
559
+ # import pdb; pdb.set_trace()
560
+ if (kwargs.get('stage', 0) == 2):
561
+ unc_metric = ((metric_depth_gt > 0)*(mask_roi) * (unc_metric > 0.5)).float()
562
+ metric_depth_gt[metric_depth_gt > 10*q25] = 0
563
+ else:
564
+ unc_metric = ((metric_depth_gt > 0)*(mask_roi)).float()
565
+ unc_metric *= (~(utils3d.torch.depth_edge(metric_depth_gt, rtol=0.03, mask=mask_roi.bool()))).float()
566
+ # filter the sky
567
+ metric_depth_gt[metric_depth_gt > 10*q25] = 0
568
+ if "unc_metric" in annots.keys():
569
+ unc_metric_ = F.interpolate(annots["unc_metric"].permute(1,0,2,3),
570
+ size=(H_resize, W_resize), mode='nearest')
571
+ unc_metric = unc_metric * unc_metric_
572
+ if if_gt_depth:
573
+ points_map = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
574
+ metric_depth = metric_depth_gt
575
+ points_map_gt = points_map
576
+ else:
577
+ points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
578
+
579
+ # track the 3d points
580
+ ret_track = None
581
+ regular_track = True
582
+ dyn_preds, final_tracks = None, None
583
+
584
+ if "use_extr" in annots.keys():
585
+ init_pose = True
586
+ else:
587
+ init_pose = False
588
+ # set the custom vid and valid only
589
+ custom_vid = annots.get("custom_vid", False)
590
+ valid_only = annots.get("data_dir", [""])[0] == "stereo4d"
591
+ if self.training:
592
+ if (annots["vis"].sum() > 0) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3):
593
+ traj3d = annots['traj_3d']
594
+ if (kwargs.get('stage', 0)==1) and (annots.get("custom_vid", False)==False):
595
+ support_pts_q = get_track_points(H_resize, W_resize,
596
+ T, x.device, query_size=self.track_num // 2,
597
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
598
+ else:
599
+ support_pts_q = get_track_points(H_resize, W_resize,
600
+ T, x.device, query_size=random.randint(32, 256),
601
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
602
+ if pts_q is not None:
603
+ pts_q = pts_q[None,None]
604
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
605
+ metric_depth,
606
+ unc_metric.detach(), points_map, pts_q,
607
+ intrs=intrs.clone(), cache=cache,
608
+ prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
609
+ vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
610
+ cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
611
+ init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
612
+ points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
613
+ else:
614
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
615
+ metric_depth,
616
+ unc_metric.detach(), points_map, traj3d[..., :2],
617
+ intrs=intrs.clone(), cache=cache,
618
+ prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
619
+ vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
620
+ cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
621
+ init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
622
+ points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
623
+ regular_track = False
624
+
625
+
626
+ if regular_track:
627
+ if pts_q is None:
628
+ pts_q = get_track_points(H_resize, W_resize,
629
+ T, x.device, query_size=self.track_num,
630
+ support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental" if self.training else "incremental")[None]
631
+ support_pts_q = None
632
+ else:
633
+ pts_q = pts_q[None,None]
634
+ # resize the query points
635
+ pts_q[...,1] *= W_resize / W
636
+ pts_q[...,2] *= H_resize / H
637
+
638
+ if pts_q_3d is not None:
639
+ pts_q_3d = pts_q_3d[None,None]
640
+ # resize the query points
641
+ pts_q_3d[...,1] *= W_resize / W
642
+ pts_q_3d[...,2] *= H_resize / H
643
+ else:
644
+ # adjust the query with uncertainty
645
+ if (full_point==False) and (overlap_d is None):
646
+ pts_q_unc = sample_features5d(unc_metric[None], pts_q).squeeze()
647
+ pts_q = pts_q[:,:,pts_q_unc>0.5,:]
648
+ if (pts_q_unc<0.5).sum() > 0:
649
+ # pad the query points
650
+ pad_num = pts_q_unc.shape[0] - pts_q.shape[2]
651
+ # pick the random indices
652
+ indices = torch.randint(0, pts_q.shape[2], (pad_num,), device=pts_q.device)
653
+ pad_pts = indices
654
+ pts_q = torch.cat([pts_q, pts_q[:,:,pad_pts,:]], dim=-2)
655
+
656
+ support_pts_q = get_track_points(H_resize, W_resize,
657
+ T, x.device, query_size=self.track_num,
658
+ support_frame=self.support_frame,
659
+ unc_metric=unc_metric, mode="mixed")[None]
660
+
661
+ points_map[points_map>1e3] = 0
662
+ points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T, 3, 3))
663
+ ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
664
+ metric_depth,
665
+ unc_metric.detach(), points_map, pts_q,
666
+ pts_q_3d=pts_q_3d, intrs=intrs.clone(),cache=cache,
667
+ overlap_d=overlap_d, cam_gt=c2w_traj_gt if kwargs.get('stage', 0)==1 else None,
668
+ prec_fx=prec_fx, prec_fy=prec_fy, support_pts_q=support_pts_q, custom_vid=custom_vid, valid_only=valid_only,
669
+ fixed_cam=fixed_cam, query_no_BA=query_no_BA, init_pose=init_pose, iters=iters_track,
670
+ stage=kwargs.get('stage', 0), points_map_gt=points_map_gt, replace_ratio=replace_ratio)
671
+ intrs = intrs_org
672
+ points_map = point_map_org_refined
673
+ c2w_traj = ret_track["cam_pred"]
674
+
675
+ if ret_track is not None:
676
+ if ret_track["loss"] is not None:
677
+ track_loss, conf_loss, dyn_loss, vis_loss, point_map_loss, scale_loss, shift_loss, sync_loss= ret_track["loss"]
678
+
679
+ # update the cache
680
+ cache.update({"metric_depth": metric_depth, "unc_metric": unc_metric, "points_map": points_map, "intrs": intrs[0]})
681
+ # output
682
+ depth = F.interpolate(metric_depth,
683
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
684
+ points_map = F.interpolate(points_map,
685
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
686
+ unc_metric = F.interpolate(unc_metric,
687
+ size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
688
+
689
+ if self.training:
690
+
691
+ loss = track_loss + conf_loss + dyn_loss + sync_loss + vis_loss + point_map_loss + (scale_loss + shift_loss)*50
692
+ ret = {"loss": loss,
693
+ "depth_loss": point_map_loss,
694
+ "ab_loss": (scale_loss + shift_loss)*50,
695
+ "vis_loss": vis_loss, "track_loss": track_loss,
696
+ "poses_pred": c2w_traj, "dyn_preds": dyn_preds, "traj_est": final_tracks, "conf_loss": conf_loss,
697
+ "imgs_raw": imgs_raw, "rgb_tracks": rgb_tracks, "vis_est": ret_track['vis_pred'],
698
+ "depth": depth, "points_map": points_map, "unc_metric": unc_metric, "intrs": intrs, "dyn_loss": dyn_loss,
699
+ "sync_loss": sync_loss, "conf_pred": ret_track['conf_pred'], "cache": cache,
700
+ }
701
+
702
+ else:
703
+
704
+ if ret_track is not None:
705
+ traj_est = ret_track['preds']
706
+ traj_est[..., 0] *= W / W_resize
707
+ traj_est[..., 1] *= H / H_resize
708
+ vis_est = ret_track['vis_pred']
709
+ else:
710
+ traj_est = torch.zeros(B, self.track_num // 2, 3).to(x.device)
711
+ vis_est = torch.zeros(B, self.track_num // 2).to(x.device)
712
+
713
+ if intrs is not None:
714
+ intrs[..., 0, :] *= W / W_resize
715
+ intrs[..., 1, :] *= H / H_resize
716
+ ret = {"poses_pred": c2w_traj, "dyn_preds": dyn_preds,
717
+ "depth": depth, "traj_est": traj_est, "vis_est": vis_est, "imgs_raw": imgs_raw,
718
+ "rgb_tracks": rgb_tracks, "intrs": intrs, "unc_metric": unc_metric, "points_map": points_map,
719
+ "conf_pred": ret_track['conf_pred'], "cache": cache,
720
+ }
721
+
722
+ return ret
723
+
724
+
725
+
726
+
727
+ # three stages of training
728
+
729
+ # stage 1:
730
+ # gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
731
+ # Tracking and Pose as well -> based on gt depth and intrinsics
732
+ # (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
733
+
734
+ # stage 2: fixed 3D tracking
735
+ # Joint depth refiner
736
+ # input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
737
+ # estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
738
+ # ongoing two days
739
+
740
+ # stage 3: train multi windows by propagation
741
+ # 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
742
+
743
+ # types of scenarioes:
744
+ # 1. auto driving (waymo open dataset)
745
+ # 2. robot
746
+ # 3. internet ego video
747
+
748
+
749
+
750
+ # Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
751
+ # Update Variables:
752
+ # 1. 3D tracks B T N 3 xyz.
753
+ # 2. 2D tracks B T N 2 x y.
754
+ # 3. Dynamic Mask B T H W.
755
+ # 4. Camera Pose B T 4 4.
756
+ # 5. Video Depth.
757
+
758
+ # (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
759
+ # Campatiablity by product.
models/SpaTrackV2/models/__init__.py ADDED
File without changes
models/SpaTrackV2/models/blocks.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torch.cuda.amp import autocast
11
+ from einops import rearrange
12
+ import collections
13
+ from functools import partial
14
+ from itertools import repeat
15
+ import torchvision.models as tvm
16
+ from torch.utils.checkpoint import checkpoint
17
+ from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
18
+ from typing import Union, Tuple
19
+ from torch import Tensor
20
+
21
+ # From PyTorch internals
22
+ def _ntuple(n):
23
+ def parse(x):
24
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
25
+ return tuple(x)
26
+ return tuple(repeat(x, n))
27
+
28
+ return parse
29
+
30
+
31
+ def exists(val):
32
+ return val is not None
33
+
34
+
35
+ def default(val, d):
36
+ return val if exists(val) else d
37
+
38
+
39
+ to_2tuple = _ntuple(2)
40
+
41
+ class LayerScale(nn.Module):
42
+ def __init__(
43
+ self,
44
+ dim: int,
45
+ init_values: Union[float, Tensor] = 1e-5,
46
+ inplace: bool = False,
47
+ ) -> None:
48
+ super().__init__()
49
+ self.inplace = inplace
50
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
51
+
52
+ def forward(self, x: Tensor) -> Tensor:
53
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
54
+
55
+ class Mlp(nn.Module):
56
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
57
+
58
+ def __init__(
59
+ self,
60
+ in_features,
61
+ hidden_features=None,
62
+ out_features=None,
63
+ act_layer=nn.GELU,
64
+ norm_layer=None,
65
+ bias=True,
66
+ drop=0.0,
67
+ use_conv=False,
68
+ ):
69
+ super().__init__()
70
+ out_features = out_features or in_features
71
+ hidden_features = hidden_features or in_features
72
+ bias = to_2tuple(bias)
73
+ drop_probs = to_2tuple(drop)
74
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
75
+
76
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
77
+ self.act = act_layer()
78
+ self.drop1 = nn.Dropout(drop_probs[0])
79
+ self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
80
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
81
+ self.drop2 = nn.Dropout(drop_probs[1])
82
+
83
+ def forward(self, x):
84
+ x = self.fc1(x)
85
+ x = self.act(x)
86
+ x = self.drop1(x)
87
+ x = self.fc2(x)
88
+ x = self.drop2(x)
89
+ return x
90
+
91
+ class Attention(nn.Module):
92
+ def __init__(self, query_dim, context_dim=None,
93
+ num_heads=8, dim_head=48, qkv_bias=False, flash=False):
94
+ super().__init__()
95
+ inner_dim = self.inner_dim = dim_head * num_heads
96
+ context_dim = default(context_dim, query_dim)
97
+ self.scale = dim_head**-0.5
98
+ self.heads = num_heads
99
+ self.flash = flash
100
+
101
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
102
+ self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
103
+ self.to_out = nn.Linear(inner_dim, query_dim)
104
+
105
+ def forward(self, x, context=None, attn_bias=None):
106
+ B, N1, _ = x.shape
107
+ C = self.inner_dim
108
+ h = self.heads
109
+ q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
110
+ context = default(context, x)
111
+ k, v = self.to_kv(context).chunk(2, dim=-1)
112
+
113
+ N2 = context.shape[1]
114
+ k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
115
+ v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
116
+
117
+ with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
118
+ if self.flash==False:
119
+ sim = (q @ k.transpose(-2, -1)) * self.scale
120
+ if attn_bias is not None:
121
+ sim = sim + attn_bias
122
+ if sim.abs().max()>1e2:
123
+ import pdb; pdb.set_trace()
124
+ attn = sim.softmax(dim=-1)
125
+ x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
126
+ else:
127
+ input_args = [x.contiguous() for x in [q, k, v]]
128
+ x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
129
+
130
+ if self.to_out.bias.dtype != x.dtype:
131
+ x = x.to(self.to_out.bias.dtype)
132
+
133
+ return self.to_out(x)
134
+
135
+
136
+ class VGG19(nn.Module):
137
+ def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
138
+ super().__init__()
139
+ self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
140
+ self.amp = amp
141
+ self.amp_dtype = amp_dtype
142
+
143
+ def forward(self, x, **kwargs):
144
+ with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
145
+ feats = {}
146
+ scale = 1
147
+ for layer in self.layers:
148
+ if isinstance(layer, nn.MaxPool2d):
149
+ feats[scale] = x
150
+ scale = scale*2
151
+ x = layer(x)
152
+ return feats
153
+
154
+ class CNNandDinov2(nn.Module):
155
+ def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
156
+ super().__init__()
157
+ # in case the Internet connection is not stable, please load the DINOv2 locally
158
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
159
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
160
+
161
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
162
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
163
+
164
+
165
+ cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
166
+ self.cnn = VGG19(**cnn_kwargs)
167
+ self.amp = amp
168
+ self.amp_dtype = amp_dtype
169
+ if self.amp:
170
+ dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
171
+ self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
172
+
173
+
174
+ def train(self, mode: bool = True):
175
+ return self.cnn.train(mode)
176
+
177
+ def forward(self, x, upsample = False):
178
+ B,C,H,W = x.shape
179
+ feature_pyramid = self.cnn(x)
180
+
181
+ if not upsample:
182
+ with torch.no_grad():
183
+ if self.dinov2_vitl14[0].device != x.device:
184
+ self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
185
+ dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
186
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
187
+ del dinov2_features_16
188
+ feature_pyramid[16] = features_16
189
+ return feature_pyramid
190
+
191
+ class Dinov2(nn.Module):
192
+ def __init__(self, amp = True, amp_dtype = torch.float16):
193
+ super().__init__()
194
+ # in case the Internet connection is not stable, please load the DINOv2 locally
195
+ self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
196
+ 'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
197
+
198
+ state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
199
+ self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
200
+
201
+ self.amp = amp
202
+ self.amp_dtype = amp_dtype
203
+ if self.amp:
204
+ self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
205
+
206
+ def forward(self, x, upsample = False):
207
+ B,C,H,W = x.shape
208
+ mean_ = torch.tensor([0.485, 0.456, 0.406],
209
+ device=x.device).view(1, 3, 1, 1)
210
+ std_ = torch.tensor([0.229, 0.224, 0.225],
211
+ device=x.device).view(1, 3, 1, 1)
212
+ x = (x+1)/2
213
+ x = (x - mean_)/std_
214
+ h_re, w_re = 560, 560
215
+ x_resize = F.interpolate(x, size=(h_re, w_re),
216
+ mode='bilinear', align_corners=True)
217
+ if not upsample:
218
+ with torch.no_grad():
219
+ dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
220
+ features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
221
+ del dinov2_features_16
222
+ features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
223
+ return features_16
224
+
225
+ class AttnBlock(nn.Module):
226
+ """
227
+ A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
228
+ """
229
+
230
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
231
+ flash=False, ckpt_fwd=False, debug=False, **block_kwargs):
232
+ super().__init__()
233
+ self.debug=debug
234
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
235
+ self.flash=flash
236
+
237
+ self.attn = Attention(
238
+ hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
239
+ **block_kwargs
240
+ )
241
+ self.ls = LayerScale(hidden_size, init_values=0.005)
242
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
243
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
244
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
245
+ self.mlp = Mlp(
246
+ in_features=hidden_size,
247
+ hidden_features=mlp_hidden_dim,
248
+ act_layer=approx_gelu,
249
+ )
250
+ self.ckpt_fwd = ckpt_fwd
251
+ def forward(self, x):
252
+ if self.debug:
253
+ print(x.max(), x.min(), x.mean())
254
+ if self.ckpt_fwd:
255
+ x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
256
+ else:
257
+ x = x + self.attn(self.norm1(x))
258
+
259
+ x = x + self.ls(self.mlp(self.norm2(x)))
260
+ return x
261
+
262
+ class CrossAttnBlock(nn.Module):
263
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, head_dim=48,
264
+ flash=False, ckpt_fwd=False, **block_kwargs):
265
+ super().__init__()
266
+ self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
267
+ self.norm_context = nn.LayerNorm(hidden_size)
268
+
269
+ self.cross_attn = Attention(
270
+ hidden_size, context_dim=context_dim, dim_head=head_dim,
271
+ num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash,
272
+ )
273
+
274
+ self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
275
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
276
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
277
+ self.mlp = Mlp(
278
+ in_features=hidden_size,
279
+ hidden_features=mlp_hidden_dim,
280
+ act_layer=approx_gelu,
281
+ drop=0,
282
+ )
283
+ self.ckpt_fwd = ckpt_fwd
284
+ def forward(self, x, context):
285
+ if self.ckpt_fwd:
286
+ with autocast():
287
+ x = x + checkpoint(self.cross_attn,
288
+ self.norm1(x), self.norm_context(context), use_reentrant=False)
289
+ else:
290
+ with autocast():
291
+ x = x + self.cross_attn(
292
+ self.norm1(x), self.norm_context(context)
293
+ )
294
+ x = x + self.mlp(self.norm2(x))
295
+ return x
296
+
297
+
298
+ def bilinear_sampler(img, coords, mode="bilinear", mask=False):
299
+ """Wrapper for grid_sample, uses pixel coordinates"""
300
+ H, W = img.shape[-2:]
301
+ xgrid, ygrid = coords.split([1, 1], dim=-1)
302
+ # go to 0,1 then 0,2 then -1,1
303
+ xgrid = 2 * xgrid / (W - 1) - 1
304
+ ygrid = 2 * ygrid / (H - 1) - 1
305
+
306
+ grid = torch.cat([xgrid, ygrid], dim=-1)
307
+ img = F.grid_sample(img, grid, align_corners=True, mode=mode)
308
+
309
+ if mask:
310
+ mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
311
+ return img, mask.float()
312
+
313
+ return img
314
+
315
+
316
+ class CorrBlock:
317
+ def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
318
+ B, S, C, H_prev, W_prev = fmaps.shape
319
+ self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
320
+
321
+ self.num_levels = num_levels
322
+ self.radius = radius
323
+ self.fmaps_pyramid = []
324
+ self.depth_pyramid = []
325
+ self.fmaps_pyramid.append(fmaps)
326
+ if depths_dnG is not None:
327
+ self.depth_pyramid.append(depths_dnG)
328
+ for i in range(self.num_levels - 1):
329
+ if depths_dnG is not None:
330
+ depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
331
+ depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
332
+ _, _, H, W = depths_dnG_.shape
333
+ depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
334
+ self.depth_pyramid.append(depths_dnG)
335
+ fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
336
+ fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
337
+ _, _, H, W = fmaps_.shape
338
+ fmaps = fmaps_.reshape(B, S, C, H, W)
339
+ H_prev = H
340
+ W_prev = W
341
+ self.fmaps_pyramid.append(fmaps)
342
+
343
+ def sample(self, coords):
344
+ r = self.radius
345
+ B, S, N, D = coords.shape
346
+ assert D == 2
347
+
348
+ H, W = self.H, self.W
349
+ out_pyramid = []
350
+ for i in range(self.num_levels):
351
+ corrs = self.corrs_pyramid[i] # B, S, N, H, W
352
+ _, _, _, H, W = corrs.shape
353
+
354
+ dx = torch.linspace(-r, r, 2 * r + 1)
355
+ dy = torch.linspace(-r, r, 2 * r + 1)
356
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
357
+ coords.device
358
+ )
359
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
360
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
361
+ coords_lvl = centroid_lvl + delta_lvl
362
+ corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
363
+ corrs = corrs.view(B, S, N, -1)
364
+ out_pyramid.append(corrs)
365
+
366
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
367
+ return out.contiguous().float()
368
+
369
+ def corr(self, targets):
370
+ B, S, N, C = targets.shape
371
+ assert C == self.C
372
+ assert S == self.S
373
+
374
+ fmap1 = targets
375
+
376
+ self.corrs_pyramid = []
377
+ for fmaps in self.fmaps_pyramid:
378
+ _, _, _, H, W = fmaps.shape
379
+ fmap2s = fmaps.view(B, S, C, H * W)
380
+ corrs = torch.matmul(fmap1, fmap2s)
381
+ corrs = corrs.view(B, S, N, H, W)
382
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
383
+ self.corrs_pyramid.append(corrs)
384
+
385
+ def corr_sample(self, targets, coords, coords_dp=None):
386
+ B, S, N, C = targets.shape
387
+ r = self.radius
388
+ Dim_c = (2*r+1)**2
389
+ assert C == self.C
390
+ assert S == self.S
391
+
392
+ out_pyramid = []
393
+ out_pyramid_dp = []
394
+ for i in range(self.num_levels):
395
+ dx = torch.linspace(-r, r, 2 * r + 1)
396
+ dy = torch.linspace(-r, r, 2 * r + 1)
397
+ delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
398
+ coords.device
399
+ )
400
+ centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
401
+ delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
402
+ coords_lvl = centroid_lvl + delta_lvl
403
+ fmaps = self.fmaps_pyramid[i]
404
+ _, _, _, H, W = fmaps.shape
405
+ fmap2s = fmaps.view(B*S, C, H, W)
406
+ if len(self.depth_pyramid)>0:
407
+ depths_dnG_i = self.depth_pyramid[i]
408
+ depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
409
+ dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
410
+ dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
411
+ out_pyramid_dp.append(dp_corrs)
412
+ fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
413
+ fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
414
+ corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
415
+ corrs = corrs / torch.sqrt(torch.tensor(C).float())
416
+ corrs = corrs.view(B, S, N, -1)
417
+ out_pyramid.append(corrs)
418
+
419
+ out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
420
+ if len(self.depth_pyramid)>0:
421
+ out_dp = torch.cat(out_pyramid_dp, dim=-1)
422
+ self.fcorrD = out_dp.contiguous().float()
423
+ else:
424
+ self.fcorrD = torch.zeros_like(out).contiguous().float()
425
+ return out.contiguous().float()
426
+
427
+
428
+ class EUpdateFormer(nn.Module):
429
+ """
430
+ Transformer model that updates track estimates.
431
+ """
432
+
433
+ def __init__(
434
+ self,
435
+ space_depth=12,
436
+ time_depth=12,
437
+ input_dim=320,
438
+ hidden_size=384,
439
+ num_heads=8,
440
+ output_dim=130,
441
+ mlp_ratio=4.0,
442
+ vq_depth=3,
443
+ add_space_attn=True,
444
+ add_time_attn=True,
445
+ flash=True
446
+ ):
447
+ super().__init__()
448
+ self.out_channels = 2
449
+ self.num_heads = num_heads
450
+ self.hidden_size = hidden_size
451
+ self.add_space_attn = add_space_attn
452
+ self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
453
+ self.flash = flash
454
+ self.flow_head = nn.Sequential(
455
+ nn.Linear(hidden_size, output_dim, bias=True),
456
+ nn.ReLU(inplace=True),
457
+ nn.Linear(output_dim, output_dim, bias=True),
458
+ nn.ReLU(inplace=True),
459
+ nn.Linear(output_dim, output_dim, bias=True)
460
+ )
461
+ self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
462
+ cfg = xLSTMBlockStackConfig(
463
+ mlstm_block=mLSTMBlockConfig(
464
+ mlstm=mLSTMLayerConfig(
465
+ conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
466
+ )
467
+ ),
468
+ slstm_block=sLSTMBlockConfig(
469
+ slstm=sLSTMLayerConfig(
470
+ backend="cuda",
471
+ num_heads=4,
472
+ conv1d_kernel_size=4,
473
+ bias_init="powerlaw_blockdependent",
474
+ ),
475
+ feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
476
+ ),
477
+ context_length=50,
478
+ num_blocks=7,
479
+ embedding_dim=384,
480
+ slstm_at=[1],
481
+
482
+ )
483
+ self.xlstm_fwd = xLSTMBlockStack(cfg)
484
+ self.xlstm_bwd = xLSTMBlockStack(cfg)
485
+
486
+ self.initialize_weights()
487
+
488
+ def initialize_weights(self):
489
+ def _basic_init(module):
490
+ if isinstance(module, nn.Linear):
491
+ torch.nn.init.xavier_uniform_(module.weight)
492
+ if module.bias is not None:
493
+ nn.init.constant_(module.bias, 0)
494
+
495
+ self.apply(_basic_init)
496
+
497
+ def forward(self,
498
+ input_tensor,
499
+ track_mask=None):
500
+ """ Updating with Transformer
501
+
502
+ Args:
503
+ input_tensor: B, N, T, C
504
+ arap_embed: B, N, T, C
505
+ """
506
+ B, N, T, C = input_tensor.shape
507
+ x = self.input_transform(input_tensor)
508
+
509
+ track_mask = track_mask.permute(0,2,1,3).float()
510
+ fwd_x = x*track_mask
511
+ bwd_x = x.flip(2)*track_mask.flip(2)
512
+ feat_fwd = self.xlstm_fwd(self.norm(fwd_x.view(B*N, T, -1)))
513
+ feat_bwd = self.xlstm_bwd(self.norm(bwd_x.view(B*N, T, -1)))
514
+ feat = (feat_bwd.flip(1) + feat_fwd).view(B, N, T, -1)
515
+
516
+ flow = self.flow_head(feat)
517
+
518
+ return flow[..., :2], flow[..., 2:]
519
+
models/SpaTrackV2/models/camera_transform.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Adapted from https://github.com/amyxlase/relpose-plus-plus
9
+
10
+ import torch
11
+ import numpy as np
12
+ import math
13
+
14
+
15
+
16
+
17
+ def bbox_xyxy_to_xywh(xyxy):
18
+ wh = xyxy[2:] - xyxy[:2]
19
+ xywh = np.concatenate([xyxy[:2], wh])
20
+ return xywh
21
+
22
+
23
+ def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
24
+ focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
25
+
26
+ principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
27
+
28
+ focal_length, principal_point_cropped = _convert_pixels_to_ndc(
29
+ focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
30
+ )
31
+
32
+ return focal_length, principal_point_cropped
33
+
34
+
35
+ def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
36
+ focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
37
+
38
+ # now scale and convert from pixels to NDC
39
+ image_size_wh_output = new_size_wh.float()
40
+ scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
41
+ focal_length_px_scaled = focal_length_px * scale
42
+ principal_point_px_scaled = principal_point_px * scale
43
+
44
+ focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
45
+ focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
46
+ )
47
+ return focal_length_scaled, principal_point_scaled
48
+
49
+
50
+ def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
51
+ half_image_size = image_size_wh / 2
52
+ rescale = half_image_size.min()
53
+ principal_point_px = half_image_size - principal_point * rescale
54
+ focal_length_px = focal_length * rescale
55
+ return focal_length_px, principal_point_px
56
+
57
+
58
+ def _convert_pixels_to_ndc(
59
+ focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
60
+ ):
61
+ half_image_size = image_size_wh / 2
62
+ rescale = half_image_size.min()
63
+ principal_point = (half_image_size - principal_point_px) / rescale
64
+ focal_length = focal_length_px / rescale
65
+ return focal_length, principal_point
66
+
67
+
68
+ def normalize_cameras(
69
+ cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
70
+ pose_mode="C2W"
71
+ ):
72
+ """
73
+ Normalizes cameras such that
74
+ (1) the optical axes point to the origin and the average distance to the origin is 1
75
+ (2) the first camera is the origin
76
+ (3) the translation vector is normalized
77
+
78
+ TODO: some transforms overlap with others. no need to do so many transforms
79
+ Args:
80
+ cameras (List[camera]).
81
+ """
82
+ # Let distance from first camera to origin be unit
83
+ new_cameras = cameras.clone()
84
+ scale = 1.0
85
+
86
+ if compute_optical:
87
+ new_cameras, points = compute_optical_transform(new_cameras, points=points)
88
+ if first_camera:
89
+ new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
90
+ if normalize_trans:
91
+ new_cameras, points, scale = normalize_translation(new_cameras,
92
+ points=points, max_norm=max_norm)
93
+ return new_cameras, points, scale
94
+
95
+
96
+ def compute_optical_transform(new_cameras, points=None):
97
+ """
98
+ adapted from https://github.com/amyxlase/relpose-plus-plus
99
+ """
100
+
101
+ new_transform = new_cameras.get_world_to_view_transform()
102
+ p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
103
+ t = Translate(p_intersect)
104
+ scale = dist.squeeze()[0]
105
+
106
+ if points is not None:
107
+ points = t.inverse().transform_points(points)
108
+ points = points / scale
109
+
110
+ # Degenerate case
111
+ if scale == 0:
112
+ scale = torch.norm(new_cameras.T, dim=(0, 1))
113
+ scale = torch.sqrt(scale)
114
+ new_cameras.T = new_cameras.T / scale
115
+ else:
116
+ new_matrix = t.compose(new_transform).get_matrix()
117
+ new_cameras.R = new_matrix[:, :3, :3]
118
+ new_cameras.T = new_matrix[:, 3, :3] / scale
119
+
120
+ return new_cameras, points
121
+
122
+
123
+ def compute_optical_axis_intersection(cameras):
124
+ centers = cameras.get_camera_center()
125
+ principal_points = cameras.principal_point
126
+
127
+ one_vec = torch.ones((len(cameras), 1))
128
+ optical_axis = torch.cat((principal_points, one_vec), -1)
129
+
130
+ pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
131
+
132
+ pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
133
+
134
+ directions = pp2 - centers
135
+ centers = centers.unsqueeze(0).unsqueeze(0)
136
+ directions = directions.unsqueeze(0).unsqueeze(0)
137
+
138
+ p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
139
+
140
+ p_intersect = p_intersect.squeeze().unsqueeze(0)
141
+ dist = (p_intersect - centers).norm(dim=-1)
142
+
143
+ return p_intersect, dist, p_line_intersect, pp2, r
144
+
145
+
146
+ def intersect_skew_line_groups(p, r, mask):
147
+ # p, r both of shape (B, N, n_intersected_lines, 3)
148
+ # mask of shape (B, N, n_intersected_lines)
149
+ p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
150
+ _, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
151
+ intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
152
+ return p_intersect, p_line_intersect, intersect_dist_squared, r
153
+
154
+
155
+ def intersect_skew_lines_high_dim(p, r, mask=None):
156
+ # Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
157
+ dim = p.shape[-1]
158
+ # make sure the heading vectors are l2-normed
159
+ if mask is None:
160
+ mask = torch.ones_like(p[..., 0])
161
+ r = torch.nn.functional.normalize(r, dim=-1)
162
+
163
+ eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
164
+ I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
165
+ sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
166
+ p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
167
+
168
+ if torch.any(torch.isnan(p_intersect)):
169
+ print(p_intersect)
170
+ raise ValueError(f"p_intersect is NaN")
171
+
172
+ return p_intersect, r
173
+
174
+
175
+ def _point_line_distance(p1, r1, p2):
176
+ df = p2 - p1
177
+ proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
178
+ line_pt_nearest = p2 - proj_vector
179
+ d = (proj_vector).norm(dim=-1)
180
+ return d, line_pt_nearest
181
+
182
+
183
+ def first_camera_transform(cameras, rotation_only=False,
184
+ points=None, pose_mode="C2W"):
185
+ """
186
+ Transform so that the first camera is the origin
187
+ """
188
+
189
+ new_cameras = cameras.clone()
190
+ # new_transform = new_cameras.get_world_to_view_transform()
191
+
192
+ R = cameras.R
193
+ T = cameras.T
194
+ Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
195
+ Tran_M = torch.cat([Tran_M,
196
+ torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
197
+ if pose_mode == "C2W":
198
+ Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
199
+ elif pose_mode == "W2C":
200
+ Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
201
+
202
+ if False:
203
+ tR = Rotate(new_cameras.R[0].unsqueeze(0))
204
+ if rotation_only:
205
+ t = tR.inverse()
206
+ else:
207
+ tT = Translate(new_cameras.T[0].unsqueeze(0))
208
+ t = tR.compose(tT).inverse()
209
+
210
+ if points is not None:
211
+ points = t.inverse().transform_points(points)
212
+
213
+ if pose_mode == "C2W":
214
+ new_matrix = new_transform.compose(t).get_matrix()
215
+ else:
216
+ import ipdb; ipdb.set_trace()
217
+ new_matrix = t.compose(new_transform).get_matrix()
218
+
219
+ new_cameras.R = Tran_M_new[:, :3, :3]
220
+ new_cameras.T = Tran_M_new[:, :3, 3]
221
+
222
+ return new_cameras, points
223
+
224
+
225
+ def normalize_translation(new_cameras, points=None, max_norm=False):
226
+ t_gt = new_cameras.T.clone()
227
+ t_gt = t_gt[1:, :]
228
+
229
+ if max_norm:
230
+ t_gt_norm = torch.norm(t_gt, dim=(-1))
231
+ t_gt_scale = t_gt_norm.max()
232
+ if t_gt_norm.max() < 0.001:
233
+ t_gt_scale = torch.ones_like(t_gt_scale)
234
+ t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
235
+ else:
236
+ t_gt_norm = torch.norm(t_gt, dim=(0, 1))
237
+ t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
238
+ t_gt_scale = t_gt_scale / 2
239
+ if t_gt_norm.max() < 0.001:
240
+ t_gt_scale = torch.ones_like(t_gt_scale)
241
+ t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
242
+
243
+ new_cameras.T = new_cameras.T / t_gt_scale
244
+
245
+ if points is not None:
246
+ points = points / t_gt_scale
247
+
248
+ return new_cameras, points, t_gt_scale
models/SpaTrackV2/models/depth_refiner/backbone.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ---------------------------------------------------------------
2
+ # Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
3
+ #
4
+ # This work is licensed under the NVIDIA Source Code License
5
+ # ---------------------------------------------------------------
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from functools import partial
10
+
11
+ from timm.layers import DropPath, to_2tuple, trunc_normal_
12
+ from timm.models import register_model
13
+ from timm.models.vision_transformer import _cfg
14
+ import math
15
+
16
+
17
+ class Mlp(nn.Module):
18
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
19
+ super().__init__()
20
+ out_features = out_features or in_features
21
+ hidden_features = hidden_features or in_features
22
+ self.fc1 = nn.Linear(in_features, hidden_features)
23
+ self.dwconv = DWConv(hidden_features)
24
+ self.act = act_layer()
25
+ self.fc2 = nn.Linear(hidden_features, out_features)
26
+ self.drop = nn.Dropout(drop)
27
+
28
+ self.apply(self._init_weights)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+ elif isinstance(m, nn.LayerNorm):
36
+ nn.init.constant_(m.bias, 0)
37
+ nn.init.constant_(m.weight, 1.0)
38
+ elif isinstance(m, nn.Conv2d):
39
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
40
+ fan_out //= m.groups
41
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
42
+ if m.bias is not None:
43
+ m.bias.data.zero_()
44
+
45
+ def forward(self, x, H, W):
46
+ x = self.fc1(x)
47
+ x = self.dwconv(x, H, W)
48
+ x = self.act(x)
49
+ x = self.drop(x)
50
+ x = self.fc2(x)
51
+ x = self.drop(x)
52
+ return x
53
+
54
+
55
+ class Attention(nn.Module):
56
+ def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
57
+ super().__init__()
58
+ assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
59
+
60
+ self.dim = dim
61
+ self.num_heads = num_heads
62
+ head_dim = dim // num_heads
63
+ self.scale = qk_scale or head_dim ** -0.5
64
+
65
+ self.q = nn.Linear(dim, dim, bias=qkv_bias)
66
+ self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
67
+ self.attn_drop = nn.Dropout(attn_drop)
68
+ self.proj = nn.Linear(dim, dim)
69
+ self.proj_drop = nn.Dropout(proj_drop)
70
+
71
+ self.sr_ratio = sr_ratio
72
+ if sr_ratio > 1:
73
+ self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
74
+ self.norm = nn.LayerNorm(dim)
75
+
76
+ self.apply(self._init_weights)
77
+
78
+ def _init_weights(self, m):
79
+ if isinstance(m, nn.Linear):
80
+ trunc_normal_(m.weight, std=.02)
81
+ if isinstance(m, nn.Linear) and m.bias is not None:
82
+ nn.init.constant_(m.bias, 0)
83
+ elif isinstance(m, nn.LayerNorm):
84
+ nn.init.constant_(m.bias, 0)
85
+ nn.init.constant_(m.weight, 1.0)
86
+ elif isinstance(m, nn.Conv2d):
87
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
88
+ fan_out //= m.groups
89
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
90
+ if m.bias is not None:
91
+ m.bias.data.zero_()
92
+
93
+ def forward(self, x, H, W):
94
+ B, N, C = x.shape
95
+ q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
96
+
97
+ if self.sr_ratio > 1:
98
+ x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
99
+ x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
100
+ x_ = self.norm(x_)
101
+ kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
102
+ else:
103
+ kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
104
+ k, v = kv[0], kv[1]
105
+
106
+ attn = (q @ k.transpose(-2, -1)) * self.scale
107
+ attn = attn.softmax(dim=-1)
108
+ attn = self.attn_drop(attn)
109
+
110
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
111
+ x = self.proj(x)
112
+ x = self.proj_drop(x)
113
+
114
+ return x
115
+
116
+
117
+ class Block(nn.Module):
118
+
119
+ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
120
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
121
+ super().__init__()
122
+ self.norm1 = norm_layer(dim)
123
+ self.attn = Attention(
124
+ dim,
125
+ num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
126
+ attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
127
+ # NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
128
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
129
+ self.norm2 = norm_layer(dim)
130
+ mlp_hidden_dim = int(dim * mlp_ratio)
131
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
132
+
133
+ self.apply(self._init_weights)
134
+
135
+ def _init_weights(self, m):
136
+ if isinstance(m, nn.Linear):
137
+ trunc_normal_(m.weight, std=.02)
138
+ if isinstance(m, nn.Linear) and m.bias is not None:
139
+ nn.init.constant_(m.bias, 0)
140
+ elif isinstance(m, nn.LayerNorm):
141
+ nn.init.constant_(m.bias, 0)
142
+ nn.init.constant_(m.weight, 1.0)
143
+ elif isinstance(m, nn.Conv2d):
144
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
145
+ fan_out //= m.groups
146
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
147
+ if m.bias is not None:
148
+ m.bias.data.zero_()
149
+
150
+ def forward(self, x, H, W):
151
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
152
+ x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
153
+
154
+ return x
155
+
156
+
157
+ class OverlapPatchEmbed(nn.Module):
158
+ """ Image to Patch Embedding
159
+ """
160
+
161
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
162
+ super().__init__()
163
+ img_size = to_2tuple(img_size)
164
+ patch_size = to_2tuple(patch_size)
165
+
166
+ self.img_size = img_size
167
+ self.patch_size = patch_size
168
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
169
+ self.num_patches = self.H * self.W
170
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
171
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
172
+ self.norm = nn.LayerNorm(embed_dim)
173
+
174
+ self.apply(self._init_weights)
175
+
176
+ def _init_weights(self, m):
177
+ if isinstance(m, nn.Linear):
178
+ trunc_normal_(m.weight, std=.02)
179
+ if isinstance(m, nn.Linear) and m.bias is not None:
180
+ nn.init.constant_(m.bias, 0)
181
+ elif isinstance(m, nn.LayerNorm):
182
+ nn.init.constant_(m.bias, 0)
183
+ nn.init.constant_(m.weight, 1.0)
184
+ elif isinstance(m, nn.Conv2d):
185
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
186
+ fan_out //= m.groups
187
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
188
+ if m.bias is not None:
189
+ m.bias.data.zero_()
190
+
191
+ def forward(self, x):
192
+ x = self.proj(x)
193
+ _, _, H, W = x.shape
194
+ x = x.flatten(2).transpose(1, 2)
195
+ x = self.norm(x)
196
+
197
+ return x, H, W
198
+
199
+
200
+
201
+
202
+ class OverlapPatchEmbed43(nn.Module):
203
+ """ Image to Patch Embedding
204
+ """
205
+
206
+ def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
207
+ super().__init__()
208
+ img_size = to_2tuple(img_size)
209
+ patch_size = to_2tuple(patch_size)
210
+
211
+ self.img_size = img_size
212
+ self.patch_size = patch_size
213
+ self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
214
+ self.num_patches = self.H * self.W
215
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
216
+ padding=(patch_size[0] // 2, patch_size[1] // 2))
217
+ self.norm = nn.LayerNorm(embed_dim)
218
+
219
+ self.apply(self._init_weights)
220
+
221
+ def _init_weights(self, m):
222
+ if isinstance(m, nn.Linear):
223
+ trunc_normal_(m.weight, std=.02)
224
+ if isinstance(m, nn.Linear) and m.bias is not None:
225
+ nn.init.constant_(m.bias, 0)
226
+ elif isinstance(m, nn.LayerNorm):
227
+ nn.init.constant_(m.bias, 0)
228
+ nn.init.constant_(m.weight, 1.0)
229
+ elif isinstance(m, nn.Conv2d):
230
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
231
+ fan_out //= m.groups
232
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
233
+ if m.bias is not None:
234
+ m.bias.data.zero_()
235
+
236
+ def forward(self, x):
237
+ if x.shape[1]==4:
238
+ x = self.proj_4c(x)
239
+ else:
240
+ x = self.proj(x)
241
+ _, _, H, W = x.shape
242
+ x = x.flatten(2).transpose(1, 2)
243
+ x = self.norm(x)
244
+
245
+ return x, H, W
246
+
247
+ class MixVisionTransformer(nn.Module):
248
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
249
+ num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
250
+ attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
251
+ depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
252
+ super().__init__()
253
+ self.num_classes = num_classes
254
+ self.depths = depths
255
+
256
+ # patch_embed 43
257
+ self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
258
+ embed_dim=embed_dims[0])
259
+ self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
260
+ embed_dim=embed_dims[1])
261
+ self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
262
+ embed_dim=embed_dims[2])
263
+ self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
264
+ embed_dim=embed_dims[3])
265
+
266
+ # transformer encoder
267
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
268
+ cur = 0
269
+ self.block1 = nn.ModuleList([Block(
270
+ dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
271
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
272
+ sr_ratio=sr_ratios[0])
273
+ for i in range(depths[0])])
274
+ self.norm1 = norm_layer(embed_dims[0])
275
+
276
+ cur += depths[0]
277
+ self.block2 = nn.ModuleList([Block(
278
+ dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
279
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
280
+ sr_ratio=sr_ratios[1])
281
+ for i in range(depths[1])])
282
+ self.norm2 = norm_layer(embed_dims[1])
283
+
284
+ cur += depths[1]
285
+ self.block3 = nn.ModuleList([Block(
286
+ dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
287
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
288
+ sr_ratio=sr_ratios[2])
289
+ for i in range(depths[2])])
290
+ self.norm3 = norm_layer(embed_dims[2])
291
+
292
+ cur += depths[2]
293
+ self.block4 = nn.ModuleList([Block(
294
+ dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
295
+ drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
296
+ sr_ratio=sr_ratios[3])
297
+ for i in range(depths[3])])
298
+ self.norm4 = norm_layer(embed_dims[3])
299
+
300
+ # classification head
301
+ # self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
302
+
303
+ self.apply(self._init_weights)
304
+
305
+ def _init_weights(self, m):
306
+ if isinstance(m, nn.Linear):
307
+ trunc_normal_(m.weight, std=.02)
308
+ if isinstance(m, nn.Linear) and m.bias is not None:
309
+ nn.init.constant_(m.bias, 0)
310
+ elif isinstance(m, nn.LayerNorm):
311
+ nn.init.constant_(m.bias, 0)
312
+ nn.init.constant_(m.weight, 1.0)
313
+ elif isinstance(m, nn.Conv2d):
314
+ fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
315
+ fan_out //= m.groups
316
+ m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
317
+ if m.bias is not None:
318
+ m.bias.data.zero_()
319
+
320
+ def init_weights(self, pretrained=None):
321
+ if isinstance(pretrained, str):
322
+ logger = get_root_logger()
323
+ load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
324
+
325
+ def reset_drop_path(self, drop_path_rate):
326
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
327
+ cur = 0
328
+ for i in range(self.depths[0]):
329
+ self.block1[i].drop_path.drop_prob = dpr[cur + i]
330
+
331
+ cur += self.depths[0]
332
+ for i in range(self.depths[1]):
333
+ self.block2[i].drop_path.drop_prob = dpr[cur + i]
334
+
335
+ cur += self.depths[1]
336
+ for i in range(self.depths[2]):
337
+ self.block3[i].drop_path.drop_prob = dpr[cur + i]
338
+
339
+ cur += self.depths[2]
340
+ for i in range(self.depths[3]):
341
+ self.block4[i].drop_path.drop_prob = dpr[cur + i]
342
+
343
+ def freeze_patch_emb(self):
344
+ self.patch_embed1.requires_grad = False
345
+
346
+ @torch.jit.ignore
347
+ def no_weight_decay(self):
348
+ return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
349
+
350
+ def get_classifier(self):
351
+ return self.head
352
+
353
+ def reset_classifier(self, num_classes, global_pool=''):
354
+ self.num_classes = num_classes
355
+ self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
356
+
357
+ def forward_features(self, x):
358
+ B = x.shape[0]
359
+ outs = []
360
+
361
+ # stage 1
362
+ x, H, W = self.patch_embed1(x)
363
+ for i, blk in enumerate(self.block1):
364
+ x = blk(x, H, W)
365
+ x = self.norm1(x)
366
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
367
+ outs.append(x)
368
+
369
+ # stage 2
370
+ x, H, W = self.patch_embed2(x)
371
+ for i, blk in enumerate(self.block2):
372
+ x = blk(x, H, W)
373
+ x = self.norm2(x)
374
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
375
+ outs.append(x)
376
+
377
+ # stage 3
378
+ x, H, W = self.patch_embed3(x)
379
+ for i, blk in enumerate(self.block3):
380
+ x = blk(x, H, W)
381
+ x = self.norm3(x)
382
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
383
+ outs.append(x)
384
+
385
+ # stage 4
386
+ x, H, W = self.patch_embed4(x)
387
+ for i, blk in enumerate(self.block4):
388
+ x = blk(x, H, W)
389
+ x = self.norm4(x)
390
+ x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
391
+ outs.append(x)
392
+
393
+ return outs
394
+
395
+ def forward(self, x):
396
+ if x.dim() == 5:
397
+ x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
398
+ x = self.forward_features(x)
399
+ # x = self.head(x)
400
+
401
+ return x
402
+
403
+
404
+ class DWConv(nn.Module):
405
+ def __init__(self, dim=768):
406
+ super(DWConv, self).__init__()
407
+ self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
408
+
409
+ def forward(self, x, H, W):
410
+ B, N, C = x.shape
411
+ x = x.transpose(1, 2).view(B, C, H, W)
412
+ x = self.dwconv(x)
413
+ x = x.flatten(2).transpose(1, 2)
414
+
415
+ return x
416
+
417
+
418
+
419
+ #@BACKBONES.register_module()
420
+ class mit_b0(MixVisionTransformer):
421
+ def __init__(self, **kwargs):
422
+ super(mit_b0, self).__init__(
423
+ patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
424
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
425
+ drop_rate=0.0, drop_path_rate=0.1)
426
+
427
+
428
+ #@BACKBONES.register_module()
429
+ class mit_b1(MixVisionTransformer):
430
+ def __init__(self, **kwargs):
431
+ super(mit_b1, self).__init__(
432
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
433
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
434
+ drop_rate=0.0, drop_path_rate=0.1)
435
+
436
+
437
+ #@BACKBONES.register_module()
438
+ class mit_b2(MixVisionTransformer):
439
+ def __init__(self, **kwargs):
440
+ super(mit_b2, self).__init__(
441
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
442
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
443
+ drop_rate=0.0, drop_path_rate=0.1)
444
+
445
+
446
+ #@BACKBONES.register_module()
447
+ class mit_b3(MixVisionTransformer):
448
+ def __init__(self, **kwargs):
449
+ super(mit_b3, self).__init__(
450
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
451
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
452
+ drop_rate=0.0, drop_path_rate=0.1)
453
+
454
+
455
+ #@BACKBONES.register_module()
456
+ class mit_b4(MixVisionTransformer):
457
+ def __init__(self, **kwargs):
458
+ super(mit_b4, self).__init__(
459
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
460
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
461
+ drop_rate=0.0, drop_path_rate=0.1)
462
+
463
+
464
+ #@BACKBONES.register_module()
465
+ class mit_b5(MixVisionTransformer):
466
+ def __init__(self, **kwargs):
467
+ super(mit_b5, self).__init__(
468
+ patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
469
+ qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
470
+ drop_rate=0.0, drop_path_rate=0.1)
471
+
472
+
models/SpaTrackV2/models/depth_refiner/decode_head.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+ # from mmcv.cnn import normal_init
7
+ # from mmcv.runner import auto_fp16, force_fp32
8
+
9
+ # from mmseg.core import build_pixel_sampler
10
+ # from mmseg.ops import resize
11
+
12
+
13
+ class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
14
+ """Base class for BaseDecodeHead.
15
+
16
+ Args:
17
+ in_channels (int|Sequence[int]): Input channels.
18
+ channels (int): Channels after modules, before conv_seg.
19
+ num_classes (int): Number of classes.
20
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
21
+ conv_cfg (dict|None): Config of conv layers. Default: None.
22
+ norm_cfg (dict|None): Config of norm layers. Default: None.
23
+ act_cfg (dict): Config of activation layers.
24
+ Default: dict(type='ReLU')
25
+ in_index (int|Sequence[int]): Input feature index. Default: -1
26
+ input_transform (str|None): Transformation type of input features.
27
+ Options: 'resize_concat', 'multiple_select', None.
28
+ 'resize_concat': Multiple feature maps will be resize to the
29
+ same size as first one and than concat together.
30
+ Usually used in FCN head of HRNet.
31
+ 'multiple_select': Multiple feature maps will be bundle into
32
+ a list and passed into decode head.
33
+ None: Only one select feature map is allowed.
34
+ Default: None.
35
+ loss_decode (dict): Config of decode loss.
36
+ Default: dict(type='CrossEntropyLoss').
37
+ ignore_index (int | None): The label index to be ignored. When using
38
+ masked BCE loss, ignore_index should be set to None. Default: 255
39
+ sampler (dict|None): The config of segmentation map sampler.
40
+ Default: None.
41
+ align_corners (bool): align_corners argument of F.interpolate.
42
+ Default: False.
43
+ """
44
+
45
+ def __init__(self,
46
+ in_channels,
47
+ channels,
48
+ *,
49
+ num_classes,
50
+ dropout_ratio=0.1,
51
+ conv_cfg=None,
52
+ norm_cfg=None,
53
+ act_cfg=dict(type='ReLU'),
54
+ in_index=-1,
55
+ input_transform=None,
56
+ loss_decode=dict(
57
+ type='CrossEntropyLoss',
58
+ use_sigmoid=False,
59
+ loss_weight=1.0),
60
+ decoder_params=None,
61
+ ignore_index=255,
62
+ sampler=None,
63
+ align_corners=False):
64
+ super(BaseDecodeHead, self).__init__()
65
+ self._init_inputs(in_channels, in_index, input_transform)
66
+ self.channels = channels
67
+ self.num_classes = num_classes
68
+ self.dropout_ratio = dropout_ratio
69
+ self.conv_cfg = conv_cfg
70
+ self.norm_cfg = norm_cfg
71
+ self.act_cfg = act_cfg
72
+ self.in_index = in_index
73
+ self.ignore_index = ignore_index
74
+ self.align_corners = align_corners
75
+
76
+ if sampler is not None:
77
+ self.sampler = build_pixel_sampler(sampler, context=self)
78
+ else:
79
+ self.sampler = None
80
+
81
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
82
+ if dropout_ratio > 0:
83
+ self.dropout = nn.Dropout2d(dropout_ratio)
84
+ else:
85
+ self.dropout = None
86
+ self.fp16_enabled = False
87
+
88
+ def extra_repr(self):
89
+ """Extra repr."""
90
+ s = f'input_transform={self.input_transform}, ' \
91
+ f'ignore_index={self.ignore_index}, ' \
92
+ f'align_corners={self.align_corners}'
93
+ return s
94
+
95
+ def _init_inputs(self, in_channels, in_index, input_transform):
96
+ """Check and initialize input transforms.
97
+
98
+ The in_channels, in_index and input_transform must match.
99
+ Specifically, when input_transform is None, only single feature map
100
+ will be selected. So in_channels and in_index must be of type int.
101
+ When input_transform
102
+
103
+ Args:
104
+ in_channels (int|Sequence[int]): Input channels.
105
+ in_index (int|Sequence[int]): Input feature index.
106
+ input_transform (str|None): Transformation type of input features.
107
+ Options: 'resize_concat', 'multiple_select', None.
108
+ 'resize_concat': Multiple feature maps will be resize to the
109
+ same size as first one and than concat together.
110
+ Usually used in FCN head of HRNet.
111
+ 'multiple_select': Multiple feature maps will be bundle into
112
+ a list and passed into decode head.
113
+ None: Only one select feature map is allowed.
114
+ """
115
+
116
+ if input_transform is not None:
117
+ assert input_transform in ['resize_concat', 'multiple_select']
118
+ self.input_transform = input_transform
119
+ self.in_index = in_index
120
+ if input_transform is not None:
121
+ assert isinstance(in_channels, (list, tuple))
122
+ assert isinstance(in_index, (list, tuple))
123
+ assert len(in_channels) == len(in_index)
124
+ if input_transform == 'resize_concat':
125
+ self.in_channels = sum(in_channels)
126
+ else:
127
+ self.in_channels = in_channels
128
+ else:
129
+ assert isinstance(in_channels, int)
130
+ assert isinstance(in_index, int)
131
+ self.in_channels = in_channels
132
+
133
+ def init_weights(self):
134
+ """Initialize weights of classification layer."""
135
+ normal_init(self.conv_seg, mean=0, std=0.01)
136
+
137
+ def _transform_inputs(self, inputs):
138
+ """Transform inputs for decoder.
139
+
140
+ Args:
141
+ inputs (list[Tensor]): List of multi-level img features.
142
+
143
+ Returns:
144
+ Tensor: The transformed inputs
145
+ """
146
+
147
+ if self.input_transform == 'resize_concat':
148
+ inputs = [inputs[i] for i in self.in_index]
149
+ upsampled_inputs = [
150
+ resize(
151
+ input=x,
152
+ size=inputs[0].shape[2:],
153
+ mode='bilinear',
154
+ align_corners=self.align_corners) for x in inputs
155
+ ]
156
+ inputs = torch.cat(upsampled_inputs, dim=1)
157
+ elif self.input_transform == 'multiple_select':
158
+ inputs = [inputs[i] for i in self.in_index]
159
+ else:
160
+ inputs = inputs[self.in_index]
161
+
162
+ return inputs
163
+
164
+ # @auto_fp16()
165
+ @abstractmethod
166
+ def forward(self, inputs):
167
+ """Placeholder of forward function."""
168
+ pass
169
+
170
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
171
+ """Forward function for training.
172
+ Args:
173
+ inputs (list[Tensor]): List of multi-level img features.
174
+ img_metas (list[dict]): List of image info dict where each dict
175
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
176
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
177
+ For details on the values of these keys see
178
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
179
+ gt_semantic_seg (Tensor): Semantic segmentation masks
180
+ used if the architecture supports semantic segmentation task.
181
+ train_cfg (dict): The training config.
182
+
183
+ Returns:
184
+ dict[str, Tensor]: a dictionary of loss components
185
+ """
186
+ seg_logits = self.forward(inputs)
187
+ losses = self.losses(seg_logits, gt_semantic_seg)
188
+ return losses
189
+
190
+ def forward_test(self, inputs, img_metas, test_cfg):
191
+ """Forward function for testing.
192
+
193
+ Args:
194
+ inputs (list[Tensor]): List of multi-level img features.
195
+ img_metas (list[dict]): List of image info dict where each dict
196
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
197
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
198
+ For details on the values of these keys see
199
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
200
+ test_cfg (dict): The testing config.
201
+
202
+ Returns:
203
+ Tensor: Output segmentation map.
204
+ """
205
+ return self.forward(inputs)
206
+
207
+ def cls_seg(self, feat):
208
+ """Classify each pixel."""
209
+ if self.dropout is not None:
210
+ feat = self.dropout(feat)
211
+ output = self.conv_seg(feat)
212
+ return output
213
+
214
+
215
+ class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
216
+ """Base class for BaseDecodeHead_clips.
217
+
218
+ Args:
219
+ in_channels (int|Sequence[int]): Input channels.
220
+ channels (int): Channels after modules, before conv_seg.
221
+ num_classes (int): Number of classes.
222
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
223
+ conv_cfg (dict|None): Config of conv layers. Default: None.
224
+ norm_cfg (dict|None): Config of norm layers. Default: None.
225
+ act_cfg (dict): Config of activation layers.
226
+ Default: dict(type='ReLU')
227
+ in_index (int|Sequence[int]): Input feature index. Default: -1
228
+ input_transform (str|None): Transformation type of input features.
229
+ Options: 'resize_concat', 'multiple_select', None.
230
+ 'resize_concat': Multiple feature maps will be resize to the
231
+ same size as first one and than concat together.
232
+ Usually used in FCN head of HRNet.
233
+ 'multiple_select': Multiple feature maps will be bundle into
234
+ a list and passed into decode head.
235
+ None: Only one select feature map is allowed.
236
+ Default: None.
237
+ loss_decode (dict): Config of decode loss.
238
+ Default: dict(type='CrossEntropyLoss').
239
+ ignore_index (int | None): The label index to be ignored. When using
240
+ masked BCE loss, ignore_index should be set to None. Default: 255
241
+ sampler (dict|None): The config of segmentation map sampler.
242
+ Default: None.
243
+ align_corners (bool): align_corners argument of F.interpolate.
244
+ Default: False.
245
+ """
246
+
247
+ def __init__(self,
248
+ in_channels,
249
+ channels,
250
+ *,
251
+ num_classes,
252
+ dropout_ratio=0.1,
253
+ conv_cfg=None,
254
+ norm_cfg=None,
255
+ act_cfg=dict(type='ReLU'),
256
+ in_index=-1,
257
+ input_transform=None,
258
+ loss_decode=dict(
259
+ type='CrossEntropyLoss',
260
+ use_sigmoid=False,
261
+ loss_weight=1.0),
262
+ decoder_params=None,
263
+ ignore_index=255,
264
+ sampler=None,
265
+ align_corners=False,
266
+ num_clips=5):
267
+ super(BaseDecodeHead_clips, self).__init__()
268
+ self._init_inputs(in_channels, in_index, input_transform)
269
+ self.channels = channels
270
+ self.num_classes = num_classes
271
+ self.dropout_ratio = dropout_ratio
272
+ self.conv_cfg = conv_cfg
273
+ self.norm_cfg = norm_cfg
274
+ self.act_cfg = act_cfg
275
+ self.in_index = in_index
276
+ self.ignore_index = ignore_index
277
+ self.align_corners = align_corners
278
+ self.num_clips=num_clips
279
+
280
+ if sampler is not None:
281
+ self.sampler = build_pixel_sampler(sampler, context=self)
282
+ else:
283
+ self.sampler = None
284
+
285
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
286
+ if dropout_ratio > 0:
287
+ self.dropout = nn.Dropout2d(dropout_ratio)
288
+ else:
289
+ self.dropout = None
290
+ self.fp16_enabled = False
291
+
292
+ def extra_repr(self):
293
+ """Extra repr."""
294
+ s = f'input_transform={self.input_transform}, ' \
295
+ f'ignore_index={self.ignore_index}, ' \
296
+ f'align_corners={self.align_corners}'
297
+ return s
298
+
299
+ def _init_inputs(self, in_channels, in_index, input_transform):
300
+ """Check and initialize input transforms.
301
+
302
+ The in_channels, in_index and input_transform must match.
303
+ Specifically, when input_transform is None, only single feature map
304
+ will be selected. So in_channels and in_index must be of type int.
305
+ When input_transform
306
+
307
+ Args:
308
+ in_channels (int|Sequence[int]): Input channels.
309
+ in_index (int|Sequence[int]): Input feature index.
310
+ input_transform (str|None): Transformation type of input features.
311
+ Options: 'resize_concat', 'multiple_select', None.
312
+ 'resize_concat': Multiple feature maps will be resize to the
313
+ same size as first one and than concat together.
314
+ Usually used in FCN head of HRNet.
315
+ 'multiple_select': Multiple feature maps will be bundle into
316
+ a list and passed into decode head.
317
+ None: Only one select feature map is allowed.
318
+ """
319
+
320
+ if input_transform is not None:
321
+ assert input_transform in ['resize_concat', 'multiple_select']
322
+ self.input_transform = input_transform
323
+ self.in_index = in_index
324
+ if input_transform is not None:
325
+ assert isinstance(in_channels, (list, tuple))
326
+ assert isinstance(in_index, (list, tuple))
327
+ assert len(in_channels) == len(in_index)
328
+ if input_transform == 'resize_concat':
329
+ self.in_channels = sum(in_channels)
330
+ else:
331
+ self.in_channels = in_channels
332
+ else:
333
+ assert isinstance(in_channels, int)
334
+ assert isinstance(in_index, int)
335
+ self.in_channels = in_channels
336
+
337
+ def init_weights(self):
338
+ """Initialize weights of classification layer."""
339
+ normal_init(self.conv_seg, mean=0, std=0.01)
340
+
341
+ def _transform_inputs(self, inputs):
342
+ """Transform inputs for decoder.
343
+
344
+ Args:
345
+ inputs (list[Tensor]): List of multi-level img features.
346
+
347
+ Returns:
348
+ Tensor: The transformed inputs
349
+ """
350
+
351
+ if self.input_transform == 'resize_concat':
352
+ inputs = [inputs[i] for i in self.in_index]
353
+ upsampled_inputs = [
354
+ resize(
355
+ input=x,
356
+ size=inputs[0].shape[2:],
357
+ mode='bilinear',
358
+ align_corners=self.align_corners) for x in inputs
359
+ ]
360
+ inputs = torch.cat(upsampled_inputs, dim=1)
361
+ elif self.input_transform == 'multiple_select':
362
+ inputs = [inputs[i] for i in self.in_index]
363
+ else:
364
+ inputs = inputs[self.in_index]
365
+
366
+ return inputs
367
+
368
+ # @auto_fp16()
369
+ @abstractmethod
370
+ def forward(self, inputs):
371
+ """Placeholder of forward function."""
372
+ pass
373
+
374
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
375
+ """Forward function for training.
376
+ Args:
377
+ inputs (list[Tensor]): List of multi-level img features.
378
+ img_metas (list[dict]): List of image info dict where each dict
379
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
380
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
381
+ For details on the values of these keys see
382
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
383
+ gt_semantic_seg (Tensor): Semantic segmentation masks
384
+ used if the architecture supports semantic segmentation task.
385
+ train_cfg (dict): The training config.
386
+
387
+ Returns:
388
+ dict[str, Tensor]: a dictionary of loss components
389
+ """
390
+ seg_logits = self.forward(inputs,batch_size, num_clips)
391
+ losses = self.losses(seg_logits, gt_semantic_seg)
392
+ return losses
393
+
394
+ def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
395
+ """Forward function for testing.
396
+
397
+ Args:
398
+ inputs (list[Tensor]): List of multi-level img features.
399
+ img_metas (list[dict]): List of image info dict where each dict
400
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
401
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
402
+ For details on the values of these keys see
403
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
404
+ test_cfg (dict): The testing config.
405
+
406
+ Returns:
407
+ Tensor: Output segmentation map.
408
+ """
409
+ return self.forward(inputs, batch_size, num_clips)
410
+
411
+ def cls_seg(self, feat):
412
+ """Classify each pixel."""
413
+ if self.dropout is not None:
414
+ feat = self.dropout(feat)
415
+ output = self.conv_seg(feat)
416
+ return output
417
+
418
+ class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
419
+ """Base class for BaseDecodeHead_clips_flow.
420
+
421
+ Args:
422
+ in_channels (int|Sequence[int]): Input channels.
423
+ channels (int): Channels after modules, before conv_seg.
424
+ num_classes (int): Number of classes.
425
+ dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
426
+ conv_cfg (dict|None): Config of conv layers. Default: None.
427
+ norm_cfg (dict|None): Config of norm layers. Default: None.
428
+ act_cfg (dict): Config of activation layers.
429
+ Default: dict(type='ReLU')
430
+ in_index (int|Sequence[int]): Input feature index. Default: -1
431
+ input_transform (str|None): Transformation type of input features.
432
+ Options: 'resize_concat', 'multiple_select', None.
433
+ 'resize_concat': Multiple feature maps will be resize to the
434
+ same size as first one and than concat together.
435
+ Usually used in FCN head of HRNet.
436
+ 'multiple_select': Multiple feature maps will be bundle into
437
+ a list and passed into decode head.
438
+ None: Only one select feature map is allowed.
439
+ Default: None.
440
+ loss_decode (dict): Config of decode loss.
441
+ Default: dict(type='CrossEntropyLoss').
442
+ ignore_index (int | None): The label index to be ignored. When using
443
+ masked BCE loss, ignore_index should be set to None. Default: 255
444
+ sampler (dict|None): The config of segmentation map sampler.
445
+ Default: None.
446
+ align_corners (bool): align_corners argument of F.interpolate.
447
+ Default: False.
448
+ """
449
+
450
+ def __init__(self,
451
+ in_channels,
452
+ channels,
453
+ *,
454
+ num_classes,
455
+ dropout_ratio=0.1,
456
+ conv_cfg=None,
457
+ norm_cfg=None,
458
+ act_cfg=dict(type='ReLU'),
459
+ in_index=-1,
460
+ input_transform=None,
461
+ loss_decode=dict(
462
+ type='CrossEntropyLoss',
463
+ use_sigmoid=False,
464
+ loss_weight=1.0),
465
+ decoder_params=None,
466
+ ignore_index=255,
467
+ sampler=None,
468
+ align_corners=False,
469
+ num_clips=5):
470
+ super(BaseDecodeHead_clips_flow, self).__init__()
471
+ self._init_inputs(in_channels, in_index, input_transform)
472
+ self.channels = channels
473
+ self.num_classes = num_classes
474
+ self.dropout_ratio = dropout_ratio
475
+ self.conv_cfg = conv_cfg
476
+ self.norm_cfg = norm_cfg
477
+ self.act_cfg = act_cfg
478
+ self.in_index = in_index
479
+ self.ignore_index = ignore_index
480
+ self.align_corners = align_corners
481
+ self.num_clips=num_clips
482
+
483
+ if sampler is not None:
484
+ self.sampler = build_pixel_sampler(sampler, context=self)
485
+ else:
486
+ self.sampler = None
487
+
488
+ self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
489
+ if dropout_ratio > 0:
490
+ self.dropout = nn.Dropout2d(dropout_ratio)
491
+ else:
492
+ self.dropout = None
493
+ self.fp16_enabled = False
494
+
495
+ def extra_repr(self):
496
+ """Extra repr."""
497
+ s = f'input_transform={self.input_transform}, ' \
498
+ f'ignore_index={self.ignore_index}, ' \
499
+ f'align_corners={self.align_corners}'
500
+ return s
501
+
502
+ def _init_inputs(self, in_channels, in_index, input_transform):
503
+ """Check and initialize input transforms.
504
+
505
+ The in_channels, in_index and input_transform must match.
506
+ Specifically, when input_transform is None, only single feature map
507
+ will be selected. So in_channels and in_index must be of type int.
508
+ When input_transform
509
+
510
+ Args:
511
+ in_channels (int|Sequence[int]): Input channels.
512
+ in_index (int|Sequence[int]): Input feature index.
513
+ input_transform (str|None): Transformation type of input features.
514
+ Options: 'resize_concat', 'multiple_select', None.
515
+ 'resize_concat': Multiple feature maps will be resize to the
516
+ same size as first one and than concat together.
517
+ Usually used in FCN head of HRNet.
518
+ 'multiple_select': Multiple feature maps will be bundle into
519
+ a list and passed into decode head.
520
+ None: Only one select feature map is allowed.
521
+ """
522
+
523
+ if input_transform is not None:
524
+ assert input_transform in ['resize_concat', 'multiple_select']
525
+ self.input_transform = input_transform
526
+ self.in_index = in_index
527
+ if input_transform is not None:
528
+ assert isinstance(in_channels, (list, tuple))
529
+ assert isinstance(in_index, (list, tuple))
530
+ assert len(in_channels) == len(in_index)
531
+ if input_transform == 'resize_concat':
532
+ self.in_channels = sum(in_channels)
533
+ else:
534
+ self.in_channels = in_channels
535
+ else:
536
+ assert isinstance(in_channels, int)
537
+ assert isinstance(in_index, int)
538
+ self.in_channels = in_channels
539
+
540
+ def init_weights(self):
541
+ """Initialize weights of classification layer."""
542
+ normal_init(self.conv_seg, mean=0, std=0.01)
543
+
544
+ def _transform_inputs(self, inputs):
545
+ """Transform inputs for decoder.
546
+
547
+ Args:
548
+ inputs (list[Tensor]): List of multi-level img features.
549
+
550
+ Returns:
551
+ Tensor: The transformed inputs
552
+ """
553
+
554
+ if self.input_transform == 'resize_concat':
555
+ inputs = [inputs[i] for i in self.in_index]
556
+ upsampled_inputs = [
557
+ resize(
558
+ input=x,
559
+ size=inputs[0].shape[2:],
560
+ mode='bilinear',
561
+ align_corners=self.align_corners) for x in inputs
562
+ ]
563
+ inputs = torch.cat(upsampled_inputs, dim=1)
564
+ elif self.input_transform == 'multiple_select':
565
+ inputs = [inputs[i] for i in self.in_index]
566
+ else:
567
+ inputs = inputs[self.in_index]
568
+
569
+ return inputs
570
+
571
+ # @auto_fp16()
572
+ @abstractmethod
573
+ def forward(self, inputs):
574
+ """Placeholder of forward function."""
575
+ pass
576
+
577
+ def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
578
+ """Forward function for training.
579
+ Args:
580
+ inputs (list[Tensor]): List of multi-level img features.
581
+ img_metas (list[dict]): List of image info dict where each dict
582
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
583
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
584
+ For details on the values of these keys see
585
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
586
+ gt_semantic_seg (Tensor): Semantic segmentation masks
587
+ used if the architecture supports semantic segmentation task.
588
+ train_cfg (dict): The training config.
589
+
590
+ Returns:
591
+ dict[str, Tensor]: a dictionary of loss components
592
+ """
593
+ seg_logits = self.forward(inputs,batch_size, num_clips,img)
594
+ losses = self.losses(seg_logits, gt_semantic_seg)
595
+ return losses
596
+
597
+ def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
598
+ """Forward function for testing.
599
+
600
+ Args:
601
+ inputs (list[Tensor]): List of multi-level img features.
602
+ img_metas (list[dict]): List of image info dict where each dict
603
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
604
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
605
+ For details on the values of these keys see
606
+ `mmseg/datasets/pipelines/formatting.py:Collect`.
607
+ test_cfg (dict): The testing config.
608
+
609
+ Returns:
610
+ Tensor: Output segmentation map.
611
+ """
612
+ return self.forward(inputs, batch_size, num_clips,img)
613
+
614
+ def cls_seg(self, feat):
615
+ """Classify each pixel."""
616
+ if self.dropout is not None:
617
+ feat = self.dropout(feat)
618
+ output = self.conv_seg(feat)
619
+ return output
models/SpaTrackV2/models/depth_refiner/depth_refiner.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
5
+ from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
6
+ from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
7
+ from einops import rearrange
8
+ class TrackStablizer(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ self.backbone = mit_b3()
13
+
14
+ old_conv = self.backbone.patch_embed1.proj
15
+ new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
16
+
17
+ new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
18
+ self.backbone.patch_embed1.proj = new_conv
19
+
20
+ self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
21
+ in_index=[0, 1, 2, 3],
22
+ feature_strides=[4, 8, 16, 32],
23
+ channels=128,
24
+ dropout_ratio=0.1,
25
+ num_classes=1,
26
+ align_corners=False,
27
+ decoder_params=dict(embed_dim=256, depths=4),
28
+ num_clips=16,
29
+ norm_cfg = dict(type='SyncBN', requires_grad=True))
30
+
31
+ self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
32
+ nn.ReLU(inplace=True))
33
+ self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
34
+ nn.ReLU(inplace=True))
35
+ self.success = False
36
+ self.x = None
37
+
38
+ def buffer_forward(self, inputs, num_clips=16):
39
+ """
40
+ buffer forward for getting the pointmap and image features
41
+ """
42
+ B, T, C, H, W = inputs.shape
43
+ self.x = self.backbone(inputs)
44
+ scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
45
+ self.success = True
46
+ return scale, shift
47
+
48
+ def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
49
+
50
+ """
51
+ Args:
52
+ inputs: [B, T, C, H, W], RGB + PointMap + Mask
53
+ tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
54
+ num_clips: int, number of clips to use
55
+ """
56
+ B, T, C, H, W = inputs.shape
57
+ edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
58
+ edge_feat1 = self.edge_conv1(edge_feat)
59
+
60
+ if not self.success:
61
+ scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
62
+ self.success = True
63
+ update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
64
+ else:
65
+ update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
66
+
67
+ return update
68
+
69
+ def reset_success(self):
70
+ self.success = False
71
+ self.x = None
72
+ self.Track_Stabilizer.reset_success()
73
+
74
+
75
+ if __name__ == "__main__":
76
+ # Create test input tensors
77
+ batch_size = 1
78
+ seq_len = 16
79
+ channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
80
+ height = 384
81
+ width = 512
82
+
83
+ # Create random input tensor with shape [B, T, C, H, W]
84
+ inputs = torch.randn(batch_size, seq_len, channels, height, width)
85
+
86
+ # Create random tracks
87
+ tracks = torch.randn(batch_size, seq_len, 1024, 4)
88
+
89
+ # Create random test images
90
+ test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
91
+
92
+ # Initialize model and move to GPU
93
+ model = TrackStablizer().cuda()
94
+
95
+ # Move inputs to GPU and run forward pass
96
+ inputs = inputs.cuda()
97
+ tracks = tracks.cuda()
98
+ outputs = model.buffer_forward(inputs, num_clips=seq_len)
99
+ import time
100
+ start_time = time.time()
101
+ outputs = model(inputs, tracks, num_clips=seq_len)
102
+ end_time = time.time()
103
+ print(f"Time taken: {end_time - start_time} seconds")
104
+ import pdb; pdb.set_trace()
105
+ # # Print shapes for verification
106
+ # print(f"Input shape: {inputs.shape}")
107
+ # print(f"Output shape: {outputs.shape}")
108
+
109
+ # # Basic tests
110
+ # assert outputs.shape[0] == batch_size, "Batch size mismatch"
111
+ # assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
112
+ # assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
113
+
114
+ # print("All tests passed!")
115
+
models/SpaTrackV2/models/depth_refiner/network.py ADDED
@@ -0,0 +1,429 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ '''
4
+ Author: Ke Xian
5
+ Email: kexian@hust.edu.cn
6
+ Date: 2020/07/20
7
+ '''
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.init as init
12
+
13
+ # ==============================================================================================================
14
+
15
+ class FTB(nn.Module):
16
+ def __init__(self, inchannels, midchannels=512):
17
+ super(FTB, self).__init__()
18
+ self.in1 = inchannels
19
+ self.mid = midchannels
20
+
21
+ self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
22
+ self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
23
+ nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
24
+ #nn.BatchNorm2d(num_features=self.mid),\
25
+ nn.ReLU(inplace=True),\
26
+ nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
27
+ self.relu = nn.ReLU(inplace=True)
28
+
29
+ self.init_params()
30
+
31
+ def forward(self, x):
32
+ x = self.conv1(x)
33
+ x = x + self.conv_branch(x)
34
+ x = self.relu(x)
35
+
36
+ return x
37
+
38
+ def init_params(self):
39
+ for m in self.modules():
40
+ if isinstance(m, nn.Conv2d):
41
+ #init.kaiming_normal_(m.weight, mode='fan_out')
42
+ init.normal_(m.weight, std=0.01)
43
+ # init.xavier_normal_(m.weight)
44
+ if m.bias is not None:
45
+ init.constant_(m.bias, 0)
46
+ elif isinstance(m, nn.ConvTranspose2d):
47
+ #init.kaiming_normal_(m.weight, mode='fan_out')
48
+ init.normal_(m.weight, std=0.01)
49
+ # init.xavier_normal_(m.weight)
50
+ if m.bias is not None:
51
+ init.constant_(m.bias, 0)
52
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
53
+ init.constant_(m.weight, 1)
54
+ init.constant_(m.bias, 0)
55
+ elif isinstance(m, nn.Linear):
56
+ init.normal_(m.weight, std=0.01)
57
+ if m.bias is not None:
58
+ init.constant_(m.bias, 0)
59
+
60
+ class ATA(nn.Module):
61
+ def __init__(self, inchannels, reduction = 8):
62
+ super(ATA, self).__init__()
63
+ self.inchannels = inchannels
64
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
65
+ self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
66
+ nn.ReLU(inplace=True),
67
+ nn.Linear(self.inchannels // reduction, self.inchannels),
68
+ nn.Sigmoid())
69
+ self.init_params()
70
+
71
+ def forward(self, low_x, high_x):
72
+ n, c, _, _ = low_x.size()
73
+ x = torch.cat([low_x, high_x], 1)
74
+ x = self.avg_pool(x)
75
+ x = x.view(n, -1)
76
+ x = self.fc(x).view(n,c,1,1)
77
+ x = low_x * x + high_x
78
+
79
+ return x
80
+
81
+ def init_params(self):
82
+ for m in self.modules():
83
+ if isinstance(m, nn.Conv2d):
84
+ #init.kaiming_normal_(m.weight, mode='fan_out')
85
+ #init.normal(m.weight, std=0.01)
86
+ init.xavier_normal_(m.weight)
87
+ if m.bias is not None:
88
+ init.constant_(m.bias, 0)
89
+ elif isinstance(m, nn.ConvTranspose2d):
90
+ #init.kaiming_normal_(m.weight, mode='fan_out')
91
+ #init.normal_(m.weight, std=0.01)
92
+ init.xavier_normal_(m.weight)
93
+ if m.bias is not None:
94
+ init.constant_(m.bias, 0)
95
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
96
+ init.constant_(m.weight, 1)
97
+ init.constant_(m.bias, 0)
98
+ elif isinstance(m, nn.Linear):
99
+ init.normal_(m.weight, std=0.01)
100
+ if m.bias is not None:
101
+ init.constant_(m.bias, 0)
102
+
103
+
104
+ class FFM(nn.Module):
105
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
106
+ super(FFM, self).__init__()
107
+ self.inchannels = inchannels
108
+ self.midchannels = midchannels
109
+ self.outchannels = outchannels
110
+ self.upfactor = upfactor
111
+
112
+ self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
113
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
114
+
115
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
116
+
117
+ self.init_params()
118
+ #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
119
+ #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
120
+ #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
121
+
122
+ def forward(self, low_x, high_x):
123
+ x = self.ftb1(low_x)
124
+
125
+ '''
126
+ x = torch.cat((x,high_x),1)
127
+ if x.shape[2] == 12:
128
+ x = self.p1(x)
129
+ elif x.shape[2] == 24:
130
+ x = self.p2(x)
131
+ elif x.shape[2] == 48:
132
+ x = self.p3(x)
133
+ '''
134
+ x = x + high_x ###high_x
135
+ x = self.ftb2(x)
136
+ x = self.upsample(x)
137
+
138
+ return x
139
+
140
+ def init_params(self):
141
+ for m in self.modules():
142
+ if isinstance(m, nn.Conv2d):
143
+ #init.kaiming_normal_(m.weight, mode='fan_out')
144
+ init.normal_(m.weight, std=0.01)
145
+ #init.xavier_normal_(m.weight)
146
+ if m.bias is not None:
147
+ init.constant_(m.bias, 0)
148
+ elif isinstance(m, nn.ConvTranspose2d):
149
+ #init.kaiming_normal_(m.weight, mode='fan_out')
150
+ init.normal_(m.weight, std=0.01)
151
+ #init.xavier_normal_(m.weight)
152
+ if m.bias is not None:
153
+ init.constant_(m.bias, 0)
154
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
155
+ init.constant_(m.weight, 1)
156
+ init.constant_(m.bias, 0)
157
+ elif isinstance(m, nn.Linear):
158
+ init.normal_(m.weight, std=0.01)
159
+ if m.bias is not None:
160
+ init.constant_(m.bias, 0)
161
+
162
+
163
+
164
+ class noFFM(nn.Module):
165
+ def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
166
+ super(noFFM, self).__init__()
167
+ self.inchannels = inchannels
168
+ self.midchannels = midchannels
169
+ self.outchannels = outchannels
170
+ self.upfactor = upfactor
171
+
172
+ self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
173
+
174
+ self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
175
+
176
+ self.init_params()
177
+ #self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
178
+ #self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
179
+ #self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
180
+
181
+ def forward(self, low_x, high_x):
182
+
183
+ #x = self.ftb1(low_x)
184
+ x = high_x ###high_x
185
+ x = self.ftb2(x)
186
+ x = self.upsample(x)
187
+
188
+ return x
189
+
190
+ def init_params(self):
191
+ for m in self.modules():
192
+ if isinstance(m, nn.Conv2d):
193
+ #init.kaiming_normal_(m.weight, mode='fan_out')
194
+ init.normal_(m.weight, std=0.01)
195
+ #init.xavier_normal_(m.weight)
196
+ if m.bias is not None:
197
+ init.constant_(m.bias, 0)
198
+ elif isinstance(m, nn.ConvTranspose2d):
199
+ #init.kaiming_normal_(m.weight, mode='fan_out')
200
+ init.normal_(m.weight, std=0.01)
201
+ #init.xavier_normal_(m.weight)
202
+ if m.bias is not None:
203
+ init.constant_(m.bias, 0)
204
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
205
+ init.constant_(m.weight, 1)
206
+ init.constant_(m.bias, 0)
207
+ elif isinstance(m, nn.Linear):
208
+ init.normal_(m.weight, std=0.01)
209
+ if m.bias is not None:
210
+ init.constant_(m.bias, 0)
211
+
212
+
213
+
214
+
215
+ class AO(nn.Module):
216
+ # Adaptive output module
217
+ def __init__(self, inchannels, outchannels, upfactor=2):
218
+ super(AO, self).__init__()
219
+ self.inchannels = inchannels
220
+ self.outchannels = outchannels
221
+ self.upfactor = upfactor
222
+
223
+ """
224
+ self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
225
+ nn.BatchNorm2d(num_features=self.inchannels//2),\
226
+ nn.ReLU(inplace=True),\
227
+ nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
228
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
229
+ #nn.ReLU(inplace=True)) ## get positive values
230
+ """
231
+ self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
232
+ #nn.BatchNorm2d(num_features=self.inchannels//2),\
233
+ nn.ReLU(inplace=True),\
234
+ nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
235
+ nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
236
+
237
+ #nn.ReLU(inplace=True)) ## get positive values
238
+
239
+ self.init_params()
240
+
241
+ def forward(self, x):
242
+ x = self.adapt_conv(x)
243
+ return x
244
+
245
+ def init_params(self):
246
+ for m in self.modules():
247
+ if isinstance(m, nn.Conv2d):
248
+ #init.kaiming_normal_(m.weight, mode='fan_out')
249
+ init.normal_(m.weight, std=0.01)
250
+ #init.xavier_normal_(m.weight)
251
+ if m.bias is not None:
252
+ init.constant_(m.bias, 0)
253
+ elif isinstance(m, nn.ConvTranspose2d):
254
+ #init.kaiming_normal_(m.weight, mode='fan_out')
255
+ init.normal_(m.weight, std=0.01)
256
+ #init.xavier_normal_(m.weight)
257
+ if m.bias is not None:
258
+ init.constant_(m.bias, 0)
259
+ elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
260
+ init.constant_(m.weight, 1)
261
+ init.constant_(m.bias, 0)
262
+ elif isinstance(m, nn.Linear):
263
+ init.normal_(m.weight, std=0.01)
264
+ if m.bias is not None:
265
+ init.constant_(m.bias, 0)
266
+
267
+ class ASPP(nn.Module):
268
+ def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
269
+ super(ASPP, self).__init__()
270
+ self.inchannels = inchannels
271
+ self.planes = planes
272
+ self.rates = rates
273
+ self.kernel_sizes = []
274
+ self.paddings = []
275
+ for rate in self.rates:
276
+ if rate == 1:
277
+ self.kernel_sizes.append(1)
278
+ self.paddings.append(0)
279
+ else:
280
+ self.kernel_sizes.append(3)
281
+ self.paddings.append(rate)
282
+ self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
283
+ stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
284
+ nn.ReLU(inplace=True),
285
+ nn.BatchNorm2d(num_features=self.planes)
286
+ )
287
+ self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
288
+ stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
289
+ nn.ReLU(inplace=True),
290
+ nn.BatchNorm2d(num_features=self.planes),
291
+ )
292
+ self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
293
+ stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
294
+ nn.ReLU(inplace=True),
295
+ nn.BatchNorm2d(num_features=self.planes),
296
+ )
297
+ self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
298
+ stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
299
+ nn.ReLU(inplace=True),
300
+ nn.BatchNorm2d(num_features=self.planes),
301
+ )
302
+
303
+ #self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
304
+ def forward(self, x):
305
+ x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
306
+ #x = self.conv(x)
307
+
308
+ return x
309
+
310
+ # ==============================================================================================================
311
+
312
+
313
+ class ResidualConv(nn.Module):
314
+ def __init__(self, inchannels):
315
+ super(ResidualConv, self).__init__()
316
+ #nn.BatchNorm2d
317
+ self.conv = nn.Sequential(
318
+ #nn.BatchNorm2d(num_features=inchannels),
319
+ nn.ReLU(inplace=False),
320
+ #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
321
+ #nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
322
+ nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
323
+ nn.BatchNorm2d(num_features=inchannels//2),
324
+ nn.ReLU(inplace=False),
325
+ nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
326
+ )
327
+ self.init_params()
328
+
329
+ def forward(self, x):
330
+ x = self.conv(x)+x
331
+ return x
332
+
333
+ def init_params(self):
334
+ for m in self.modules():
335
+ if isinstance(m, nn.Conv2d):
336
+ #init.kaiming_normal_(m.weight, mode='fan_out')
337
+ init.normal_(m.weight, std=0.01)
338
+ #init.xavier_normal_(m.weight)
339
+ if m.bias is not None:
340
+ init.constant_(m.bias, 0)
341
+ elif isinstance(m, nn.ConvTranspose2d):
342
+ #init.kaiming_normal_(m.weight, mode='fan_out')
343
+ init.normal_(m.weight, std=0.01)
344
+ #init.xavier_normal_(m.weight)
345
+ if m.bias is not None:
346
+ init.constant_(m.bias, 0)
347
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
348
+ init.constant_(m.weight, 1)
349
+ init.constant_(m.bias, 0)
350
+ elif isinstance(m, nn.Linear):
351
+ init.normal_(m.weight, std=0.01)
352
+ if m.bias is not None:
353
+ init.constant_(m.bias, 0)
354
+
355
+
356
+ class FeatureFusion(nn.Module):
357
+ def __init__(self, inchannels, outchannels):
358
+ super(FeatureFusion, self).__init__()
359
+ self.conv = ResidualConv(inchannels=inchannels)
360
+ #nn.BatchNorm2d
361
+ self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
362
+ nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
363
+ nn.BatchNorm2d(num_features=outchannels),
364
+ nn.ReLU(inplace=True))
365
+
366
+ def forward(self, lowfeat, highfeat):
367
+ return self.up(highfeat + self.conv(lowfeat))
368
+
369
+ def init_params(self):
370
+ for m in self.modules():
371
+ if isinstance(m, nn.Conv2d):
372
+ #init.kaiming_normal_(m.weight, mode='fan_out')
373
+ init.normal_(m.weight, std=0.01)
374
+ #init.xavier_normal_(m.weight)
375
+ if m.bias is not None:
376
+ init.constant_(m.bias, 0)
377
+ elif isinstance(m, nn.ConvTranspose2d):
378
+ #init.kaiming_normal_(m.weight, mode='fan_out')
379
+ init.normal_(m.weight, std=0.01)
380
+ #init.xavier_normal_(m.weight)
381
+ if m.bias is not None:
382
+ init.constant_(m.bias, 0)
383
+ elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
384
+ init.constant_(m.weight, 1)
385
+ init.constant_(m.bias, 0)
386
+ elif isinstance(m, nn.Linear):
387
+ init.normal_(m.weight, std=0.01)
388
+ if m.bias is not None:
389
+ init.constant_(m.bias, 0)
390
+
391
+
392
+ class SenceUnderstand(nn.Module):
393
+ def __init__(self, channels):
394
+ super(SenceUnderstand, self).__init__()
395
+ self.channels = channels
396
+ self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
397
+ nn.ReLU(inplace = True))
398
+ self.pool = nn.AdaptiveAvgPool2d(8)
399
+ self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
400
+ nn.ReLU(inplace = True))
401
+ self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
402
+ nn.ReLU(inplace=True))
403
+ self.initial_params()
404
+
405
+ def forward(self, x):
406
+ n,c,h,w = x.size()
407
+ x = self.conv1(x)
408
+ x = self.pool(x)
409
+ x = x.view(n,-1)
410
+ x = self.fc(x)
411
+ x = x.view(n, self.channels, 1, 1)
412
+ x = self.conv2(x)
413
+ x = x.repeat(1,1,h,w)
414
+ return x
415
+
416
+ def initial_params(self, dev=0.01):
417
+ for m in self.modules():
418
+ if isinstance(m, nn.Conv2d):
419
+ #print torch.sum(m.weight)
420
+ m.weight.data.normal_(0, dev)
421
+ if m.bias is not None:
422
+ m.bias.data.fill_(0)
423
+ elif isinstance(m, nn.ConvTranspose2d):
424
+ #print torch.sum(m.weight)
425
+ m.weight.data.normal_(0, dev)
426
+ if m.bias is not None:
427
+ m.bias.data.fill_(0)
428
+ elif isinstance(m, nn.Linear):
429
+ m.weight.data.normal_(0, dev)
models/SpaTrackV2/models/depth_refiner/stablilization_attention.py ADDED
@@ -0,0 +1,1187 @@