Spaces:
Build error
Build error
xiaoyuxi commited on
Commit ·
c8d9d42
0
Parent(s):
Cleaned history, reset to current state
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +36 -0
- .gitignore +69 -0
- README.md +14 -0
- _viz/viz_template.html +1778 -0
- app.py +1118 -0
- app_3rd/README.md +12 -0
- app_3rd/sam_utils/hf_sam_predictor.py +129 -0
- app_3rd/sam_utils/inference.py +123 -0
- app_3rd/spatrack_utils/infer_track.py +194 -0
- config/__init__.py +0 -0
- config/magic_infer_moge.yaml +48 -0
- examples/backpack.mp4 +3 -0
- examples/ball.mp4 +3 -0
- examples/basketball.mp4 +3 -0
- examples/biker.mp4 +3 -0
- examples/cinema_0.mp4 +3 -0
- examples/cinema_1.mp4 +3 -0
- examples/drifting.mp4 +3 -0
- examples/ego_kc1.mp4 +3 -0
- examples/ego_teaser.mp4 +3 -0
- examples/handwave.mp4 +3 -0
- examples/hockey.mp4 +3 -0
- examples/ken_block_0.mp4 +3 -0
- examples/kiss.mp4 +3 -0
- examples/kitchen.mp4 +3 -0
- examples/kitchen_egocentric.mp4 +3 -0
- examples/pillow.mp4 +3 -0
- examples/protein.mp4 +3 -0
- examples/pusht.mp4 +3 -0
- examples/robot1.mp4 +3 -0
- examples/robot2.mp4 +3 -0
- examples/robot_3.mp4 +3 -0
- examples/robot_unitree.mp4 +3 -0
- examples/running.mp4 +3 -0
- examples/teleop2.mp4 +3 -0
- examples/vertical_place.mp4 +3 -0
- models/SpaTrackV2/models/SpaTrack.py +759 -0
- models/SpaTrackV2/models/__init__.py +0 -0
- models/SpaTrackV2/models/blocks.py +519 -0
- models/SpaTrackV2/models/camera_transform.py +248 -0
- models/SpaTrackV2/models/depth_refiner/backbone.py +472 -0
- models/SpaTrackV2/models/depth_refiner/decode_head.py +619 -0
- models/SpaTrackV2/models/depth_refiner/depth_refiner.py +115 -0
- models/SpaTrackV2/models/depth_refiner/network.py +429 -0
- models/SpaTrackV2/models/depth_refiner/stablilization_attention.py +1187 -0
- models/SpaTrackV2/models/depth_refiner/stablizer.py +342 -0
- models/SpaTrackV2/models/predictor.py +153 -0
- models/SpaTrackV2/models/tracker3D/TrackRefiner.py +1478 -0
- models/SpaTrackV2/models/tracker3D/co_tracker/cotracker_base.py +418 -0
- models/SpaTrackV2/models/tracker3D/co_tracker/utils.py +929 -0
.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mp4 filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ignore the multi media
|
| 2 |
+
checkpoints
|
| 3 |
+
**/checkpoints/
|
| 4 |
+
**/temp/
|
| 5 |
+
temp
|
| 6 |
+
assets_dev
|
| 7 |
+
assets/example0/results
|
| 8 |
+
assets/example0/snowboard.npz
|
| 9 |
+
assets/example1/results
|
| 10 |
+
assets/davis_eval
|
| 11 |
+
assets/*/results
|
| 12 |
+
*gradio*
|
| 13 |
+
#
|
| 14 |
+
models/monoD/zoeDepth/ckpts/*
|
| 15 |
+
models/monoD/depth_anything/ckpts/*
|
| 16 |
+
vis_results
|
| 17 |
+
dist_encrypted
|
| 18 |
+
# remove the dependencies
|
| 19 |
+
deps
|
| 20 |
+
|
| 21 |
+
# filter the __pycache__ files
|
| 22 |
+
__pycache__/
|
| 23 |
+
/**/**/__pycache__
|
| 24 |
+
/**/__pycache__
|
| 25 |
+
|
| 26 |
+
outputs
|
| 27 |
+
scripts/lauch_exp/config
|
| 28 |
+
scripts/lauch_exp/submit_job.log
|
| 29 |
+
scripts/lauch_exp/hydra_output
|
| 30 |
+
scripts/lauch_wulan
|
| 31 |
+
scripts/custom_video
|
| 32 |
+
# ignore the visualizer
|
| 33 |
+
viser
|
| 34 |
+
viser_result
|
| 35 |
+
benchmark/results
|
| 36 |
+
benchmark
|
| 37 |
+
|
| 38 |
+
ossutil_output
|
| 39 |
+
|
| 40 |
+
prev_version
|
| 41 |
+
spat_ceres
|
| 42 |
+
wandb
|
| 43 |
+
*.log
|
| 44 |
+
seg_target.py
|
| 45 |
+
|
| 46 |
+
eval_davis.py
|
| 47 |
+
eval_multiple_gpu.py
|
| 48 |
+
eval_pose_scan.py
|
| 49 |
+
eval_single_gpu.py
|
| 50 |
+
|
| 51 |
+
infer_cam.py
|
| 52 |
+
infer_stream.py
|
| 53 |
+
|
| 54 |
+
*.egg-info/
|
| 55 |
+
**/*.egg-info
|
| 56 |
+
|
| 57 |
+
eval_kinectics.py
|
| 58 |
+
models/SpaTrackV2/datasets
|
| 59 |
+
|
| 60 |
+
scripts
|
| 61 |
+
config/fix_2d.yaml
|
| 62 |
+
|
| 63 |
+
models/SpaTrackV2/datasets
|
| 64 |
+
scripts/
|
| 65 |
+
|
| 66 |
+
models/**/build
|
| 67 |
+
models/**/dist
|
| 68 |
+
|
| 69 |
+
temp_local
|
README.md
ADDED
|
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: SpatialTrackerV2
|
| 3 |
+
emoji: ⚡️
|
| 4 |
+
colorFrom: yellow
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.31.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Official Space for SpatialTrackerV2
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
_viz/viz_template.html
ADDED
|
@@ -0,0 +1,1778 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
| 2 |
+
<html lang="en">
|
| 3 |
+
<head>
|
| 4 |
+
<meta charset="UTF-8">
|
| 5 |
+
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
| 6 |
+
<title>3D Point Cloud Visualizer</title>
|
| 7 |
+
<style>
|
| 8 |
+
:root {
|
| 9 |
+
--primary: #9b59b6; /* Brighter purple for dark mode */
|
| 10 |
+
--primary-light: #3a2e4a;
|
| 11 |
+
--secondary: #a86add;
|
| 12 |
+
--accent: #ff6e6e;
|
| 13 |
+
--bg: #1a1a1a;
|
| 14 |
+
--surface: #2c2c2c;
|
| 15 |
+
--text: #e0e0e0;
|
| 16 |
+
--text-secondary: #a0a0a0;
|
| 17 |
+
--border: #444444;
|
| 18 |
+
--shadow: rgba(0, 0, 0, 0.2);
|
| 19 |
+
--shadow-hover: rgba(0, 0, 0, 0.3);
|
| 20 |
+
|
| 21 |
+
--space-sm: 16px;
|
| 22 |
+
--space-md: 24px;
|
| 23 |
+
--space-lg: 32px;
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
body {
|
| 27 |
+
margin: 0;
|
| 28 |
+
overflow: hidden;
|
| 29 |
+
background: var(--bg);
|
| 30 |
+
color: var(--text);
|
| 31 |
+
font-family: 'Inter', sans-serif;
|
| 32 |
+
-webkit-font-smoothing: antialiased;
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
#canvas-container {
|
| 36 |
+
position: absolute;
|
| 37 |
+
width: 100%;
|
| 38 |
+
height: 100%;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
#ui-container {
|
| 42 |
+
position: absolute;
|
| 43 |
+
top: 0;
|
| 44 |
+
left: 0;
|
| 45 |
+
width: 100%;
|
| 46 |
+
height: 100%;
|
| 47 |
+
pointer-events: none;
|
| 48 |
+
z-index: 10;
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
#status-bar {
|
| 52 |
+
position: absolute;
|
| 53 |
+
top: 16px;
|
| 54 |
+
left: 16px;
|
| 55 |
+
background: rgba(30, 30, 30, 0.9);
|
| 56 |
+
padding: 8px 16px;
|
| 57 |
+
border-radius: 8px;
|
| 58 |
+
pointer-events: auto;
|
| 59 |
+
box-shadow: 0 4px 6px var(--shadow);
|
| 60 |
+
backdrop-filter: blur(4px);
|
| 61 |
+
border: 1px solid var(--border);
|
| 62 |
+
color: var(--text);
|
| 63 |
+
transition: opacity 0.5s ease, transform 0.5s ease;
|
| 64 |
+
font-weight: 500;
|
| 65 |
+
}
|
| 66 |
+
|
| 67 |
+
#status-bar.hidden {
|
| 68 |
+
opacity: 0;
|
| 69 |
+
transform: translateY(-20px);
|
| 70 |
+
pointer-events: none;
|
| 71 |
+
}
|
| 72 |
+
|
| 73 |
+
#control-panel {
|
| 74 |
+
position: absolute;
|
| 75 |
+
bottom: 16px;
|
| 76 |
+
left: 50%;
|
| 77 |
+
transform: translateX(-50%);
|
| 78 |
+
background: rgba(44, 44, 44, 0.95);
|
| 79 |
+
padding: 6px 8px;
|
| 80 |
+
border-radius: 6px;
|
| 81 |
+
display: flex;
|
| 82 |
+
gap: 8px;
|
| 83 |
+
align-items: center;
|
| 84 |
+
justify-content: space-between;
|
| 85 |
+
pointer-events: auto;
|
| 86 |
+
box-shadow: 0 4px 10px var(--shadow);
|
| 87 |
+
backdrop-filter: blur(4px);
|
| 88 |
+
border: 1px solid var(--border);
|
| 89 |
+
}
|
| 90 |
+
|
| 91 |
+
#timeline {
|
| 92 |
+
width: 150px;
|
| 93 |
+
height: 4px;
|
| 94 |
+
background: rgba(255, 255, 255, 0.1);
|
| 95 |
+
border-radius: 2px;
|
| 96 |
+
position: relative;
|
| 97 |
+
cursor: pointer;
|
| 98 |
+
}
|
| 99 |
+
|
| 100 |
+
#progress {
|
| 101 |
+
position: absolute;
|
| 102 |
+
height: 100%;
|
| 103 |
+
background: var(--primary);
|
| 104 |
+
border-radius: 2px;
|
| 105 |
+
width: 0%;
|
| 106 |
+
}
|
| 107 |
+
|
| 108 |
+
#playback-controls {
|
| 109 |
+
display: flex;
|
| 110 |
+
gap: 4px;
|
| 111 |
+
align-items: center;
|
| 112 |
+
}
|
| 113 |
+
|
| 114 |
+
button {
|
| 115 |
+
background: rgba(255, 255, 255, 0.08);
|
| 116 |
+
border: 1px solid var(--border);
|
| 117 |
+
color: var(--text);
|
| 118 |
+
padding: 4px 6px;
|
| 119 |
+
border-radius: 3px;
|
| 120 |
+
cursor: pointer;
|
| 121 |
+
display: flex;
|
| 122 |
+
align-items: center;
|
| 123 |
+
justify-content: center;
|
| 124 |
+
transition: background 0.2s, transform 0.2s;
|
| 125 |
+
font-family: 'Inter', sans-serif;
|
| 126 |
+
font-weight: 500;
|
| 127 |
+
font-size: 6px;
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
button:hover {
|
| 131 |
+
background: rgba(255, 255, 255, 0.15);
|
| 132 |
+
transform: translateY(-1px);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
button.active {
|
| 136 |
+
background: var(--primary);
|
| 137 |
+
color: white;
|
| 138 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
| 139 |
+
}
|
| 140 |
+
|
| 141 |
+
select, input {
|
| 142 |
+
background: rgba(255, 255, 255, 0.08);
|
| 143 |
+
border: 1px solid var(--border);
|
| 144 |
+
color: var(--text);
|
| 145 |
+
padding: 4px 6px;
|
| 146 |
+
border-radius: 3px;
|
| 147 |
+
cursor: pointer;
|
| 148 |
+
font-family: 'Inter', sans-serif;
|
| 149 |
+
font-size: 6px;
|
| 150 |
+
}
|
| 151 |
+
|
| 152 |
+
.icon {
|
| 153 |
+
width: 10px;
|
| 154 |
+
height: 10px;
|
| 155 |
+
fill: currentColor;
|
| 156 |
+
}
|
| 157 |
+
|
| 158 |
+
.tooltip {
|
| 159 |
+
position: absolute;
|
| 160 |
+
bottom: 100%;
|
| 161 |
+
left: 50%;
|
| 162 |
+
transform: translateX(-50%);
|
| 163 |
+
background: var(--surface);
|
| 164 |
+
color: var(--text);
|
| 165 |
+
padding: 3px 6px;
|
| 166 |
+
border-radius: 3px;
|
| 167 |
+
font-size: 7px;
|
| 168 |
+
white-space: nowrap;
|
| 169 |
+
margin-bottom: 4px;
|
| 170 |
+
opacity: 0;
|
| 171 |
+
transition: opacity 0.2s;
|
| 172 |
+
pointer-events: none;
|
| 173 |
+
box-shadow: 0 2px 4px var(--shadow);
|
| 174 |
+
border: 1px solid var(--border);
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
button:hover .tooltip {
|
| 178 |
+
opacity: 1;
|
| 179 |
+
}
|
| 180 |
+
|
| 181 |
+
#settings-panel {
|
| 182 |
+
position: absolute;
|
| 183 |
+
top: 16px;
|
| 184 |
+
right: 16px;
|
| 185 |
+
background: rgba(44, 44, 44, 0.98);
|
| 186 |
+
padding: 10px;
|
| 187 |
+
border-radius: 6px;
|
| 188 |
+
width: 195px;
|
| 189 |
+
max-height: calc(100vh - 40px);
|
| 190 |
+
overflow-y: auto;
|
| 191 |
+
pointer-events: auto;
|
| 192 |
+
box-shadow: 0 4px 15px var(--shadow);
|
| 193 |
+
backdrop-filter: blur(4px);
|
| 194 |
+
border: 1px solid var(--border);
|
| 195 |
+
display: block;
|
| 196 |
+
opacity: 1;
|
| 197 |
+
scrollbar-width: thin;
|
| 198 |
+
scrollbar-color: var(--primary-light) transparent;
|
| 199 |
+
transition: transform 0.35s ease-in-out, opacity 0.3s ease-in-out;
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
#settings-panel.is-hidden {
|
| 203 |
+
transform: translateX(calc(100% + 20px));
|
| 204 |
+
opacity: 0;
|
| 205 |
+
pointer-events: none;
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
#settings-panel::-webkit-scrollbar {
|
| 209 |
+
width: 3px;
|
| 210 |
+
}
|
| 211 |
+
|
| 212 |
+
#settings-panel::-webkit-scrollbar-track {
|
| 213 |
+
background: transparent;
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
#settings-panel::-webkit-scrollbar-thumb {
|
| 217 |
+
background-color: var(--primary-light);
|
| 218 |
+
border-radius: 3px;
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
@media (max-height: 700px) {
|
| 222 |
+
#settings-panel {
|
| 223 |
+
max-height: calc(100vh - 40px);
|
| 224 |
+
}
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
@media (max-width: 768px) {
|
| 228 |
+
#control-panel {
|
| 229 |
+
width: 90%;
|
| 230 |
+
flex-wrap: wrap;
|
| 231 |
+
justify-content: center;
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
#timeline {
|
| 235 |
+
width: 100%;
|
| 236 |
+
order: 3;
|
| 237 |
+
margin-top: 10px;
|
| 238 |
+
}
|
| 239 |
+
|
| 240 |
+
#settings-panel {
|
| 241 |
+
width: 140px;
|
| 242 |
+
right: 10px;
|
| 243 |
+
top: 10px;
|
| 244 |
+
max-height: calc(100vh - 20px);
|
| 245 |
+
}
|
| 246 |
+
}
|
| 247 |
+
|
| 248 |
+
.settings-group {
|
| 249 |
+
margin-bottom: 8px;
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
.settings-group h3 {
|
| 253 |
+
margin: 0 0 6px 0;
|
| 254 |
+
font-size: 10px;
|
| 255 |
+
font-weight: 500;
|
| 256 |
+
color: var(--text-secondary);
|
| 257 |
+
}
|
| 258 |
+
|
| 259 |
+
.slider-container {
|
| 260 |
+
display: flex;
|
| 261 |
+
align-items: center;
|
| 262 |
+
gap: 6px;
|
| 263 |
+
width: 100%;
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
.slider-container label {
|
| 267 |
+
min-width: 60px;
|
| 268 |
+
font-size: 10px;
|
| 269 |
+
flex-shrink: 0;
|
| 270 |
+
}
|
| 271 |
+
|
| 272 |
+
input[type="range"] {
|
| 273 |
+
flex: 1;
|
| 274 |
+
height: 2px;
|
| 275 |
+
-webkit-appearance: none;
|
| 276 |
+
background: rgba(255, 255, 255, 0.1);
|
| 277 |
+
border-radius: 1px;
|
| 278 |
+
min-width: 0;
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
input[type="range"]::-webkit-slider-thumb {
|
| 282 |
+
-webkit-appearance: none;
|
| 283 |
+
width: 8px;
|
| 284 |
+
height: 8px;
|
| 285 |
+
border-radius: 50%;
|
| 286 |
+
background: var(--primary);
|
| 287 |
+
cursor: pointer;
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
.toggle-switch {
|
| 291 |
+
position: relative;
|
| 292 |
+
display: inline-block;
|
| 293 |
+
width: 20px;
|
| 294 |
+
height: 10px;
|
| 295 |
+
}
|
| 296 |
+
|
| 297 |
+
.toggle-switch input {
|
| 298 |
+
opacity: 0;
|
| 299 |
+
width: 0;
|
| 300 |
+
height: 0;
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
.toggle-slider {
|
| 304 |
+
position: absolute;
|
| 305 |
+
cursor: pointer;
|
| 306 |
+
top: 0;
|
| 307 |
+
left: 0;
|
| 308 |
+
right: 0;
|
| 309 |
+
bottom: 0;
|
| 310 |
+
background: rgba(255, 255, 255, 0.1);
|
| 311 |
+
transition: .4s;
|
| 312 |
+
border-radius: 10px;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
.toggle-slider:before {
|
| 316 |
+
position: absolute;
|
| 317 |
+
content: "";
|
| 318 |
+
height: 8px;
|
| 319 |
+
width: 8px;
|
| 320 |
+
left: 1px;
|
| 321 |
+
bottom: 1px;
|
| 322 |
+
background: var(--surface);
|
| 323 |
+
border: 1px solid var(--border);
|
| 324 |
+
transition: .4s;
|
| 325 |
+
border-radius: 50%;
|
| 326 |
+
}
|
| 327 |
+
|
| 328 |
+
input:checked + .toggle-slider {
|
| 329 |
+
background: var(--primary);
|
| 330 |
+
}
|
| 331 |
+
|
| 332 |
+
input:checked + .toggle-slider:before {
|
| 333 |
+
transform: translateX(10px);
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
.checkbox-container {
|
| 337 |
+
display: flex;
|
| 338 |
+
align-items: center;
|
| 339 |
+
gap: 4px;
|
| 340 |
+
margin-bottom: 4px;
|
| 341 |
+
}
|
| 342 |
+
|
| 343 |
+
.checkbox-container label {
|
| 344 |
+
font-size: 10px;
|
| 345 |
+
cursor: pointer;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
#loading-overlay {
|
| 349 |
+
position: absolute;
|
| 350 |
+
top: 0;
|
| 351 |
+
left: 0;
|
| 352 |
+
width: 100%;
|
| 353 |
+
height: 100%;
|
| 354 |
+
background: var(--bg);
|
| 355 |
+
display: flex;
|
| 356 |
+
flex-direction: column;
|
| 357 |
+
align-items: center;
|
| 358 |
+
justify-content: center;
|
| 359 |
+
z-index: 100;
|
| 360 |
+
transition: opacity 0.5s;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
#loading-overlay.fade-out {
|
| 364 |
+
opacity: 0;
|
| 365 |
+
pointer-events: none;
|
| 366 |
+
}
|
| 367 |
+
|
| 368 |
+
.spinner {
|
| 369 |
+
width: 50px;
|
| 370 |
+
height: 50px;
|
| 371 |
+
border: 5px solid rgba(155, 89, 182, 0.2);
|
| 372 |
+
border-radius: 50%;
|
| 373 |
+
border-top-color: var(--primary);
|
| 374 |
+
animation: spin 1s ease-in-out infinite;
|
| 375 |
+
margin-bottom: 16px;
|
| 376 |
+
}
|
| 377 |
+
|
| 378 |
+
@keyframes spin {
|
| 379 |
+
to { transform: rotate(360deg); }
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
#loading-text {
|
| 383 |
+
margin-top: 16px;
|
| 384 |
+
font-size: 18px;
|
| 385 |
+
color: var(--text);
|
| 386 |
+
font-weight: 500;
|
| 387 |
+
}
|
| 388 |
+
|
| 389 |
+
#frame-counter {
|
| 390 |
+
color: var(--text-secondary);
|
| 391 |
+
font-size: 7px;
|
| 392 |
+
font-weight: 500;
|
| 393 |
+
min-width: 60px;
|
| 394 |
+
text-align: center;
|
| 395 |
+
padding: 0 4px;
|
| 396 |
+
}
|
| 397 |
+
|
| 398 |
+
.control-btn {
|
| 399 |
+
background: rgba(255, 255, 255, 0.08);
|
| 400 |
+
border: 1px solid var(--border);
|
| 401 |
+
padding: 4px 6px;
|
| 402 |
+
border-radius: 3px;
|
| 403 |
+
cursor: pointer;
|
| 404 |
+
display: flex;
|
| 405 |
+
align-items: center;
|
| 406 |
+
justify-content: center;
|
| 407 |
+
transition: all 0.2s ease;
|
| 408 |
+
font-size: 6px;
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
.control-btn:hover {
|
| 412 |
+
background: rgba(255, 255, 255, 0.15);
|
| 413 |
+
transform: translateY(-1px);
|
| 414 |
+
}
|
| 415 |
+
|
| 416 |
+
.control-btn.active {
|
| 417 |
+
background: var(--primary);
|
| 418 |
+
color: white;
|
| 419 |
+
}
|
| 420 |
+
|
| 421 |
+
.control-btn.active:hover {
|
| 422 |
+
background: var(--primary);
|
| 423 |
+
box-shadow: 0 2px 8px rgba(155, 89, 182, 0.4);
|
| 424 |
+
}
|
| 425 |
+
|
| 426 |
+
#settings-toggle-btn {
|
| 427 |
+
position: relative;
|
| 428 |
+
border-radius: 6px;
|
| 429 |
+
z-index: 20;
|
| 430 |
+
}
|
| 431 |
+
|
| 432 |
+
#settings-toggle-btn.active {
|
| 433 |
+
background: var(--primary);
|
| 434 |
+
color: white;
|
| 435 |
+
}
|
| 436 |
+
|
| 437 |
+
#status-bar,
|
| 438 |
+
#control-panel,
|
| 439 |
+
#settings-panel,
|
| 440 |
+
button,
|
| 441 |
+
input,
|
| 442 |
+
select,
|
| 443 |
+
.toggle-switch {
|
| 444 |
+
pointer-events: auto;
|
| 445 |
+
}
|
| 446 |
+
|
| 447 |
+
h2 {
|
| 448 |
+
font-size: 0.9rem;
|
| 449 |
+
font-weight: 600;
|
| 450 |
+
margin-top: 0;
|
| 451 |
+
margin-bottom: 12px;
|
| 452 |
+
color: var(--primary);
|
| 453 |
+
cursor: move;
|
| 454 |
+
user-select: none;
|
| 455 |
+
display: flex;
|
| 456 |
+
align-items: center;
|
| 457 |
+
}
|
| 458 |
+
|
| 459 |
+
.drag-handle {
|
| 460 |
+
font-size: 10px;
|
| 461 |
+
margin-right: 4px;
|
| 462 |
+
opacity: 0.6;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
h2:hover .drag-handle {
|
| 466 |
+
opacity: 1;
|
| 467 |
+
}
|
| 468 |
+
|
| 469 |
+
.loading-subtitle {
|
| 470 |
+
font-size: 7px;
|
| 471 |
+
color: var(--text-secondary);
|
| 472 |
+
margin-top: 4px;
|
| 473 |
+
}
|
| 474 |
+
|
| 475 |
+
#reset-view-btn {
|
| 476 |
+
background: var(--primary-light);
|
| 477 |
+
color: var(--primary);
|
| 478 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
| 479 |
+
font-weight: 600;
|
| 480 |
+
font-size: 9px;
|
| 481 |
+
padding: 4px 6px;
|
| 482 |
+
transition: all 0.2s;
|
| 483 |
+
}
|
| 484 |
+
|
| 485 |
+
#reset-view-btn:hover {
|
| 486 |
+
background: var(--primary);
|
| 487 |
+
color: white;
|
| 488 |
+
transform: translateY(-2px);
|
| 489 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
| 490 |
+
}
|
| 491 |
+
|
| 492 |
+
#show-settings-btn {
|
| 493 |
+
position: absolute;
|
| 494 |
+
top: 16px;
|
| 495 |
+
right: 16px;
|
| 496 |
+
z-index: 15;
|
| 497 |
+
display: none;
|
| 498 |
+
}
|
| 499 |
+
|
| 500 |
+
#settings-panel.visible {
|
| 501 |
+
display: block;
|
| 502 |
+
opacity: 1;
|
| 503 |
+
animation: slideIn 0.3s ease forwards;
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
@keyframes slideIn {
|
| 507 |
+
from {
|
| 508 |
+
transform: translateY(20px);
|
| 509 |
+
opacity: 0;
|
| 510 |
+
}
|
| 511 |
+
to {
|
| 512 |
+
transform: translateY(0);
|
| 513 |
+
opacity: 1;
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
|
| 517 |
+
.dragging {
|
| 518 |
+
opacity: 0.9;
|
| 519 |
+
box-shadow: 0 8px 20px rgba(0, 0, 0, 0.15) !important;
|
| 520 |
+
transition: none !important;
|
| 521 |
+
}
|
| 522 |
+
|
| 523 |
+
/* Tooltip for draggable element */
|
| 524 |
+
.tooltip-drag {
|
| 525 |
+
position: absolute;
|
| 526 |
+
left: 50%;
|
| 527 |
+
transform: translateX(-50%);
|
| 528 |
+
background: var(--primary);
|
| 529 |
+
color: white;
|
| 530 |
+
font-size: 9px;
|
| 531 |
+
padding: 2px 4px;
|
| 532 |
+
border-radius: 2px;
|
| 533 |
+
opacity: 0;
|
| 534 |
+
pointer-events: none;
|
| 535 |
+
transition: opacity 0.3s;
|
| 536 |
+
white-space: nowrap;
|
| 537 |
+
bottom: 100%;
|
| 538 |
+
margin-bottom: 4px;
|
| 539 |
+
}
|
| 540 |
+
|
| 541 |
+
h2:hover .tooltip-drag {
|
| 542 |
+
opacity: 1;
|
| 543 |
+
}
|
| 544 |
+
|
| 545 |
+
.btn-group {
|
| 546 |
+
display: flex;
|
| 547 |
+
margin-top: 8px;
|
| 548 |
+
}
|
| 549 |
+
|
| 550 |
+
#reset-settings-btn {
|
| 551 |
+
background: var(--primary-light);
|
| 552 |
+
color: var(--primary);
|
| 553 |
+
border: 1px solid rgba(155, 89, 182, 0.2);
|
| 554 |
+
font-weight: 600;
|
| 555 |
+
font-size: 9px;
|
| 556 |
+
padding: 4px 6px;
|
| 557 |
+
transition: all 0.2s;
|
| 558 |
+
}
|
| 559 |
+
|
| 560 |
+
#reset-settings-btn:hover {
|
| 561 |
+
background: var(--primary);
|
| 562 |
+
color: white;
|
| 563 |
+
transform: translateY(-2px);
|
| 564 |
+
box-shadow: 0 4px 8px rgba(155, 89, 182, 0.3);
|
| 565 |
+
}
|
| 566 |
+
</style>
|
| 567 |
+
</head>
|
| 568 |
+
<body>
|
| 569 |
+
<link rel="preconnect" href="https://fonts.googleapis.com">
|
| 570 |
+
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
|
| 571 |
+
<link href="https://fonts.googleapis.com/css2?family=Inter:wght@300;400;500;600;700&display=swap" rel="stylesheet">
|
| 572 |
+
|
| 573 |
+
<div id="canvas-container"></div>
|
| 574 |
+
|
| 575 |
+
<div id="ui-container">
|
| 576 |
+
<div id="status-bar">Initializing...</div>
|
| 577 |
+
|
| 578 |
+
<div id="control-panel">
|
| 579 |
+
<button id="play-pause-btn" class="control-btn">
|
| 580 |
+
<svg class="icon" viewBox="0 0 24 24">
|
| 581 |
+
<path id="play-icon" d="M8 5v14l11-7z"/>
|
| 582 |
+
<path id="pause-icon" d="M6 19h4V5H6v14zm8-14v14h4V5h-4z" style="display: none;"/>
|
| 583 |
+
</svg>
|
| 584 |
+
<span class="tooltip">Play/Pause</span>
|
| 585 |
+
</button>
|
| 586 |
+
|
| 587 |
+
<div id="timeline">
|
| 588 |
+
<div id="progress"></div>
|
| 589 |
+
</div>
|
| 590 |
+
|
| 591 |
+
<div id="frame-counter">Frame 0 / 0</div>
|
| 592 |
+
|
| 593 |
+
<div id="playback-controls">
|
| 594 |
+
<button id="speed-btn" class="control-btn">1x</button>
|
| 595 |
+
</div>
|
| 596 |
+
</div>
|
| 597 |
+
|
| 598 |
+
<div id="settings-panel">
|
| 599 |
+
<h2>
|
| 600 |
+
<span class="drag-handle">☰</span>
|
| 601 |
+
Visualization Settings
|
| 602 |
+
<button id="hide-settings-btn" class="control-btn" style="margin-left: auto; padding: 2px;" title="Hide Panel">
|
| 603 |
+
<svg class="icon" viewBox="0 0 24 24" style="width: 9px; height: 9px;">
|
| 604 |
+
<path d="M14.59 7.41L18.17 11H4v2h14.17l-3.58 3.59L16 18l6-6-6-6-1.41 1.41z"/>
|
| 605 |
+
</svg>
|
| 606 |
+
</button>
|
| 607 |
+
</h2>
|
| 608 |
+
|
| 609 |
+
<div class="settings-group">
|
| 610 |
+
<h3>Point Cloud</h3>
|
| 611 |
+
<div class="slider-container">
|
| 612 |
+
<label for="point-size">Size</label>
|
| 613 |
+
<input type="range" id="point-size" min="0.005" max="0.1" step="0.005" value="0.03">
|
| 614 |
+
</div>
|
| 615 |
+
<div class="slider-container">
|
| 616 |
+
<label for="point-opacity">Opacity</label>
|
| 617 |
+
<input type="range" id="point-opacity" min="0.1" max="1" step="0.05" value="1">
|
| 618 |
+
</div>
|
| 619 |
+
<div class="slider-container">
|
| 620 |
+
<label for="max-depth">Max Depth</label>
|
| 621 |
+
<input type="range" id="max-depth" min="0.1" max="10" step="0.2" value="100">
|
| 622 |
+
</div>
|
| 623 |
+
</div>
|
| 624 |
+
|
| 625 |
+
<div class="settings-group">
|
| 626 |
+
<h3>Trajectory</h3>
|
| 627 |
+
<div class="checkbox-container">
|
| 628 |
+
<label class="toggle-switch">
|
| 629 |
+
<input type="checkbox" id="show-trajectory" checked>
|
| 630 |
+
<span class="toggle-slider"></span>
|
| 631 |
+
</label>
|
| 632 |
+
<label for="show-trajectory">Show Trajectory</label>
|
| 633 |
+
</div>
|
| 634 |
+
<div class="checkbox-container">
|
| 635 |
+
<label class="toggle-switch">
|
| 636 |
+
<input type="checkbox" id="enable-rich-trail">
|
| 637 |
+
<span class="toggle-slider"></span>
|
| 638 |
+
</label>
|
| 639 |
+
<label for="enable-rich-trail">Visual-Rich Trail</label>
|
| 640 |
+
</div>
|
| 641 |
+
<div class="slider-container">
|
| 642 |
+
<label for="trajectory-line-width">Line Width</label>
|
| 643 |
+
<input type="range" id="trajectory-line-width" min="0.5" max="5" step="0.5" value="1.5">
|
| 644 |
+
</div>
|
| 645 |
+
<div class="slider-container">
|
| 646 |
+
<label for="trajectory-ball-size">Ball Size</label>
|
| 647 |
+
<input type="range" id="trajectory-ball-size" min="0.005" max="0.05" step="0.001" value="0.02">
|
| 648 |
+
</div>
|
| 649 |
+
<div class="slider-container">
|
| 650 |
+
<label for="trajectory-history">History Frames</label>
|
| 651 |
+
<input type="range" id="trajectory-history" min="1" max="500" step="1" value="30">
|
| 652 |
+
</div>
|
| 653 |
+
<div class="slider-container" id="tail-opacity-container" style="display: none;">
|
| 654 |
+
<label for="trajectory-fade">Tail Opacity</label>
|
| 655 |
+
<input type="range" id="trajectory-fade" min="0" max="1" step="0.05" value="0.0">
|
| 656 |
+
</div>
|
| 657 |
+
</div>
|
| 658 |
+
|
| 659 |
+
<div class="settings-group">
|
| 660 |
+
<h3>Camera</h3>
|
| 661 |
+
<div class="checkbox-container">
|
| 662 |
+
<label class="toggle-switch">
|
| 663 |
+
<input type="checkbox" id="show-camera-frustum" checked>
|
| 664 |
+
<span class="toggle-slider"></span>
|
| 665 |
+
</label>
|
| 666 |
+
<label for="show-camera-frustum">Show Camera Frustum</label>
|
| 667 |
+
</div>
|
| 668 |
+
<div class="slider-container">
|
| 669 |
+
<label for="frustum-size">Size</label>
|
| 670 |
+
<input type="range" id="frustum-size" min="0.02" max="0.5" step="0.01" value="0.2">
|
| 671 |
+
</div>
|
| 672 |
+
</div>
|
| 673 |
+
|
| 674 |
+
<div class="settings-group">
|
| 675 |
+
<div class="btn-group">
|
| 676 |
+
<button id="reset-view-btn" style="flex: 1; margin-right: 5px;">Reset View</button>
|
| 677 |
+
<button id="reset-settings-btn" style="flex: 1; margin-left: 5px;">Reset Settings</button>
|
| 678 |
+
</div>
|
| 679 |
+
</div>
|
| 680 |
+
</div>
|
| 681 |
+
|
| 682 |
+
<button id="show-settings-btn" class="control-btn" title="Show Settings">
|
| 683 |
+
<svg class="icon" viewBox="0 0 24 24">
|
| 684 |
+
<path d="M19.14,12.94c0.04-0.3,0.06-0.61,0.06-0.94c0-0.32-0.02-0.64-0.07-0.94l2.03-1.58c0.18-0.14,0.23-0.41,0.12-0.61 l-1.92-3.32c-0.12-0.22-0.37-0.29-0.59-0.22l-2.39,0.96c-0.5-0.38-1.03-0.7-1.62-0.94L14.4,2.81c-0.04-0.24-0.24-0.41-0.48-0.41 h-3.84c-0.24,0-0.43,0.17-0.47,0.41L9.25,5.35C8.66,5.59,8.12,5.92,7.63,6.29L5.24,5.33c-0.22-0.08-0.47,0-0.59,0.22L2.74,8.87 C2.62,9.08,2.66,9.34,2.86,9.48l2.03,1.58C4.84,11.36,4.8,11.69,4.8,12s0.02,0.64,0.07,0.94l-2.03,1.58 c-0.18,0.14-0.23,0.41-0.12,0.61l1.92,3.32c0.12,0.22,0.37,0.29,0.59,0.22l2.39-0.96c0.5,0.38,1.03,0.7,1.62,0.94l0.36,2.54 c0.04,0.24,0.24,0.41,0.48,0.41h3.84c0.24,0,0.44-0.17,0.47-0.41l0.36-2.54c0.59-0.24,1.13-0.56,1.62-0.94l2.39,0.96 c0.22,0.08,0.47,0,0.59-0.22l1.92-3.32c0.12-0.22,0.07-0.47-0.12-0.61L19.14,12.94z M12,15.6c-1.98,0-3.6-1.62-3.6-3.6 s1.62-3.6,3.6-3.6s3.6,1.62,3.6,3.6S13.98,15.6,12,15.6z"/>
|
| 685 |
+
</svg>
|
| 686 |
+
</button>
|
| 687 |
+
</div>
|
| 688 |
+
|
| 689 |
+
<div id="loading-overlay">
|
| 690 |
+
<!-- <div class="spinner"></div> -->
|
| 691 |
+
<div id="loading-text"></div>
|
| 692 |
+
<div class="loading-subtitle" style="font-size: medium;">Interactive Viewer of 3D Tracking</div>
|
| 693 |
+
</div>
|
| 694 |
+
|
| 695 |
+
<!-- Libraries -->
|
| 696 |
+
<script src="https://cdnjs.cloudflare.com/ajax/libs/pako/2.1.0/pako.min.js"></script>
|
| 697 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/build/three.min.js"></script>
|
| 698 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/controls/OrbitControls.js"></script>
|
| 699 |
+
<script src="https://cdn.jsdelivr.net/npm/dat.gui@0.7.7/build/dat.gui.min.js"></script>
|
| 700 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineSegmentsGeometry.js"></script>
|
| 701 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineGeometry.js"></script>
|
| 702 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineMaterial.js"></script>
|
| 703 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/LineSegments2.js"></script>
|
| 704 |
+
<script src="https://cdn.jsdelivr.net/npm/three@0.132.2/examples/js/lines/Line2.js"></script>
|
| 705 |
+
|
| 706 |
+
<script>
|
| 707 |
+
class PointCloudVisualizer {
|
| 708 |
+
constructor() {
|
| 709 |
+
this.data = null;
|
| 710 |
+
this.config = {};
|
| 711 |
+
this.currentFrame = 0;
|
| 712 |
+
this.isPlaying = false;
|
| 713 |
+
this.playbackSpeed = 1;
|
| 714 |
+
this.lastFrameTime = 0;
|
| 715 |
+
this.defaultSettings = null;
|
| 716 |
+
|
| 717 |
+
this.ui = {
|
| 718 |
+
statusBar: document.getElementById('status-bar'),
|
| 719 |
+
playPauseBtn: document.getElementById('play-pause-btn'),
|
| 720 |
+
speedBtn: document.getElementById('speed-btn'),
|
| 721 |
+
timeline: document.getElementById('timeline'),
|
| 722 |
+
progress: document.getElementById('progress'),
|
| 723 |
+
settingsPanel: document.getElementById('settings-panel'),
|
| 724 |
+
loadingOverlay: document.getElementById('loading-overlay'),
|
| 725 |
+
loadingText: document.getElementById('loading-text'),
|
| 726 |
+
settingsToggleBtn: document.getElementById('settings-toggle-btn'),
|
| 727 |
+
frameCounter: document.getElementById('frame-counter'),
|
| 728 |
+
pointSize: document.getElementById('point-size'),
|
| 729 |
+
pointOpacity: document.getElementById('point-opacity'),
|
| 730 |
+
maxDepth: document.getElementById('max-depth'),
|
| 731 |
+
showTrajectory: document.getElementById('show-trajectory'),
|
| 732 |
+
enableRichTrail: document.getElementById('enable-rich-trail'),
|
| 733 |
+
trajectoryLineWidth: document.getElementById('trajectory-line-width'),
|
| 734 |
+
trajectoryBallSize: document.getElementById('trajectory-ball-size'),
|
| 735 |
+
trajectoryHistory: document.getElementById('trajectory-history'),
|
| 736 |
+
trajectoryFade: document.getElementById('trajectory-fade'),
|
| 737 |
+
tailOpacityContainer: document.getElementById('tail-opacity-container'),
|
| 738 |
+
resetViewBtn: document.getElementById('reset-view-btn'),
|
| 739 |
+
showCameraFrustum: document.getElementById('show-camera-frustum'),
|
| 740 |
+
frustumSize: document.getElementById('frustum-size'),
|
| 741 |
+
hideSettingsBtn: document.getElementById('hide-settings-btn'),
|
| 742 |
+
showSettingsBtn: document.getElementById('show-settings-btn')
|
| 743 |
+
};
|
| 744 |
+
|
| 745 |
+
this.scene = null;
|
| 746 |
+
this.camera = null;
|
| 747 |
+
this.renderer = null;
|
| 748 |
+
this.controls = null;
|
| 749 |
+
this.pointCloud = null;
|
| 750 |
+
this.trajectories = [];
|
| 751 |
+
this.cameraFrustum = null;
|
| 752 |
+
|
| 753 |
+
this.initThreeJS();
|
| 754 |
+
this.loadDefaultSettings().then(() => {
|
| 755 |
+
this.initEventListeners();
|
| 756 |
+
this.loadData();
|
| 757 |
+
});
|
| 758 |
+
}
|
| 759 |
+
|
| 760 |
+
async loadDefaultSettings() {
|
| 761 |
+
try {
|
| 762 |
+
const urlParams = new URLSearchParams(window.location.search);
|
| 763 |
+
const dataPath = urlParams.get('data') || '';
|
| 764 |
+
|
| 765 |
+
const defaultSettings = {
|
| 766 |
+
pointSize: 0.03,
|
| 767 |
+
pointOpacity: 1.0,
|
| 768 |
+
showTrajectory: true,
|
| 769 |
+
trajectoryLineWidth: 2.5,
|
| 770 |
+
trajectoryBallSize: 0.015,
|
| 771 |
+
trajectoryHistory: 0,
|
| 772 |
+
showCameraFrustum: true,
|
| 773 |
+
frustumSize: 0.2
|
| 774 |
+
};
|
| 775 |
+
|
| 776 |
+
if (!dataPath) {
|
| 777 |
+
this.defaultSettings = defaultSettings;
|
| 778 |
+
this.applyDefaultSettings();
|
| 779 |
+
return;
|
| 780 |
+
}
|
| 781 |
+
|
| 782 |
+
// Try to extract dataset and videoId from the data path
|
| 783 |
+
// Expected format: demos/datasetname/videoid.bin
|
| 784 |
+
const pathParts = dataPath.split('/');
|
| 785 |
+
if (pathParts.length < 3) {
|
| 786 |
+
this.defaultSettings = defaultSettings;
|
| 787 |
+
this.applyDefaultSettings();
|
| 788 |
+
return;
|
| 789 |
+
}
|
| 790 |
+
|
| 791 |
+
const datasetName = pathParts[pathParts.length - 2];
|
| 792 |
+
let videoId = pathParts[pathParts.length - 1].replace('.bin', '');
|
| 793 |
+
|
| 794 |
+
// Load settings from data.json
|
| 795 |
+
const response = await fetch('./data.json');
|
| 796 |
+
if (!response.ok) {
|
| 797 |
+
this.defaultSettings = defaultSettings;
|
| 798 |
+
this.applyDefaultSettings();
|
| 799 |
+
return;
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
const settingsData = await response.json();
|
| 803 |
+
|
| 804 |
+
// Check if this dataset and video exist
|
| 805 |
+
if (settingsData[datasetName] && settingsData[datasetName][videoId]) {
|
| 806 |
+
this.defaultSettings = settingsData[datasetName][videoId];
|
| 807 |
+
} else {
|
| 808 |
+
this.defaultSettings = defaultSettings;
|
| 809 |
+
}
|
| 810 |
+
|
| 811 |
+
this.applyDefaultSettings();
|
| 812 |
+
} catch (error) {
|
| 813 |
+
console.error("Error loading default settings:", error);
|
| 814 |
+
|
| 815 |
+
this.defaultSettings = {
|
| 816 |
+
pointSize: 0.03,
|
| 817 |
+
pointOpacity: 1.0,
|
| 818 |
+
showTrajectory: true,
|
| 819 |
+
trajectoryLineWidth: 2.5,
|
| 820 |
+
trajectoryBallSize: 0.015,
|
| 821 |
+
trajectoryHistory: 0,
|
| 822 |
+
showCameraFrustum: true,
|
| 823 |
+
frustumSize: 0.2
|
| 824 |
+
};
|
| 825 |
+
|
| 826 |
+
this.applyDefaultSettings();
|
| 827 |
+
}
|
| 828 |
+
}
|
| 829 |
+
|
| 830 |
+
applyDefaultSettings() {
|
| 831 |
+
if (!this.defaultSettings) return;
|
| 832 |
+
|
| 833 |
+
if (this.ui.pointSize) {
|
| 834 |
+
this.ui.pointSize.value = this.defaultSettings.pointSize;
|
| 835 |
+
}
|
| 836 |
+
|
| 837 |
+
if (this.ui.pointOpacity) {
|
| 838 |
+
this.ui.pointOpacity.value = this.defaultSettings.pointOpacity;
|
| 839 |
+
}
|
| 840 |
+
|
| 841 |
+
if (this.ui.maxDepth) {
|
| 842 |
+
this.ui.maxDepth.value = this.defaultSettings.maxDepth || 100.0;
|
| 843 |
+
}
|
| 844 |
+
|
| 845 |
+
if (this.ui.showTrajectory) {
|
| 846 |
+
this.ui.showTrajectory.checked = this.defaultSettings.showTrajectory;
|
| 847 |
+
}
|
| 848 |
+
|
| 849 |
+
if (this.ui.trajectoryLineWidth) {
|
| 850 |
+
this.ui.trajectoryLineWidth.value = this.defaultSettings.trajectoryLineWidth;
|
| 851 |
+
}
|
| 852 |
+
|
| 853 |
+
if (this.ui.trajectoryBallSize) {
|
| 854 |
+
this.ui.trajectoryBallSize.value = this.defaultSettings.trajectoryBallSize;
|
| 855 |
+
}
|
| 856 |
+
|
| 857 |
+
if (this.ui.trajectoryHistory) {
|
| 858 |
+
this.ui.trajectoryHistory.value = this.defaultSettings.trajectoryHistory;
|
| 859 |
+
}
|
| 860 |
+
|
| 861 |
+
if (this.ui.showCameraFrustum) {
|
| 862 |
+
this.ui.showCameraFrustum.checked = this.defaultSettings.showCameraFrustum;
|
| 863 |
+
}
|
| 864 |
+
|
| 865 |
+
if (this.ui.frustumSize) {
|
| 866 |
+
this.ui.frustumSize.value = this.defaultSettings.frustumSize;
|
| 867 |
+
}
|
| 868 |
+
}
|
| 869 |
+
|
| 870 |
+
initThreeJS() {
|
| 871 |
+
this.scene = new THREE.Scene();
|
| 872 |
+
this.scene.background = new THREE.Color(0x1a1a1a);
|
| 873 |
+
|
| 874 |
+
this.camera = new THREE.PerspectiveCamera(60, window.innerWidth / window.innerHeight, 0.1, 10000);
|
| 875 |
+
this.camera.position.set(0, 0, 0);
|
| 876 |
+
|
| 877 |
+
this.renderer = new THREE.WebGLRenderer({ antialias: true });
|
| 878 |
+
this.renderer.setPixelRatio(window.devicePixelRatio);
|
| 879 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
| 880 |
+
document.getElementById('canvas-container').appendChild(this.renderer.domElement);
|
| 881 |
+
|
| 882 |
+
this.controls = new THREE.OrbitControls(this.camera, this.renderer.domElement);
|
| 883 |
+
this.controls.enableDamping = true;
|
| 884 |
+
this.controls.dampingFactor = 0.05;
|
| 885 |
+
this.controls.target.set(0, 0, 0);
|
| 886 |
+
this.controls.minDistance = 0.1;
|
| 887 |
+
this.controls.maxDistance = 1000;
|
| 888 |
+
this.controls.update();
|
| 889 |
+
|
| 890 |
+
const ambientLight = new THREE.AmbientLight(0xffffff, 0.5);
|
| 891 |
+
this.scene.add(ambientLight);
|
| 892 |
+
|
| 893 |
+
const directionalLight = new THREE.DirectionalLight(0xffffff, 0.8);
|
| 894 |
+
directionalLight.position.set(1, 1, 1);
|
| 895 |
+
this.scene.add(directionalLight);
|
| 896 |
+
}
|
| 897 |
+
|
| 898 |
+
initEventListeners() {
|
| 899 |
+
window.addEventListener('resize', () => this.onWindowResize());
|
| 900 |
+
|
| 901 |
+
this.ui.playPauseBtn.addEventListener('click', () => this.togglePlayback());
|
| 902 |
+
|
| 903 |
+
this.ui.timeline.addEventListener('click', (e) => {
|
| 904 |
+
const rect = this.ui.timeline.getBoundingClientRect();
|
| 905 |
+
const pos = (e.clientX - rect.left) / rect.width;
|
| 906 |
+
this.seekTo(pos);
|
| 907 |
+
});
|
| 908 |
+
|
| 909 |
+
this.ui.speedBtn.addEventListener('click', () => this.cyclePlaybackSpeed());
|
| 910 |
+
|
| 911 |
+
this.ui.pointSize.addEventListener('input', () => this.updatePointCloudSettings());
|
| 912 |
+
this.ui.pointOpacity.addEventListener('input', () => this.updatePointCloudSettings());
|
| 913 |
+
this.ui.maxDepth.addEventListener('input', () => this.updatePointCloudSettings());
|
| 914 |
+
this.ui.showTrajectory.addEventListener('change', () => {
|
| 915 |
+
this.trajectories.forEach(trajectory => {
|
| 916 |
+
trajectory.visible = this.ui.showTrajectory.checked;
|
| 917 |
+
});
|
| 918 |
+
});
|
| 919 |
+
|
| 920 |
+
this.ui.enableRichTrail.addEventListener('change', () => {
|
| 921 |
+
this.ui.tailOpacityContainer.style.display = this.ui.enableRichTrail.checked ? 'flex' : 'none';
|
| 922 |
+
this.updateTrajectories(this.currentFrame);
|
| 923 |
+
});
|
| 924 |
+
|
| 925 |
+
this.ui.trajectoryLineWidth.addEventListener('input', () => this.updateTrajectorySettings());
|
| 926 |
+
this.ui.trajectoryBallSize.addEventListener('input', () => this.updateTrajectorySettings());
|
| 927 |
+
this.ui.trajectoryHistory.addEventListener('input', () => {
|
| 928 |
+
this.updateTrajectories(this.currentFrame);
|
| 929 |
+
});
|
| 930 |
+
this.ui.trajectoryFade.addEventListener('input', () => {
|
| 931 |
+
this.updateTrajectories(this.currentFrame);
|
| 932 |
+
});
|
| 933 |
+
|
| 934 |
+
this.ui.resetViewBtn.addEventListener('click', () => this.resetView());
|
| 935 |
+
|
| 936 |
+
const resetSettingsBtn = document.getElementById('reset-settings-btn');
|
| 937 |
+
if (resetSettingsBtn) {
|
| 938 |
+
resetSettingsBtn.addEventListener('click', () => this.resetSettings());
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
document.addEventListener('keydown', (e) => {
|
| 942 |
+
if (e.key === 'Escape' && this.ui.settingsPanel.classList.contains('visible')) {
|
| 943 |
+
this.ui.settingsPanel.classList.remove('visible');
|
| 944 |
+
this.ui.settingsToggleBtn.classList.remove('active');
|
| 945 |
+
}
|
| 946 |
+
});
|
| 947 |
+
|
| 948 |
+
if (this.ui.settingsToggleBtn) {
|
| 949 |
+
this.ui.settingsToggleBtn.addEventListener('click', () => {
|
| 950 |
+
const isVisible = this.ui.settingsPanel.classList.toggle('visible');
|
| 951 |
+
this.ui.settingsToggleBtn.classList.toggle('active', isVisible);
|
| 952 |
+
|
| 953 |
+
if (isVisible) {
|
| 954 |
+
const panelRect = this.ui.settingsPanel.getBoundingClientRect();
|
| 955 |
+
const viewportHeight = window.innerHeight;
|
| 956 |
+
|
| 957 |
+
if (panelRect.bottom > viewportHeight) {
|
| 958 |
+
this.ui.settingsPanel.style.bottom = 'auto';
|
| 959 |
+
this.ui.settingsPanel.style.top = '80px';
|
| 960 |
+
}
|
| 961 |
+
}
|
| 962 |
+
});
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
if (this.ui.frustumSize) {
|
| 966 |
+
this.ui.frustumSize.addEventListener('input', () => this.updateFrustumDimensions());
|
| 967 |
+
}
|
| 968 |
+
|
| 969 |
+
if (this.ui.hideSettingsBtn && this.ui.showSettingsBtn && this.ui.settingsPanel) {
|
| 970 |
+
this.ui.hideSettingsBtn.addEventListener('click', () => {
|
| 971 |
+
this.ui.settingsPanel.classList.add('is-hidden');
|
| 972 |
+
this.ui.showSettingsBtn.style.display = 'flex';
|
| 973 |
+
});
|
| 974 |
+
|
| 975 |
+
this.ui.showSettingsBtn.addEventListener('click', () => {
|
| 976 |
+
this.ui.settingsPanel.classList.remove('is-hidden');
|
| 977 |
+
this.ui.showSettingsBtn.style.display = 'none';
|
| 978 |
+
});
|
| 979 |
+
}
|
| 980 |
+
}
|
| 981 |
+
|
| 982 |
+
makeElementDraggable(element) {
|
| 983 |
+
let pos1 = 0, pos2 = 0, pos3 = 0, pos4 = 0;
|
| 984 |
+
|
| 985 |
+
const dragHandle = element.querySelector('h2');
|
| 986 |
+
|
| 987 |
+
if (dragHandle) {
|
| 988 |
+
dragHandle.onmousedown = dragMouseDown;
|
| 989 |
+
dragHandle.title = "Drag to move panel";
|
| 990 |
+
} else {
|
| 991 |
+
element.onmousedown = dragMouseDown;
|
| 992 |
+
}
|
| 993 |
+
|
| 994 |
+
function dragMouseDown(e) {
|
| 995 |
+
e = e || window.event;
|
| 996 |
+
e.preventDefault();
|
| 997 |
+
pos3 = e.clientX;
|
| 998 |
+
pos4 = e.clientY;
|
| 999 |
+
document.onmouseup = closeDragElement;
|
| 1000 |
+
document.onmousemove = elementDrag;
|
| 1001 |
+
|
| 1002 |
+
element.classList.add('dragging');
|
| 1003 |
+
}
|
| 1004 |
+
|
| 1005 |
+
function elementDrag(e) {
|
| 1006 |
+
e = e || window.event;
|
| 1007 |
+
e.preventDefault();
|
| 1008 |
+
pos1 = pos3 - e.clientX;
|
| 1009 |
+
pos2 = pos4 - e.clientY;
|
| 1010 |
+
pos3 = e.clientX;
|
| 1011 |
+
pos4 = e.clientY;
|
| 1012 |
+
|
| 1013 |
+
const newTop = element.offsetTop - pos2;
|
| 1014 |
+
const newLeft = element.offsetLeft - pos1;
|
| 1015 |
+
|
| 1016 |
+
const viewportWidth = window.innerWidth;
|
| 1017 |
+
const viewportHeight = window.innerHeight;
|
| 1018 |
+
|
| 1019 |
+
const panelRect = element.getBoundingClientRect();
|
| 1020 |
+
|
| 1021 |
+
const maxTop = viewportHeight - 50;
|
| 1022 |
+
const maxLeft = viewportWidth - 50;
|
| 1023 |
+
|
| 1024 |
+
element.style.top = Math.min(Math.max(newTop, 0), maxTop) + "px";
|
| 1025 |
+
element.style.left = Math.min(Math.max(newLeft, 0), maxLeft) + "px";
|
| 1026 |
+
|
| 1027 |
+
// Remove bottom/right settings when dragging
|
| 1028 |
+
element.style.bottom = 'auto';
|
| 1029 |
+
element.style.right = 'auto';
|
| 1030 |
+
}
|
| 1031 |
+
|
| 1032 |
+
function closeDragElement() {
|
| 1033 |
+
document.onmouseup = null;
|
| 1034 |
+
document.onmousemove = null;
|
| 1035 |
+
|
| 1036 |
+
element.classList.remove('dragging');
|
| 1037 |
+
}
|
| 1038 |
+
}
|
| 1039 |
+
|
| 1040 |
+
async loadData() {
|
| 1041 |
+
try {
|
| 1042 |
+
// this.ui.loadingText.textContent = "Loading binary data...";
|
| 1043 |
+
|
| 1044 |
+
let arrayBuffer;
|
| 1045 |
+
|
| 1046 |
+
if (window.embeddedBase64) {
|
| 1047 |
+
// Base64 embedded path
|
| 1048 |
+
const binaryString = atob(window.embeddedBase64);
|
| 1049 |
+
const len = binaryString.length;
|
| 1050 |
+
const bytes = new Uint8Array(len);
|
| 1051 |
+
for (let i = 0; i < len; i++) {
|
| 1052 |
+
bytes[i] = binaryString.charCodeAt(i);
|
| 1053 |
+
}
|
| 1054 |
+
arrayBuffer = bytes.buffer;
|
| 1055 |
+
} else {
|
| 1056 |
+
// Default fetch path (fallback)
|
| 1057 |
+
const urlParams = new URLSearchParams(window.location.search);
|
| 1058 |
+
const dataPath = urlParams.get('data') || 'data.bin';
|
| 1059 |
+
|
| 1060 |
+
const response = await fetch(dataPath);
|
| 1061 |
+
if (!response.ok) throw new Error(`Failed to load ${dataPath}`);
|
| 1062 |
+
arrayBuffer = await response.arrayBuffer();
|
| 1063 |
+
}
|
| 1064 |
+
|
| 1065 |
+
const dataView = new DataView(arrayBuffer);
|
| 1066 |
+
const headerLen = dataView.getUint32(0, true);
|
| 1067 |
+
|
| 1068 |
+
const headerText = new TextDecoder("utf-8").decode(arrayBuffer.slice(4, 4 + headerLen));
|
| 1069 |
+
const header = JSON.parse(headerText);
|
| 1070 |
+
|
| 1071 |
+
const compressedBlob = new Uint8Array(arrayBuffer, 4 + headerLen);
|
| 1072 |
+
const decompressed = pako.inflate(compressedBlob).buffer;
|
| 1073 |
+
|
| 1074 |
+
const arrays = {};
|
| 1075 |
+
for (const key in header) {
|
| 1076 |
+
if (key === "meta") continue;
|
| 1077 |
+
|
| 1078 |
+
const meta = header[key];
|
| 1079 |
+
const { dtype, shape, offset, length } = meta;
|
| 1080 |
+
const slice = decompressed.slice(offset, offset + length);
|
| 1081 |
+
|
| 1082 |
+
let typedArray;
|
| 1083 |
+
switch (dtype) {
|
| 1084 |
+
case "uint8": typedArray = new Uint8Array(slice); break;
|
| 1085 |
+
case "uint16": typedArray = new Uint16Array(slice); break;
|
| 1086 |
+
case "float32": typedArray = new Float32Array(slice); break;
|
| 1087 |
+
case "float64": typedArray = new Float64Array(slice); break;
|
| 1088 |
+
default: throw new Error(`Unknown dtype: ${dtype}`);
|
| 1089 |
+
}
|
| 1090 |
+
|
| 1091 |
+
arrays[key] = { data: typedArray, shape: shape };
|
| 1092 |
+
}
|
| 1093 |
+
|
| 1094 |
+
this.data = arrays;
|
| 1095 |
+
this.config = header.meta;
|
| 1096 |
+
|
| 1097 |
+
this.initCameraWithCorrectFOV();
|
| 1098 |
+
this.ui.loadingText.textContent = "Creating point cloud...";
|
| 1099 |
+
|
| 1100 |
+
this.initPointCloud();
|
| 1101 |
+
this.initTrajectories();
|
| 1102 |
+
|
| 1103 |
+
setTimeout(() => {
|
| 1104 |
+
this.ui.loadingOverlay.classList.add('fade-out');
|
| 1105 |
+
this.ui.statusBar.classList.add('hidden');
|
| 1106 |
+
this.startAnimation();
|
| 1107 |
+
}, 500);
|
| 1108 |
+
} catch (error) {
|
| 1109 |
+
console.error("Error loading data:", error);
|
| 1110 |
+
this.ui.statusBar.textContent = `Error: ${error.message}`;
|
| 1111 |
+
// this.ui.loadingText.textContent = `Error loading data: ${error.message}`;
|
| 1112 |
+
}
|
| 1113 |
+
}
|
| 1114 |
+
|
| 1115 |
+
initPointCloud() {
|
| 1116 |
+
const numPoints = this.config.resolution[0] * this.config.resolution[1];
|
| 1117 |
+
const positions = new Float32Array(numPoints * 3);
|
| 1118 |
+
const colors = new Float32Array(numPoints * 3);
|
| 1119 |
+
|
| 1120 |
+
const geometry = new THREE.BufferGeometry();
|
| 1121 |
+
geometry.setAttribute('position', new THREE.BufferAttribute(positions, 3).setUsage(THREE.DynamicDrawUsage));
|
| 1122 |
+
geometry.setAttribute('color', new THREE.BufferAttribute(colors, 3).setUsage(THREE.DynamicDrawUsage));
|
| 1123 |
+
|
| 1124 |
+
const pointSize = parseFloat(this.ui.pointSize.value) || this.defaultSettings.pointSize;
|
| 1125 |
+
const pointOpacity = parseFloat(this.ui.pointOpacity.value) || this.defaultSettings.pointOpacity;
|
| 1126 |
+
|
| 1127 |
+
const material = new THREE.PointsMaterial({
|
| 1128 |
+
size: pointSize,
|
| 1129 |
+
vertexColors: true,
|
| 1130 |
+
transparent: true,
|
| 1131 |
+
opacity: pointOpacity,
|
| 1132 |
+
sizeAttenuation: true
|
| 1133 |
+
});
|
| 1134 |
+
|
| 1135 |
+
this.pointCloud = new THREE.Points(geometry, material);
|
| 1136 |
+
this.scene.add(this.pointCloud);
|
| 1137 |
+
}
|
| 1138 |
+
|
| 1139 |
+
initTrajectories() {
|
| 1140 |
+
if (!this.data.trajectories) return;
|
| 1141 |
+
|
| 1142 |
+
this.trajectories.forEach(trajectory => {
|
| 1143 |
+
if (trajectory.userData.lineSegments) {
|
| 1144 |
+
trajectory.userData.lineSegments.forEach(segment => {
|
| 1145 |
+
segment.geometry.dispose();
|
| 1146 |
+
segment.material.dispose();
|
| 1147 |
+
});
|
| 1148 |
+
}
|
| 1149 |
+
this.scene.remove(trajectory);
|
| 1150 |
+
});
|
| 1151 |
+
this.trajectories = [];
|
| 1152 |
+
|
| 1153 |
+
const shape = this.data.trajectories.shape;
|
| 1154 |
+
if (!shape || shape.length < 2) return;
|
| 1155 |
+
|
| 1156 |
+
const [totalFrames, numTrajectories] = shape;
|
| 1157 |
+
const palette = this.createColorPalette(numTrajectories);
|
| 1158 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1159 |
+
const maxHistory = 500; // Max value of the history slider, for the object pool
|
| 1160 |
+
|
| 1161 |
+
for (let i = 0; i < numTrajectories; i++) {
|
| 1162 |
+
const trajectoryGroup = new THREE.Group();
|
| 1163 |
+
|
| 1164 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
| 1165 |
+
const sphereGeometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
| 1166 |
+
const sphereMaterial = new THREE.MeshBasicMaterial({ color: palette[i], transparent: true });
|
| 1167 |
+
const positionMarker = new THREE.Mesh(sphereGeometry, sphereMaterial);
|
| 1168 |
+
trajectoryGroup.add(positionMarker);
|
| 1169 |
+
|
| 1170 |
+
// High-Performance Line (default)
|
| 1171 |
+
const simpleLineGeometry = new THREE.BufferGeometry();
|
| 1172 |
+
const simpleLinePositions = new Float32Array(maxHistory * 3);
|
| 1173 |
+
simpleLineGeometry.setAttribute('position', new THREE.BufferAttribute(simpleLinePositions, 3).setUsage(THREE.DynamicDrawUsage));
|
| 1174 |
+
const simpleLine = new THREE.Line(simpleLineGeometry, new THREE.LineBasicMaterial({ color: palette[i] }));
|
| 1175 |
+
simpleLine.frustumCulled = false;
|
| 1176 |
+
trajectoryGroup.add(simpleLine);
|
| 1177 |
+
|
| 1178 |
+
// High-Quality Line Segments (for rich trail)
|
| 1179 |
+
const lineSegments = [];
|
| 1180 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
| 1181 |
+
|
| 1182 |
+
// Create a pool of line segment objects
|
| 1183 |
+
for (let j = 0; j < maxHistory - 1; j++) {
|
| 1184 |
+
const lineGeometry = new THREE.LineGeometry();
|
| 1185 |
+
lineGeometry.setPositions([0, 0, 0, 0, 0, 0]);
|
| 1186 |
+
const lineMaterial = new THREE.LineMaterial({
|
| 1187 |
+
color: palette[i],
|
| 1188 |
+
linewidth: lineWidth,
|
| 1189 |
+
resolution: resolution,
|
| 1190 |
+
transparent: true,
|
| 1191 |
+
depthWrite: false, // Correctly handle transparency
|
| 1192 |
+
opacity: 0
|
| 1193 |
+
});
|
| 1194 |
+
const segment = new THREE.Line2(lineGeometry, lineMaterial);
|
| 1195 |
+
segment.frustumCulled = false;
|
| 1196 |
+
segment.visible = false; // Start with all segments hidden
|
| 1197 |
+
trajectoryGroup.add(segment);
|
| 1198 |
+
lineSegments.push(segment);
|
| 1199 |
+
}
|
| 1200 |
+
|
| 1201 |
+
trajectoryGroup.userData = {
|
| 1202 |
+
marker: positionMarker,
|
| 1203 |
+
simpleLine: simpleLine,
|
| 1204 |
+
lineSegments: lineSegments,
|
| 1205 |
+
color: palette[i]
|
| 1206 |
+
};
|
| 1207 |
+
|
| 1208 |
+
this.scene.add(trajectoryGroup);
|
| 1209 |
+
this.trajectories.push(trajectoryGroup);
|
| 1210 |
+
}
|
| 1211 |
+
|
| 1212 |
+
const showTrajectory = this.ui.showTrajectory.checked;
|
| 1213 |
+
this.trajectories.forEach(trajectory => trajectory.visible = showTrajectory);
|
| 1214 |
+
}
|
| 1215 |
+
|
| 1216 |
+
createColorPalette(count) {
|
| 1217 |
+
const colors = [];
|
| 1218 |
+
const hueStep = 360 / count;
|
| 1219 |
+
|
| 1220 |
+
for (let i = 0; i < count; i++) {
|
| 1221 |
+
const hue = (i * hueStep) % 360;
|
| 1222 |
+
const color = new THREE.Color().setHSL(hue / 360, 0.8, 0.6);
|
| 1223 |
+
colors.push(color);
|
| 1224 |
+
}
|
| 1225 |
+
|
| 1226 |
+
return colors;
|
| 1227 |
+
}
|
| 1228 |
+
|
| 1229 |
+
updatePointCloud(frameIndex) {
|
| 1230 |
+
if (!this.data || !this.pointCloud) return;
|
| 1231 |
+
|
| 1232 |
+
const positions = this.pointCloud.geometry.attributes.position.array;
|
| 1233 |
+
const colors = this.pointCloud.geometry.attributes.color.array;
|
| 1234 |
+
|
| 1235 |
+
const rgbVideo = this.data.rgb_video;
|
| 1236 |
+
const depthsRgb = this.data.depths_rgb;
|
| 1237 |
+
const intrinsics = this.data.intrinsics;
|
| 1238 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
| 1239 |
+
|
| 1240 |
+
const width = this.config.resolution[0];
|
| 1241 |
+
const height = this.config.resolution[1];
|
| 1242 |
+
const numPoints = width * height;
|
| 1243 |
+
|
| 1244 |
+
const K = this.get3x3Matrix(intrinsics.data, intrinsics.shape, frameIndex);
|
| 1245 |
+
const fx = K[0][0], fy = K[1][1], cx = K[0][2], cy = K[1][2];
|
| 1246 |
+
|
| 1247 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
| 1248 |
+
const transform = this.getTransformElements(invExtrMat);
|
| 1249 |
+
|
| 1250 |
+
const rgbFrame = this.getFrame(rgbVideo.data, rgbVideo.shape, frameIndex);
|
| 1251 |
+
const depthFrame = this.getFrame(depthsRgb.data, depthsRgb.shape, frameIndex);
|
| 1252 |
+
|
| 1253 |
+
const maxDepth = parseFloat(this.ui.maxDepth.value) || 10.0;
|
| 1254 |
+
|
| 1255 |
+
let validPointCount = 0;
|
| 1256 |
+
|
| 1257 |
+
for (let i = 0; i < numPoints; i++) {
|
| 1258 |
+
const xPix = i % width;
|
| 1259 |
+
const yPix = Math.floor(i / width);
|
| 1260 |
+
|
| 1261 |
+
const d0 = depthFrame[i * 3];
|
| 1262 |
+
const d1 = depthFrame[i * 3 + 1];
|
| 1263 |
+
const depthEncoded = d0 | (d1 << 8);
|
| 1264 |
+
const depthValue = (depthEncoded / ((1 << 16) - 1)) *
|
| 1265 |
+
(this.config.depthRange[1] - this.config.depthRange[0]) +
|
| 1266 |
+
this.config.depthRange[0];
|
| 1267 |
+
|
| 1268 |
+
if (depthValue === 0 || depthValue > maxDepth) {
|
| 1269 |
+
continue;
|
| 1270 |
+
}
|
| 1271 |
+
|
| 1272 |
+
const X = ((xPix - cx) * depthValue) / fx;
|
| 1273 |
+
const Y = ((yPix - cy) * depthValue) / fy;
|
| 1274 |
+
const Z = depthValue;
|
| 1275 |
+
|
| 1276 |
+
const tx = transform.m11 * X + transform.m12 * Y + transform.m13 * Z + transform.m14;
|
| 1277 |
+
const ty = transform.m21 * X + transform.m22 * Y + transform.m23 * Z + transform.m24;
|
| 1278 |
+
const tz = transform.m31 * X + transform.m32 * Y + transform.m33 * Z + transform.m34;
|
| 1279 |
+
|
| 1280 |
+
const index = validPointCount * 3;
|
| 1281 |
+
positions[index] = tx;
|
| 1282 |
+
positions[index + 1] = -ty;
|
| 1283 |
+
positions[index + 2] = -tz;
|
| 1284 |
+
|
| 1285 |
+
colors[index] = rgbFrame[i * 3] / 255;
|
| 1286 |
+
colors[index + 1] = rgbFrame[i * 3 + 1] / 255;
|
| 1287 |
+
colors[index + 2] = rgbFrame[i * 3 + 2] / 255;
|
| 1288 |
+
|
| 1289 |
+
validPointCount++;
|
| 1290 |
+
}
|
| 1291 |
+
|
| 1292 |
+
this.pointCloud.geometry.setDrawRange(0, validPointCount);
|
| 1293 |
+
this.pointCloud.geometry.attributes.position.needsUpdate = true;
|
| 1294 |
+
this.pointCloud.geometry.attributes.color.needsUpdate = true;
|
| 1295 |
+
this.pointCloud.geometry.computeBoundingSphere(); // Important for camera culling
|
| 1296 |
+
|
| 1297 |
+
this.updateTrajectories(frameIndex);
|
| 1298 |
+
|
| 1299 |
+
const progress = (frameIndex + 1) / this.config.totalFrames;
|
| 1300 |
+
this.ui.progress.style.width = `${progress * 100}%`;
|
| 1301 |
+
|
| 1302 |
+
if (this.ui.frameCounter && this.config.totalFrames) {
|
| 1303 |
+
this.ui.frameCounter.textContent = `Frame ${frameIndex} / ${this.config.totalFrames - 1}`;
|
| 1304 |
+
}
|
| 1305 |
+
|
| 1306 |
+
this.updateCameraFrustum(frameIndex);
|
| 1307 |
+
}
|
| 1308 |
+
|
| 1309 |
+
updateTrajectories(frameIndex) {
|
| 1310 |
+
if (!this.data.trajectories || this.trajectories.length === 0) return;
|
| 1311 |
+
|
| 1312 |
+
const trajectoryData = this.data.trajectories.data;
|
| 1313 |
+
const [totalFrames, numTrajectories] = this.data.trajectories.shape;
|
| 1314 |
+
const historyFrames = parseInt(this.ui.trajectoryHistory.value);
|
| 1315 |
+
const tailOpacity = parseFloat(this.ui.trajectoryFade.value);
|
| 1316 |
+
|
| 1317 |
+
const isRichMode = this.ui.enableRichTrail.checked;
|
| 1318 |
+
|
| 1319 |
+
for (let i = 0; i < numTrajectories; i++) {
|
| 1320 |
+
const trajectoryGroup = this.trajectories[i];
|
| 1321 |
+
const { marker, simpleLine, lineSegments } = trajectoryGroup.userData;
|
| 1322 |
+
|
| 1323 |
+
const currentPos = new THREE.Vector3();
|
| 1324 |
+
const currentOffset = (frameIndex * numTrajectories + i) * 3;
|
| 1325 |
+
|
| 1326 |
+
currentPos.x = trajectoryData[currentOffset];
|
| 1327 |
+
currentPos.y = -trajectoryData[currentOffset + 1];
|
| 1328 |
+
currentPos.z = -trajectoryData[currentOffset + 2];
|
| 1329 |
+
|
| 1330 |
+
marker.position.copy(currentPos);
|
| 1331 |
+
marker.material.opacity = 1.0;
|
| 1332 |
+
|
| 1333 |
+
const historyToShow = Math.min(historyFrames, frameIndex + 1);
|
| 1334 |
+
|
| 1335 |
+
if (isRichMode) {
|
| 1336 |
+
// --- High-Quality Mode ---
|
| 1337 |
+
simpleLine.visible = false;
|
| 1338 |
+
|
| 1339 |
+
for (let j = 0; j < lineSegments.length; j++) {
|
| 1340 |
+
const segment = lineSegments[j];
|
| 1341 |
+
if (j < historyToShow - 1) {
|
| 1342 |
+
const headFrame = frameIndex - j;
|
| 1343 |
+
const tailFrame = frameIndex - j - 1;
|
| 1344 |
+
const headOffset = (headFrame * numTrajectories + i) * 3;
|
| 1345 |
+
const tailOffset = (tailFrame * numTrajectories + i) * 3;
|
| 1346 |
+
const positions = [
|
| 1347 |
+
trajectoryData[headOffset], -trajectoryData[headOffset + 1], -trajectoryData[headOffset + 2],
|
| 1348 |
+
trajectoryData[tailOffset], -trajectoryData[tailOffset + 1], -trajectoryData[tailOffset + 2]
|
| 1349 |
+
];
|
| 1350 |
+
segment.geometry.setPositions(positions);
|
| 1351 |
+
const headOpacity = 1.0;
|
| 1352 |
+
const normalizedAge = j / Math.max(1, historyToShow - 2);
|
| 1353 |
+
const alpha = headOpacity - (headOpacity - tailOpacity) * normalizedAge;
|
| 1354 |
+
segment.material.opacity = Math.max(0, alpha);
|
| 1355 |
+
segment.visible = true;
|
| 1356 |
+
} else {
|
| 1357 |
+
segment.visible = false;
|
| 1358 |
+
}
|
| 1359 |
+
}
|
| 1360 |
+
} else {
|
| 1361 |
+
// --- Performance Mode ---
|
| 1362 |
+
lineSegments.forEach(s => s.visible = false);
|
| 1363 |
+
simpleLine.visible = true;
|
| 1364 |
+
|
| 1365 |
+
const positions = simpleLine.geometry.attributes.position.array;
|
| 1366 |
+
for (let j = 0; j < historyToShow; j++) {
|
| 1367 |
+
const historyFrame = Math.max(0, frameIndex - j);
|
| 1368 |
+
const offset = (historyFrame * numTrajectories + i) * 3;
|
| 1369 |
+
positions[j * 3] = trajectoryData[offset];
|
| 1370 |
+
positions[j * 3 + 1] = -trajectoryData[offset + 1];
|
| 1371 |
+
positions[j * 3 + 2] = -trajectoryData[offset + 2];
|
| 1372 |
+
}
|
| 1373 |
+
simpleLine.geometry.setDrawRange(0, historyToShow);
|
| 1374 |
+
simpleLine.geometry.attributes.position.needsUpdate = true;
|
| 1375 |
+
}
|
| 1376 |
+
}
|
| 1377 |
+
}
|
| 1378 |
+
|
| 1379 |
+
updateTrajectorySettings() {
|
| 1380 |
+
if (!this.trajectories || this.trajectories.length === 0) return;
|
| 1381 |
+
|
| 1382 |
+
const ballSize = parseFloat(this.ui.trajectoryBallSize.value);
|
| 1383 |
+
const lineWidth = parseFloat(this.ui.trajectoryLineWidth.value);
|
| 1384 |
+
|
| 1385 |
+
this.trajectories.forEach(trajectoryGroup => {
|
| 1386 |
+
const { marker, lineSegments } = trajectoryGroup.userData;
|
| 1387 |
+
|
| 1388 |
+
marker.geometry.dispose();
|
| 1389 |
+
marker.geometry = new THREE.SphereGeometry(ballSize, 16, 16);
|
| 1390 |
+
|
| 1391 |
+
// Line width only affects rich mode
|
| 1392 |
+
lineSegments.forEach(segment => {
|
| 1393 |
+
if (segment.material) {
|
| 1394 |
+
segment.material.linewidth = lineWidth;
|
| 1395 |
+
}
|
| 1396 |
+
});
|
| 1397 |
+
});
|
| 1398 |
+
|
| 1399 |
+
this.updateTrajectories(this.currentFrame);
|
| 1400 |
+
}
|
| 1401 |
+
|
| 1402 |
+
getDepthColor(normalizedDepth) {
|
| 1403 |
+
const hue = (1 - normalizedDepth) * 240 / 360;
|
| 1404 |
+
const color = new THREE.Color().setHSL(hue, 1.0, 0.5);
|
| 1405 |
+
return color;
|
| 1406 |
+
}
|
| 1407 |
+
|
| 1408 |
+
getFrame(typedArray, shape, frameIndex) {
|
| 1409 |
+
const [T, H, W, C] = shape;
|
| 1410 |
+
const frameSize = H * W * C;
|
| 1411 |
+
const offset = frameIndex * frameSize;
|
| 1412 |
+
return typedArray.subarray(offset, offset + frameSize);
|
| 1413 |
+
}
|
| 1414 |
+
|
| 1415 |
+
get3x3Matrix(typedArray, shape, frameIndex) {
|
| 1416 |
+
const frameSize = 9;
|
| 1417 |
+
const offset = frameIndex * frameSize;
|
| 1418 |
+
const K = [];
|
| 1419 |
+
for (let i = 0; i < 3; i++) {
|
| 1420 |
+
const row = [];
|
| 1421 |
+
for (let j = 0; j < 3; j++) {
|
| 1422 |
+
row.push(typedArray[offset + i * 3 + j]);
|
| 1423 |
+
}
|
| 1424 |
+
K.push(row);
|
| 1425 |
+
}
|
| 1426 |
+
return K;
|
| 1427 |
+
}
|
| 1428 |
+
|
| 1429 |
+
get4x4Matrix(typedArray, shape, frameIndex) {
|
| 1430 |
+
const frameSize = 16;
|
| 1431 |
+
const offset = frameIndex * frameSize;
|
| 1432 |
+
const M = [];
|
| 1433 |
+
for (let i = 0; i < 4; i++) {
|
| 1434 |
+
const row = [];
|
| 1435 |
+
for (let j = 0; j < 4; j++) {
|
| 1436 |
+
row.push(typedArray[offset + i * 4 + j]);
|
| 1437 |
+
}
|
| 1438 |
+
M.push(row);
|
| 1439 |
+
}
|
| 1440 |
+
return M;
|
| 1441 |
+
}
|
| 1442 |
+
|
| 1443 |
+
getTransformElements(matrix) {
|
| 1444 |
+
return {
|
| 1445 |
+
m11: matrix[0][0], m12: matrix[0][1], m13: matrix[0][2], m14: matrix[0][3],
|
| 1446 |
+
m21: matrix[1][0], m22: matrix[1][1], m23: matrix[1][2], m24: matrix[1][3],
|
| 1447 |
+
m31: matrix[2][0], m32: matrix[2][1], m33: matrix[2][2], m34: matrix[2][3]
|
| 1448 |
+
};
|
| 1449 |
+
}
|
| 1450 |
+
|
| 1451 |
+
togglePlayback() {
|
| 1452 |
+
this.isPlaying = !this.isPlaying;
|
| 1453 |
+
|
| 1454 |
+
const playIcon = document.getElementById('play-icon');
|
| 1455 |
+
const pauseIcon = document.getElementById('pause-icon');
|
| 1456 |
+
|
| 1457 |
+
if (this.isPlaying) {
|
| 1458 |
+
playIcon.style.display = 'none';
|
| 1459 |
+
pauseIcon.style.display = 'block';
|
| 1460 |
+
this.lastFrameTime = performance.now();
|
| 1461 |
+
} else {
|
| 1462 |
+
playIcon.style.display = 'block';
|
| 1463 |
+
pauseIcon.style.display = 'none';
|
| 1464 |
+
}
|
| 1465 |
+
}
|
| 1466 |
+
|
| 1467 |
+
cyclePlaybackSpeed() {
|
| 1468 |
+
const speeds = [0.5, 1, 2, 4, 8];
|
| 1469 |
+
const speedRates = speeds.map(s => s * this.config.baseFrameRate);
|
| 1470 |
+
|
| 1471 |
+
let currentIndex = 0;
|
| 1472 |
+
const normalizedSpeed = this.playbackSpeed / this.config.baseFrameRate;
|
| 1473 |
+
|
| 1474 |
+
for (let i = 0; i < speeds.length; i++) {
|
| 1475 |
+
if (Math.abs(normalizedSpeed - speeds[i]) < Math.abs(normalizedSpeed - speeds[currentIndex])) {
|
| 1476 |
+
currentIndex = i;
|
| 1477 |
+
}
|
| 1478 |
+
}
|
| 1479 |
+
|
| 1480 |
+
const nextIndex = (currentIndex + 1) % speeds.length;
|
| 1481 |
+
this.playbackSpeed = speedRates[nextIndex];
|
| 1482 |
+
this.ui.speedBtn.textContent = `${speeds[nextIndex]}x`;
|
| 1483 |
+
|
| 1484 |
+
if (speeds[nextIndex] === 1) {
|
| 1485 |
+
this.ui.speedBtn.classList.remove('active');
|
| 1486 |
+
} else {
|
| 1487 |
+
this.ui.speedBtn.classList.add('active');
|
| 1488 |
+
}
|
| 1489 |
+
}
|
| 1490 |
+
|
| 1491 |
+
seekTo(position) {
|
| 1492 |
+
const frameIndex = Math.floor(position * this.config.totalFrames);
|
| 1493 |
+
this.currentFrame = Math.max(0, Math.min(frameIndex, this.config.totalFrames - 1));
|
| 1494 |
+
this.updatePointCloud(this.currentFrame);
|
| 1495 |
+
}
|
| 1496 |
+
|
| 1497 |
+
updatePointCloudSettings() {
|
| 1498 |
+
if (!this.pointCloud) return;
|
| 1499 |
+
|
| 1500 |
+
const size = parseFloat(this.ui.pointSize.value);
|
| 1501 |
+
const opacity = parseFloat(this.ui.pointOpacity.value);
|
| 1502 |
+
|
| 1503 |
+
this.pointCloud.material.size = size;
|
| 1504 |
+
this.pointCloud.material.opacity = opacity;
|
| 1505 |
+
this.pointCloud.material.needsUpdate = true;
|
| 1506 |
+
|
| 1507 |
+
this.updatePointCloud(this.currentFrame);
|
| 1508 |
+
}
|
| 1509 |
+
|
| 1510 |
+
updateControls() {
|
| 1511 |
+
if (!this.controls) return;
|
| 1512 |
+
this.controls.update();
|
| 1513 |
+
}
|
| 1514 |
+
|
| 1515 |
+
resetView() {
|
| 1516 |
+
if (!this.camera || !this.controls) return;
|
| 1517 |
+
|
| 1518 |
+
// Reset camera position
|
| 1519 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
| 1520 |
+
|
| 1521 |
+
// Reset controls
|
| 1522 |
+
this.controls.reset();
|
| 1523 |
+
|
| 1524 |
+
// Set target slightly in front of camera
|
| 1525 |
+
this.controls.target.set(0, 0, -1);
|
| 1526 |
+
this.controls.update();
|
| 1527 |
+
|
| 1528 |
+
// Show status message
|
| 1529 |
+
this.ui.statusBar.textContent = "View reset";
|
| 1530 |
+
this.ui.statusBar.classList.remove('hidden');
|
| 1531 |
+
|
| 1532 |
+
// Hide status message after a few seconds
|
| 1533 |
+
setTimeout(() => {
|
| 1534 |
+
this.ui.statusBar.classList.add('hidden');
|
| 1535 |
+
}, 3000);
|
| 1536 |
+
}
|
| 1537 |
+
|
| 1538 |
+
onWindowResize() {
|
| 1539 |
+
if (!this.camera || !this.renderer) return;
|
| 1540 |
+
|
| 1541 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
| 1542 |
+
this.camera.aspect = windowAspect;
|
| 1543 |
+
this.camera.updateProjectionMatrix();
|
| 1544 |
+
this.renderer.setSize(window.innerWidth, window.innerHeight);
|
| 1545 |
+
|
| 1546 |
+
if (this.trajectories && this.trajectories.length > 0) {
|
| 1547 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1548 |
+
this.trajectories.forEach(trajectory => {
|
| 1549 |
+
const { lineSegments } = trajectory.userData;
|
| 1550 |
+
if (lineSegments && lineSegments.length > 0) {
|
| 1551 |
+
lineSegments.forEach(segment => {
|
| 1552 |
+
if (segment.material && segment.material.resolution) {
|
| 1553 |
+
segment.material.resolution.copy(resolution);
|
| 1554 |
+
}
|
| 1555 |
+
});
|
| 1556 |
+
}
|
| 1557 |
+
});
|
| 1558 |
+
}
|
| 1559 |
+
|
| 1560 |
+
if (this.cameraFrustum) {
|
| 1561 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1562 |
+
this.cameraFrustum.children.forEach(line => {
|
| 1563 |
+
if (line.material && line.material.resolution) {
|
| 1564 |
+
line.material.resolution.copy(resolution);
|
| 1565 |
+
}
|
| 1566 |
+
});
|
| 1567 |
+
}
|
| 1568 |
+
}
|
| 1569 |
+
|
| 1570 |
+
startAnimation() {
|
| 1571 |
+
this.isPlaying = true;
|
| 1572 |
+
this.lastFrameTime = performance.now();
|
| 1573 |
+
|
| 1574 |
+
this.camera.position.set(0, 0, this.config.cameraZ || 0);
|
| 1575 |
+
this.controls.target.set(0, 0, -1);
|
| 1576 |
+
this.controls.update();
|
| 1577 |
+
|
| 1578 |
+
this.playbackSpeed = this.config.baseFrameRate;
|
| 1579 |
+
|
| 1580 |
+
document.getElementById('play-icon').style.display = 'none';
|
| 1581 |
+
document.getElementById('pause-icon').style.display = 'block';
|
| 1582 |
+
|
| 1583 |
+
this.animate();
|
| 1584 |
+
}
|
| 1585 |
+
|
| 1586 |
+
animate() {
|
| 1587 |
+
requestAnimationFrame(() => this.animate());
|
| 1588 |
+
|
| 1589 |
+
if (this.controls) {
|
| 1590 |
+
this.controls.update();
|
| 1591 |
+
}
|
| 1592 |
+
|
| 1593 |
+
if (this.isPlaying && this.data) {
|
| 1594 |
+
const now = performance.now();
|
| 1595 |
+
const delta = (now - this.lastFrameTime) / 1000;
|
| 1596 |
+
|
| 1597 |
+
const framesToAdvance = Math.floor(delta * this.config.baseFrameRate * this.playbackSpeed);
|
| 1598 |
+
if (framesToAdvance > 0) {
|
| 1599 |
+
this.currentFrame = (this.currentFrame + framesToAdvance) % this.config.totalFrames;
|
| 1600 |
+
this.lastFrameTime = now;
|
| 1601 |
+
this.updatePointCloud(this.currentFrame);
|
| 1602 |
+
}
|
| 1603 |
+
}
|
| 1604 |
+
|
| 1605 |
+
if (this.renderer && this.scene && this.camera) {
|
| 1606 |
+
this.renderer.render(this.scene, this.camera);
|
| 1607 |
+
}
|
| 1608 |
+
}
|
| 1609 |
+
|
| 1610 |
+
initCameraWithCorrectFOV() {
|
| 1611 |
+
const fov = this.config.fov || 60;
|
| 1612 |
+
|
| 1613 |
+
const windowAspect = window.innerWidth / window.innerHeight;
|
| 1614 |
+
|
| 1615 |
+
this.camera = new THREE.PerspectiveCamera(
|
| 1616 |
+
fov,
|
| 1617 |
+
windowAspect,
|
| 1618 |
+
0.1,
|
| 1619 |
+
10000
|
| 1620 |
+
);
|
| 1621 |
+
|
| 1622 |
+
this.controls.object = this.camera;
|
| 1623 |
+
this.controls.update();
|
| 1624 |
+
|
| 1625 |
+
this.initCameraFrustum();
|
| 1626 |
+
}
|
| 1627 |
+
|
| 1628 |
+
initCameraFrustum() {
|
| 1629 |
+
this.cameraFrustum = new THREE.Group();
|
| 1630 |
+
|
| 1631 |
+
this.scene.add(this.cameraFrustum);
|
| 1632 |
+
|
| 1633 |
+
this.initCameraFrustumGeometry();
|
| 1634 |
+
|
| 1635 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : (this.defaultSettings ? this.defaultSettings.showCameraFrustum : false);
|
| 1636 |
+
|
| 1637 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
| 1638 |
+
}
|
| 1639 |
+
|
| 1640 |
+
initCameraFrustumGeometry() {
|
| 1641 |
+
const fov = this.config.fov || 60;
|
| 1642 |
+
const originalAspect = this.config.original_aspect_ratio || 1.33;
|
| 1643 |
+
|
| 1644 |
+
const size = parseFloat(this.ui.frustumSize.value) || this.defaultSettings.frustumSize;
|
| 1645 |
+
|
| 1646 |
+
const halfHeight = Math.tan(THREE.MathUtils.degToRad(fov / 2)) * size;
|
| 1647 |
+
const halfWidth = halfHeight * originalAspect;
|
| 1648 |
+
|
| 1649 |
+
const vertices = [
|
| 1650 |
+
new THREE.Vector3(0, 0, 0),
|
| 1651 |
+
new THREE.Vector3(-halfWidth, -halfHeight, size),
|
| 1652 |
+
new THREE.Vector3(halfWidth, -halfHeight, size),
|
| 1653 |
+
new THREE.Vector3(halfWidth, halfHeight, size),
|
| 1654 |
+
new THREE.Vector3(-halfWidth, halfHeight, size)
|
| 1655 |
+
];
|
| 1656 |
+
|
| 1657 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1658 |
+
|
| 1659 |
+
const linePairs = [
|
| 1660 |
+
[1, 2], [2, 3], [3, 4], [4, 1],
|
| 1661 |
+
[0, 1], [0, 2], [0, 3], [0, 4]
|
| 1662 |
+
];
|
| 1663 |
+
|
| 1664 |
+
const colors = {
|
| 1665 |
+
edge: new THREE.Color(0x3366ff),
|
| 1666 |
+
ray: new THREE.Color(0x33cc66)
|
| 1667 |
+
};
|
| 1668 |
+
|
| 1669 |
+
linePairs.forEach((pair, index) => {
|
| 1670 |
+
const positions = [
|
| 1671 |
+
vertices[pair[0]].x, vertices[pair[0]].y, vertices[pair[0]].z,
|
| 1672 |
+
vertices[pair[1]].x, vertices[pair[1]].y, vertices[pair[1]].z
|
| 1673 |
+
];
|
| 1674 |
+
|
| 1675 |
+
const lineGeometry = new THREE.LineGeometry();
|
| 1676 |
+
lineGeometry.setPositions(positions);
|
| 1677 |
+
|
| 1678 |
+
let color = index < 4 ? colors.edge : colors.ray;
|
| 1679 |
+
|
| 1680 |
+
const lineMaterial = new THREE.LineMaterial({
|
| 1681 |
+
color: color,
|
| 1682 |
+
linewidth: 2,
|
| 1683 |
+
resolution: resolution,
|
| 1684 |
+
dashed: false
|
| 1685 |
+
});
|
| 1686 |
+
|
| 1687 |
+
const line = new THREE.Line2(lineGeometry, lineMaterial);
|
| 1688 |
+
this.cameraFrustum.add(line);
|
| 1689 |
+
});
|
| 1690 |
+
}
|
| 1691 |
+
|
| 1692 |
+
updateCameraFrustum(frameIndex) {
|
| 1693 |
+
if (!this.cameraFrustum || !this.data) return;
|
| 1694 |
+
|
| 1695 |
+
const invExtrinsics = this.data.inv_extrinsics;
|
| 1696 |
+
if (!invExtrinsics) return;
|
| 1697 |
+
|
| 1698 |
+
const invExtrMat = this.get4x4Matrix(invExtrinsics.data, invExtrinsics.shape, frameIndex);
|
| 1699 |
+
|
| 1700 |
+
const matrix = new THREE.Matrix4();
|
| 1701 |
+
matrix.set(
|
| 1702 |
+
invExtrMat[0][0], invExtrMat[0][1], invExtrMat[0][2], invExtrMat[0][3],
|
| 1703 |
+
invExtrMat[1][0], invExtrMat[1][1], invExtrMat[1][2], invExtrMat[1][3],
|
| 1704 |
+
invExtrMat[2][0], invExtrMat[2][1], invExtrMat[2][2], invExtrMat[2][3],
|
| 1705 |
+
invExtrMat[3][0], invExtrMat[3][1], invExtrMat[3][2], invExtrMat[3][3]
|
| 1706 |
+
);
|
| 1707 |
+
|
| 1708 |
+
const position = new THREE.Vector3();
|
| 1709 |
+
position.setFromMatrixPosition(matrix);
|
| 1710 |
+
|
| 1711 |
+
const rotMatrix = new THREE.Matrix4().extractRotation(matrix);
|
| 1712 |
+
|
| 1713 |
+
const coordinateCorrection = new THREE.Matrix4().makeRotationX(Math.PI);
|
| 1714 |
+
|
| 1715 |
+
const finalRotation = new THREE.Matrix4().multiplyMatrices(coordinateCorrection, rotMatrix);
|
| 1716 |
+
|
| 1717 |
+
const quaternion = new THREE.Quaternion();
|
| 1718 |
+
quaternion.setFromRotationMatrix(finalRotation);
|
| 1719 |
+
|
| 1720 |
+
position.y = -position.y;
|
| 1721 |
+
position.z = -position.z;
|
| 1722 |
+
|
| 1723 |
+
this.cameraFrustum.position.copy(position);
|
| 1724 |
+
this.cameraFrustum.quaternion.copy(quaternion);
|
| 1725 |
+
|
| 1726 |
+
const showCameraFrustum = this.ui.showCameraFrustum ? this.ui.showCameraFrustum.checked : this.defaultSettings.showCameraFrustum;
|
| 1727 |
+
|
| 1728 |
+
if (this.cameraFrustum.visible !== showCameraFrustum) {
|
| 1729 |
+
this.cameraFrustum.visible = showCameraFrustum;
|
| 1730 |
+
}
|
| 1731 |
+
|
| 1732 |
+
const resolution = new THREE.Vector2(window.innerWidth, window.innerHeight);
|
| 1733 |
+
this.cameraFrustum.children.forEach(line => {
|
| 1734 |
+
if (line.material && line.material.resolution) {
|
| 1735 |
+
line.material.resolution.copy(resolution);
|
| 1736 |
+
}
|
| 1737 |
+
});
|
| 1738 |
+
}
|
| 1739 |
+
|
| 1740 |
+
updateFrustumDimensions() {
|
| 1741 |
+
if (!this.cameraFrustum) return;
|
| 1742 |
+
|
| 1743 |
+
while(this.cameraFrustum.children.length > 0) {
|
| 1744 |
+
const child = this.cameraFrustum.children[0];
|
| 1745 |
+
if (child.geometry) child.geometry.dispose();
|
| 1746 |
+
if (child.material) child.material.dispose();
|
| 1747 |
+
this.cameraFrustum.remove(child);
|
| 1748 |
+
}
|
| 1749 |
+
|
| 1750 |
+
this.initCameraFrustumGeometry();
|
| 1751 |
+
|
| 1752 |
+
this.updateCameraFrustum(this.currentFrame);
|
| 1753 |
+
}
|
| 1754 |
+
|
| 1755 |
+
resetSettings() {
|
| 1756 |
+
if (!this.defaultSettings) return;
|
| 1757 |
+
|
| 1758 |
+
this.applyDefaultSettings();
|
| 1759 |
+
|
| 1760 |
+
this.updatePointCloudSettings();
|
| 1761 |
+
this.updateTrajectorySettings();
|
| 1762 |
+
this.updateFrustumDimensions();
|
| 1763 |
+
|
| 1764 |
+
this.ui.statusBar.textContent = "Settings reset to defaults";
|
| 1765 |
+
this.ui.statusBar.classList.remove('hidden');
|
| 1766 |
+
|
| 1767 |
+
setTimeout(() => {
|
| 1768 |
+
this.ui.statusBar.classList.add('hidden');
|
| 1769 |
+
}, 3000);
|
| 1770 |
+
}
|
| 1771 |
+
}
|
| 1772 |
+
|
| 1773 |
+
window.addEventListener('DOMContentLoaded', () => {
|
| 1774 |
+
new PointCloudVisualizer();
|
| 1775 |
+
});
|
| 1776 |
+
</script>
|
| 1777 |
+
</body>
|
| 1778 |
+
</html>
|
app.py
ADDED
|
@@ -0,0 +1,1118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import os
|
| 3 |
+
import json
|
| 4 |
+
import numpy as np
|
| 5 |
+
import cv2
|
| 6 |
+
import base64
|
| 7 |
+
import time
|
| 8 |
+
import tempfile
|
| 9 |
+
import shutil
|
| 10 |
+
import glob
|
| 11 |
+
import threading
|
| 12 |
+
import subprocess
|
| 13 |
+
import struct
|
| 14 |
+
import zlib
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
from typing import List, Tuple, Union
|
| 18 |
+
try:
|
| 19 |
+
import spaces
|
| 20 |
+
except ImportError:
|
| 21 |
+
# Fallback for local development
|
| 22 |
+
def spaces(func):
|
| 23 |
+
return func
|
| 24 |
+
import torch
|
| 25 |
+
import logging
|
| 26 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 27 |
+
import atexit
|
| 28 |
+
import uuid
|
| 29 |
+
|
| 30 |
+
# Configure logging
|
| 31 |
+
logging.basicConfig(level=logging.INFO)
|
| 32 |
+
logger = logging.getLogger(__name__)
|
| 33 |
+
|
| 34 |
+
# Import custom modules with error handling
|
| 35 |
+
try:
|
| 36 |
+
from app_3rd.sam_utils.inference import SamPredictor, get_sam_predictor, run_inference
|
| 37 |
+
from app_3rd.spatrack_utils.infer_track import get_tracker_predictor, run_tracker, get_points_on_a_grid
|
| 38 |
+
except ImportError as e:
|
| 39 |
+
logger.error(f"Failed to import custom modules: {e}")
|
| 40 |
+
raise
|
| 41 |
+
|
| 42 |
+
# Constants
|
| 43 |
+
MAX_FRAMES = 80
|
| 44 |
+
COLORS = [(0, 0, 255), (0, 255, 255)] # BGR: Red for negative, Yellow for positive
|
| 45 |
+
MARKERS = [1, 5] # Cross for negative, Star for positive
|
| 46 |
+
MARKER_SIZE = 8
|
| 47 |
+
|
| 48 |
+
# Thread pool for delayed deletion
|
| 49 |
+
thread_pool_executor = ThreadPoolExecutor(max_workers=2)
|
| 50 |
+
|
| 51 |
+
def delete_later(path: Union[str, os.PathLike], delay: int = 600):
|
| 52 |
+
"""Delete file or directory after specified delay (default 10 minutes)"""
|
| 53 |
+
def _delete():
|
| 54 |
+
try:
|
| 55 |
+
if os.path.isfile(path):
|
| 56 |
+
os.remove(path)
|
| 57 |
+
elif os.path.isdir(path):
|
| 58 |
+
shutil.rmtree(path)
|
| 59 |
+
except Exception as e:
|
| 60 |
+
logger.warning(f"Failed to delete {path}: {e}")
|
| 61 |
+
|
| 62 |
+
def _wait_and_delete():
|
| 63 |
+
time.sleep(delay)
|
| 64 |
+
_delete()
|
| 65 |
+
|
| 66 |
+
thread_pool_executor.submit(_wait_and_delete)
|
| 67 |
+
atexit.register(_delete)
|
| 68 |
+
|
| 69 |
+
def create_user_temp_dir():
|
| 70 |
+
"""Create a unique temporary directory for each user session"""
|
| 71 |
+
session_id = str(uuid.uuid4())[:8] # Short unique ID
|
| 72 |
+
temp_dir = os.path.join("temp_local", f"session_{session_id}")
|
| 73 |
+
os.makedirs(temp_dir, exist_ok=True)
|
| 74 |
+
|
| 75 |
+
# Schedule deletion after 10 minutes
|
| 76 |
+
delete_later(temp_dir, delay=600)
|
| 77 |
+
|
| 78 |
+
return temp_dir
|
| 79 |
+
|
| 80 |
+
from huggingface_hub import hf_hub_download
|
| 81 |
+
# init the model
|
| 82 |
+
os.environ["VGGT_DIR"] = hf_hub_download("Yuxihenry/SpatialTrackerCkpts", "spatrack_front.pth") #, force_download=True)
|
| 83 |
+
|
| 84 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 85 |
+
from models.vggt.vggt.models.vggt_moe import VGGT_MoE
|
| 86 |
+
from models.vggt.vggt.utils.load_fn import preprocess_image
|
| 87 |
+
vggt_model = VGGT_MoE()
|
| 88 |
+
vggt_model.load_state_dict(torch.load(os.environ.get("VGGT_DIR")), strict=False)
|
| 89 |
+
vggt_model.eval()
|
| 90 |
+
vggt_model = vggt_model.to("cuda")
|
| 91 |
+
|
| 92 |
+
# Global model initialization
|
| 93 |
+
print("🚀 Initializing local models...")
|
| 94 |
+
tracker_model, _ = get_tracker_predictor(".", vo_points=756)
|
| 95 |
+
predictor = get_sam_predictor()
|
| 96 |
+
print("✅ Models loaded successfully!")
|
| 97 |
+
|
| 98 |
+
gr.set_static_paths(paths=[Path.cwd().absolute()/"_viz"])
|
| 99 |
+
|
| 100 |
+
@spaces.GPU
|
| 101 |
+
def gpu_run_inference(predictor_arg, image, points, boxes):
|
| 102 |
+
"""GPU-accelerated SAM inference"""
|
| 103 |
+
if predictor_arg is None:
|
| 104 |
+
print("Initializing SAM predictor inside GPU function...")
|
| 105 |
+
predictor_arg = get_sam_predictor(predictor=predictor)
|
| 106 |
+
|
| 107 |
+
# Ensure predictor is on GPU
|
| 108 |
+
try:
|
| 109 |
+
if hasattr(predictor_arg, 'model'):
|
| 110 |
+
predictor_arg.model = predictor_arg.model.cuda()
|
| 111 |
+
elif hasattr(predictor_arg, 'sam'):
|
| 112 |
+
predictor_arg.sam = predictor_arg.sam.cuda()
|
| 113 |
+
elif hasattr(predictor_arg, 'to'):
|
| 114 |
+
predictor_arg = predictor_arg.to('cuda')
|
| 115 |
+
|
| 116 |
+
if hasattr(image, 'cuda'):
|
| 117 |
+
image = image.cuda()
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Warning: Could not move predictor to GPU: {e}")
|
| 121 |
+
|
| 122 |
+
return run_inference(predictor_arg, image, points, boxes)
|
| 123 |
+
|
| 124 |
+
@spaces.GPU
|
| 125 |
+
def gpu_run_tracker(tracker_model_arg, tracker_viser_arg, temp_dir, video_name, grid_size, vo_points, fps, mode="offline"):
|
| 126 |
+
"""GPU-accelerated tracking"""
|
| 127 |
+
import torchvision.transforms as T
|
| 128 |
+
import decord
|
| 129 |
+
|
| 130 |
+
if tracker_model_arg is None or tracker_viser_arg is None:
|
| 131 |
+
print("Initializing tracker models inside GPU function...")
|
| 132 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 133 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 134 |
+
tracker_model_arg, tracker_viser_arg = get_tracker_predictor(out_dir, vo_points=vo_points, tracker_model=tracker_model)
|
| 135 |
+
|
| 136 |
+
# Setup paths
|
| 137 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 138 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 139 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 140 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 141 |
+
|
| 142 |
+
# Load video using decord
|
| 143 |
+
video_reader = decord.VideoReader(video_path)
|
| 144 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)
|
| 145 |
+
|
| 146 |
+
# Resize to ensure minimum side is 336
|
| 147 |
+
h, w = video_tensor.shape[2:]
|
| 148 |
+
scale = max(224 / h, 224 / w)
|
| 149 |
+
if scale < 1:
|
| 150 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 151 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 152 |
+
video_tensor = video_tensor[::fps].float()[:MAX_FRAMES]
|
| 153 |
+
|
| 154 |
+
# Move to GPU
|
| 155 |
+
video_tensor = video_tensor.cuda()
|
| 156 |
+
print(f"Video tensor shape: {video_tensor.shape}, device: {video_tensor.device}")
|
| 157 |
+
|
| 158 |
+
depth_tensor = None
|
| 159 |
+
intrs = None
|
| 160 |
+
extrs = None
|
| 161 |
+
data_npz_load = {}
|
| 162 |
+
|
| 163 |
+
# run vggt
|
| 164 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 165 |
+
# process the image tensor
|
| 166 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
| 167 |
+
with torch.no_grad():
|
| 168 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 169 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
| 170 |
+
predictions = vggt_model(video_tensor.cuda()/255)
|
| 171 |
+
extrinsic, intrinsic = predictions["poses_pred"], predictions["intrs"]
|
| 172 |
+
depth_map, depth_conf = predictions["points_map"][..., 2], predictions["unc_metric"]
|
| 173 |
+
|
| 174 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 175 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 176 |
+
extrs = extrinsic.squeeze().cpu().numpy()
|
| 177 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
| 178 |
+
video_tensor = video_tensor.squeeze()
|
| 179 |
+
#NOTE: 20% of the depth is not reliable
|
| 180 |
+
# threshold = depth_conf.squeeze()[0].view(-1).quantile(0.6).item()
|
| 181 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 182 |
+
|
| 183 |
+
# Load and process mask
|
| 184 |
+
if os.path.exists(mask_path):
|
| 185 |
+
mask = cv2.imread(mask_path)
|
| 186 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
| 187 |
+
mask = mask.sum(axis=-1)>0
|
| 188 |
+
else:
|
| 189 |
+
mask = np.ones_like(video_tensor[0,0].cpu().numpy())>0
|
| 190 |
+
grid_size = 10
|
| 191 |
+
|
| 192 |
+
# Get frame dimensions and create grid points
|
| 193 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
| 194 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cuda")
|
| 195 |
+
|
| 196 |
+
# Sample mask values at grid points and filter
|
| 197 |
+
if os.path.exists(mask_path):
|
| 198 |
+
grid_pts_int = grid_pts[0].long()
|
| 199 |
+
mask_values = mask[grid_pts_int.cpu()[...,1], grid_pts_int.cpu()[...,0]]
|
| 200 |
+
grid_pts = grid_pts[:, mask_values]
|
| 201 |
+
|
| 202 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].cpu().numpy()
|
| 203 |
+
print(f"Query points shape: {query_xyt.shape}")
|
| 204 |
+
|
| 205 |
+
# Run model inference
|
| 206 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 207 |
+
(
|
| 208 |
+
c2w_traj, intrs, point_map, conf_depth,
|
| 209 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 210 |
+
) = tracker_model_arg.forward(video_tensor, depth=depth_tensor,
|
| 211 |
+
intrs=intrs, extrs=extrs,
|
| 212 |
+
queries=query_xyt,
|
| 213 |
+
fps=1, full_point=False, iters_track=4,
|
| 214 |
+
query_no_BA=True, fixed_cam=False, stage=1, unc_metric=unc_metric,
|
| 215 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 216 |
+
|
| 217 |
+
# Resize results to avoid large I/O
|
| 218 |
+
max_size = 224
|
| 219 |
+
h, w = video.shape[2:]
|
| 220 |
+
scale = min(max_size / h, max_size / w)
|
| 221 |
+
if scale < 1:
|
| 222 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 223 |
+
video = T.Resize((new_h, new_w))(video)
|
| 224 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 225 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
| 226 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
| 227 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
| 228 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 229 |
+
|
| 230 |
+
# Visualize tracks
|
| 231 |
+
tracker_viser_arg.visualize(video=video[None],
|
| 232 |
+
tracks=track2d_pred[None][...,:2],
|
| 233 |
+
visibility=vis_pred[None],filename="test")
|
| 234 |
+
|
| 235 |
+
# Save in tapip3d format
|
| 236 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 237 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 238 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
| 239 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
| 240 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
| 241 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
| 242 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 243 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 244 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
| 245 |
+
|
| 246 |
+
return None
|
| 247 |
+
|
| 248 |
+
def compress_and_write(filename, header, blob):
|
| 249 |
+
header_bytes = json.dumps(header).encode("utf-8")
|
| 250 |
+
header_len = struct.pack("<I", len(header_bytes))
|
| 251 |
+
with open(filename, "wb") as f:
|
| 252 |
+
f.write(header_len)
|
| 253 |
+
f.write(header_bytes)
|
| 254 |
+
f.write(blob)
|
| 255 |
+
|
| 256 |
+
def process_point_cloud_data(npz_file, width=256, height=192, fps=4):
|
| 257 |
+
fixed_size = (width, height)
|
| 258 |
+
|
| 259 |
+
data = np.load(npz_file)
|
| 260 |
+
extrinsics = data["extrinsics"]
|
| 261 |
+
intrinsics = data["intrinsics"]
|
| 262 |
+
trajs = data["coords"]
|
| 263 |
+
T, C, H, W = data["video"].shape
|
| 264 |
+
|
| 265 |
+
fx = intrinsics[0, 0, 0]
|
| 266 |
+
fy = intrinsics[0, 1, 1]
|
| 267 |
+
fov_y = 2 * np.arctan(H / (2 * fy)) * (180 / np.pi)
|
| 268 |
+
fov_x = 2 * np.arctan(W / (2 * fx)) * (180 / np.pi)
|
| 269 |
+
original_aspect_ratio = (W / fx) / (H / fy)
|
| 270 |
+
|
| 271 |
+
rgb_video = (rearrange(data["video"], "T C H W -> T H W C") * 255).astype(np.uint8)
|
| 272 |
+
rgb_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_AREA)
|
| 273 |
+
for frame in rgb_video])
|
| 274 |
+
|
| 275 |
+
depth_video = data["depths"].astype(np.float32)
|
| 276 |
+
if "confs_depth" in data.keys():
|
| 277 |
+
confs = (data["confs_depth"].astype(np.float32) > 0.5).astype(np.float32)
|
| 278 |
+
depth_video = depth_video * confs
|
| 279 |
+
depth_video = np.stack([cv2.resize(frame, fixed_size, interpolation=cv2.INTER_NEAREST)
|
| 280 |
+
for frame in depth_video])
|
| 281 |
+
|
| 282 |
+
scale_x = fixed_size[0] / W
|
| 283 |
+
scale_y = fixed_size[1] / H
|
| 284 |
+
intrinsics = intrinsics.copy()
|
| 285 |
+
intrinsics[:, 0, :] *= scale_x
|
| 286 |
+
intrinsics[:, 1, :] *= scale_y
|
| 287 |
+
|
| 288 |
+
min_depth = float(depth_video.min()) * 0.8
|
| 289 |
+
max_depth = float(depth_video.max()) * 1.5
|
| 290 |
+
|
| 291 |
+
depth_normalized = (depth_video - min_depth) / (max_depth - min_depth)
|
| 292 |
+
depth_int = (depth_normalized * ((1 << 16) - 1)).astype(np.uint16)
|
| 293 |
+
|
| 294 |
+
depths_rgb = np.zeros((T, fixed_size[1], fixed_size[0], 3), dtype=np.uint8)
|
| 295 |
+
depths_rgb[:, :, :, 0] = (depth_int & 0xFF).astype(np.uint8)
|
| 296 |
+
depths_rgb[:, :, :, 1] = ((depth_int >> 8) & 0xFF).astype(np.uint8)
|
| 297 |
+
|
| 298 |
+
first_frame_inv = np.linalg.inv(extrinsics[0])
|
| 299 |
+
normalized_extrinsics = np.array([first_frame_inv @ ext for ext in extrinsics])
|
| 300 |
+
|
| 301 |
+
normalized_trajs = np.zeros_like(trajs)
|
| 302 |
+
for t in range(T):
|
| 303 |
+
homogeneous_trajs = np.concatenate([trajs[t], np.ones((trajs.shape[1], 1))], axis=1)
|
| 304 |
+
transformed_trajs = (first_frame_inv @ homogeneous_trajs.T).T
|
| 305 |
+
normalized_trajs[t] = transformed_trajs[:, :3]
|
| 306 |
+
|
| 307 |
+
arrays = {
|
| 308 |
+
"rgb_video": rgb_video,
|
| 309 |
+
"depths_rgb": depths_rgb,
|
| 310 |
+
"intrinsics": intrinsics,
|
| 311 |
+
"extrinsics": normalized_extrinsics,
|
| 312 |
+
"inv_extrinsics": np.linalg.inv(normalized_extrinsics),
|
| 313 |
+
"trajectories": normalized_trajs.astype(np.float32),
|
| 314 |
+
"cameraZ": 0.0
|
| 315 |
+
}
|
| 316 |
+
|
| 317 |
+
header = {}
|
| 318 |
+
blob_parts = []
|
| 319 |
+
offset = 0
|
| 320 |
+
for key, arr in arrays.items():
|
| 321 |
+
arr = np.ascontiguousarray(arr)
|
| 322 |
+
arr_bytes = arr.tobytes()
|
| 323 |
+
header[key] = {
|
| 324 |
+
"dtype": str(arr.dtype),
|
| 325 |
+
"shape": arr.shape,
|
| 326 |
+
"offset": offset,
|
| 327 |
+
"length": len(arr_bytes)
|
| 328 |
+
}
|
| 329 |
+
blob_parts.append(arr_bytes)
|
| 330 |
+
offset += len(arr_bytes)
|
| 331 |
+
|
| 332 |
+
raw_blob = b"".join(blob_parts)
|
| 333 |
+
compressed_blob = zlib.compress(raw_blob, level=9)
|
| 334 |
+
|
| 335 |
+
header["meta"] = {
|
| 336 |
+
"depthRange": [min_depth, max_depth],
|
| 337 |
+
"totalFrames": int(T),
|
| 338 |
+
"resolution": fixed_size,
|
| 339 |
+
"baseFrameRate": fps,
|
| 340 |
+
"numTrajectoryPoints": normalized_trajs.shape[1],
|
| 341 |
+
"fov": float(fov_y),
|
| 342 |
+
"fov_x": float(fov_x),
|
| 343 |
+
"original_aspect_ratio": float(original_aspect_ratio),
|
| 344 |
+
"fixed_aspect_ratio": float(fixed_size[0]/fixed_size[1])
|
| 345 |
+
}
|
| 346 |
+
|
| 347 |
+
compress_and_write('./_viz/data.bin', header, compressed_blob)
|
| 348 |
+
with open('./_viz/data.bin', "rb") as f:
|
| 349 |
+
encoded_blob = base64.b64encode(f.read()).decode("ascii")
|
| 350 |
+
os.unlink('./_viz/data.bin')
|
| 351 |
+
|
| 352 |
+
random_path = f'./_viz/_{time.time()}.html'
|
| 353 |
+
with open('./_viz/viz_template.html') as f:
|
| 354 |
+
html_template = f.read()
|
| 355 |
+
html_out = html_template.replace(
|
| 356 |
+
"<head>",
|
| 357 |
+
f"<head>\n<script>window.embeddedBase64 = `{encoded_blob}`;</script>"
|
| 358 |
+
)
|
| 359 |
+
with open(random_path,'w') as f:
|
| 360 |
+
f.write(html_out)
|
| 361 |
+
|
| 362 |
+
return random_path
|
| 363 |
+
|
| 364 |
+
def numpy_to_base64(arr):
|
| 365 |
+
"""Convert numpy array to base64 string"""
|
| 366 |
+
return base64.b64encode(arr.tobytes()).decode('utf-8')
|
| 367 |
+
|
| 368 |
+
def base64_to_numpy(b64_str, shape, dtype):
|
| 369 |
+
"""Convert base64 string back to numpy array"""
|
| 370 |
+
return np.frombuffer(base64.b64decode(b64_str), dtype=dtype).reshape(shape)
|
| 371 |
+
|
| 372 |
+
def get_video_name(video_path):
|
| 373 |
+
"""Extract video name without extension"""
|
| 374 |
+
return os.path.splitext(os.path.basename(video_path))[0]
|
| 375 |
+
|
| 376 |
+
def extract_first_frame(video_path):
|
| 377 |
+
"""Extract first frame from video file"""
|
| 378 |
+
try:
|
| 379 |
+
cap = cv2.VideoCapture(video_path)
|
| 380 |
+
ret, frame = cap.read()
|
| 381 |
+
cap.release()
|
| 382 |
+
|
| 383 |
+
if ret:
|
| 384 |
+
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 385 |
+
return frame_rgb
|
| 386 |
+
else:
|
| 387 |
+
return None
|
| 388 |
+
except Exception as e:
|
| 389 |
+
print(f"Error extracting first frame: {e}")
|
| 390 |
+
return None
|
| 391 |
+
|
| 392 |
+
def handle_video_upload(video):
|
| 393 |
+
"""Handle video upload and extract first frame"""
|
| 394 |
+
if video is None:
|
| 395 |
+
return (None, None, [],
|
| 396 |
+
gr.update(value=50),
|
| 397 |
+
gr.update(value=756),
|
| 398 |
+
gr.update(value=3))
|
| 399 |
+
|
| 400 |
+
# Create user-specific temporary directory
|
| 401 |
+
user_temp_dir = create_user_temp_dir()
|
| 402 |
+
|
| 403 |
+
# Get original video name and copy to temp directory
|
| 404 |
+
if isinstance(video, str):
|
| 405 |
+
video_name = get_video_name(video)
|
| 406 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
| 407 |
+
shutil.copy(video, video_path)
|
| 408 |
+
else:
|
| 409 |
+
video_name = get_video_name(video.name)
|
| 410 |
+
video_path = os.path.join(user_temp_dir, f"{video_name}.mp4")
|
| 411 |
+
with open(video_path, 'wb') as f:
|
| 412 |
+
f.write(video.read())
|
| 413 |
+
|
| 414 |
+
print(f"📁 Video saved to: {video_path}")
|
| 415 |
+
|
| 416 |
+
# Extract first frame
|
| 417 |
+
frame = extract_first_frame(video_path)
|
| 418 |
+
if frame is None:
|
| 419 |
+
return (None, None, [],
|
| 420 |
+
gr.update(value=50),
|
| 421 |
+
gr.update(value=756),
|
| 422 |
+
gr.update(value=3))
|
| 423 |
+
|
| 424 |
+
# Resize frame to have minimum side length of 336
|
| 425 |
+
h, w = frame.shape[:2]
|
| 426 |
+
scale = 336 / min(h, w)
|
| 427 |
+
new_h, new_w = int(h * scale)//2*2, int(w * scale)//2*2
|
| 428 |
+
frame = cv2.resize(frame, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
|
| 429 |
+
|
| 430 |
+
# Store frame data with temp directory info
|
| 431 |
+
frame_data = {
|
| 432 |
+
'data': numpy_to_base64(frame),
|
| 433 |
+
'shape': frame.shape,
|
| 434 |
+
'dtype': str(frame.dtype),
|
| 435 |
+
'temp_dir': user_temp_dir,
|
| 436 |
+
'video_name': video_name,
|
| 437 |
+
'video_path': video_path
|
| 438 |
+
}
|
| 439 |
+
|
| 440 |
+
# Get video-specific settings
|
| 441 |
+
print(f"🎬 Video path: '{video}' -> Video name: '{video_name}'")
|
| 442 |
+
grid_size_val, vo_points_val, fps_val = get_video_settings(video_name)
|
| 443 |
+
print(f"🎬 Video settings for '{video_name}': grid_size={grid_size_val}, vo_points={vo_points_val}, fps={fps_val}")
|
| 444 |
+
|
| 445 |
+
return (json.dumps(frame_data), frame, [],
|
| 446 |
+
gr.update(value=grid_size_val),
|
| 447 |
+
gr.update(value=vo_points_val),
|
| 448 |
+
gr.update(value=fps_val))
|
| 449 |
+
|
| 450 |
+
def save_masks(o_masks, video_name, temp_dir):
|
| 451 |
+
"""Save binary masks to files in user-specific temp directory"""
|
| 452 |
+
o_files = []
|
| 453 |
+
for mask, _ in o_masks:
|
| 454 |
+
o_mask = np.uint8(mask.squeeze() * 255)
|
| 455 |
+
o_file = os.path.join(temp_dir, f"{video_name}.png")
|
| 456 |
+
cv2.imwrite(o_file, o_mask)
|
| 457 |
+
o_files.append(o_file)
|
| 458 |
+
return o_files
|
| 459 |
+
|
| 460 |
+
def select_point(original_img: str, sel_pix: list, point_type: str, evt: gr.SelectData):
|
| 461 |
+
"""Handle point selection for SAM"""
|
| 462 |
+
if original_img is None:
|
| 463 |
+
return None, []
|
| 464 |
+
|
| 465 |
+
try:
|
| 466 |
+
# Convert stored image data back to numpy array
|
| 467 |
+
frame_data = json.loads(original_img)
|
| 468 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
| 469 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 470 |
+
video_name = frame_data.get('video_name', 'video')
|
| 471 |
+
|
| 472 |
+
# Create a display image for visualization
|
| 473 |
+
display_img = original_img_array.copy()
|
| 474 |
+
new_sel_pix = sel_pix.copy() if sel_pix else []
|
| 475 |
+
new_sel_pix.append((evt.index, 1 if point_type == 'positive_point' else 0))
|
| 476 |
+
|
| 477 |
+
print(f"🎯 Running SAM inference for point: {evt.index}, type: {point_type}")
|
| 478 |
+
# Run SAM inference
|
| 479 |
+
o_masks = gpu_run_inference(None, original_img_array, new_sel_pix, [])
|
| 480 |
+
|
| 481 |
+
# Draw points on display image
|
| 482 |
+
for point, label in new_sel_pix:
|
| 483 |
+
cv2.drawMarker(display_img, point, COLORS[label], markerType=MARKERS[label], markerSize=MARKER_SIZE, thickness=2)
|
| 484 |
+
|
| 485 |
+
# Draw mask overlay on display image
|
| 486 |
+
if o_masks:
|
| 487 |
+
mask = o_masks[0][0]
|
| 488 |
+
overlay = display_img.copy()
|
| 489 |
+
overlay[mask.squeeze()!=0] = [20, 60, 200] # Light blue
|
| 490 |
+
display_img = cv2.addWeighted(overlay, 0.6, display_img, 0.4, 0)
|
| 491 |
+
|
| 492 |
+
# Save mask for tracking
|
| 493 |
+
save_masks(o_masks, video_name, temp_dir)
|
| 494 |
+
print(f"✅ Mask saved for video: {video_name}")
|
| 495 |
+
|
| 496 |
+
return display_img, new_sel_pix
|
| 497 |
+
|
| 498 |
+
except Exception as e:
|
| 499 |
+
print(f"❌ Error in select_point: {e}")
|
| 500 |
+
return None, []
|
| 501 |
+
|
| 502 |
+
def reset_points(original_img: str, sel_pix):
|
| 503 |
+
"""Reset all points and clear the mask"""
|
| 504 |
+
if original_img is None:
|
| 505 |
+
return None, []
|
| 506 |
+
|
| 507 |
+
try:
|
| 508 |
+
# Convert stored image data back to numpy array
|
| 509 |
+
frame_data = json.loads(original_img)
|
| 510 |
+
original_img_array = base64_to_numpy(frame_data['data'], frame_data['shape'], frame_data['dtype'])
|
| 511 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 512 |
+
|
| 513 |
+
# Create a display image (just the original image)
|
| 514 |
+
display_img = original_img_array.copy()
|
| 515 |
+
|
| 516 |
+
# Clear all points
|
| 517 |
+
new_sel_pix = []
|
| 518 |
+
|
| 519 |
+
# Clear any existing masks
|
| 520 |
+
for mask_file in glob.glob(os.path.join(temp_dir, "*.png")):
|
| 521 |
+
try:
|
| 522 |
+
os.remove(mask_file)
|
| 523 |
+
except Exception as e:
|
| 524 |
+
logger.warning(f"Failed to remove mask file {mask_file}: {e}")
|
| 525 |
+
|
| 526 |
+
print("🔄 Points and masks reset")
|
| 527 |
+
return display_img, new_sel_pix
|
| 528 |
+
|
| 529 |
+
except Exception as e:
|
| 530 |
+
print(f"❌ Error in reset_points: {e}")
|
| 531 |
+
return None, []
|
| 532 |
+
|
| 533 |
+
def launch_viz(grid_size, vo_points, fps, original_image_state, mode="offline"):
|
| 534 |
+
"""Launch visualization with user-specific temp directory"""
|
| 535 |
+
if original_image_state is None:
|
| 536 |
+
return None, None, None
|
| 537 |
+
|
| 538 |
+
try:
|
| 539 |
+
# Get user's temp directory from stored frame data
|
| 540 |
+
frame_data = json.loads(original_image_state)
|
| 541 |
+
temp_dir = frame_data.get('temp_dir', 'temp_local')
|
| 542 |
+
video_name = frame_data.get('video_name', 'video')
|
| 543 |
+
|
| 544 |
+
print(f"🚀 Starting tracking for video: {video_name}")
|
| 545 |
+
print(f"📊 Parameters: grid_size={grid_size}, vo_points={vo_points}, fps={fps}")
|
| 546 |
+
|
| 547 |
+
# Check for mask files
|
| 548 |
+
mask_files = glob.glob(os.path.join(temp_dir, "*.png"))
|
| 549 |
+
video_files = glob.glob(os.path.join(temp_dir, "*.mp4"))
|
| 550 |
+
|
| 551 |
+
if not video_files:
|
| 552 |
+
print("❌ No video file found")
|
| 553 |
+
return "❌ Error: No video file found", None, None
|
| 554 |
+
|
| 555 |
+
video_path = video_files[0]
|
| 556 |
+
mask_path = mask_files[0] if mask_files else None
|
| 557 |
+
|
| 558 |
+
# Run tracker
|
| 559 |
+
print("🎯 Running tracker...")
|
| 560 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 561 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 562 |
+
|
| 563 |
+
gpu_run_tracker(None, None, temp_dir, video_name, grid_size, vo_points, fps, mode=mode)
|
| 564 |
+
|
| 565 |
+
# Process results
|
| 566 |
+
npz_path = os.path.join(out_dir, "result.npz")
|
| 567 |
+
track2d_video = os.path.join(out_dir, "test_pred_track.mp4")
|
| 568 |
+
|
| 569 |
+
if os.path.exists(npz_path):
|
| 570 |
+
print("📊 Processing 3D visualization...")
|
| 571 |
+
html_path = process_point_cloud_data(npz_path)
|
| 572 |
+
|
| 573 |
+
# Schedule deletion of generated files
|
| 574 |
+
delete_later(html_path, delay=600)
|
| 575 |
+
if os.path.exists(track2d_video):
|
| 576 |
+
delete_later(track2d_video, delay=600)
|
| 577 |
+
delete_later(npz_path, delay=600)
|
| 578 |
+
|
| 579 |
+
# Create iframe HTML
|
| 580 |
+
iframe_html = f"""
|
| 581 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
| 582 |
+
background: #f8f9ff; height: 650px; width: 100%;
|
| 583 |
+
box-shadow: 0 8px 32px rgba(102, 126, 234, 0.3);
|
| 584 |
+
margin: 0; padding: 0; box-sizing: border-box; overflow: hidden;'>
|
| 585 |
+
<iframe id="viz_iframe" src="/gradio_api/file={html_path}"
|
| 586 |
+
width="100%" height="650" frameborder="0"
|
| 587 |
+
style="border: none; display: block; width: 100%; height: 650px;
|
| 588 |
+
margin: 0; padding: 0; border-radius: 7px;">
|
| 589 |
+
</iframe>
|
| 590 |
+
</div>
|
| 591 |
+
"""
|
| 592 |
+
|
| 593 |
+
print("✅ Tracking completed successfully!")
|
| 594 |
+
return iframe_html, track2d_video if os.path.exists(track2d_video) else None, html_path
|
| 595 |
+
else:
|
| 596 |
+
print("❌ Tracking failed - no results generated")
|
| 597 |
+
return "❌ Error: Tracking failed to generate results", None, None
|
| 598 |
+
|
| 599 |
+
except Exception as e:
|
| 600 |
+
print(f"❌ Error in launch_viz: {e}")
|
| 601 |
+
return f"❌ Error: {str(e)}", None, None
|
| 602 |
+
|
| 603 |
+
def clear_all():
|
| 604 |
+
"""Clear all buffers and temporary files"""
|
| 605 |
+
return (None, None, [],
|
| 606 |
+
gr.update(value=50),
|
| 607 |
+
gr.update(value=756),
|
| 608 |
+
gr.update(value=3))
|
| 609 |
+
|
| 610 |
+
def clear_all_with_download():
|
| 611 |
+
"""Clear all buffers including both download components"""
|
| 612 |
+
return (None, None, [],
|
| 613 |
+
gr.update(value=50),
|
| 614 |
+
gr.update(value=756),
|
| 615 |
+
gr.update(value=3),
|
| 616 |
+
None, # tracking_video_download
|
| 617 |
+
None) # HTML download component
|
| 618 |
+
|
| 619 |
+
def get_video_settings(video_name):
|
| 620 |
+
"""Get video-specific settings based on video name"""
|
| 621 |
+
video_settings = {
|
| 622 |
+
"running": (50, 512, 2),
|
| 623 |
+
"backpack": (40, 600, 2),
|
| 624 |
+
"kitchen": (60, 800, 3),
|
| 625 |
+
"pillow": (35, 500, 2),
|
| 626 |
+
"handwave": (35, 500, 8),
|
| 627 |
+
"hockey": (45, 700, 2),
|
| 628 |
+
"drifting": (35, 1000, 6),
|
| 629 |
+
"basketball": (45, 1500, 5),
|
| 630 |
+
"ken_block_0": (45, 700, 2),
|
| 631 |
+
"ego_kc1": (45, 500, 4),
|
| 632 |
+
"vertical_place": (45, 500, 3),
|
| 633 |
+
"ego_teaser": (45, 1200, 10),
|
| 634 |
+
"robot_unitree": (45, 500, 4),
|
| 635 |
+
"robot_3": (35, 400, 5),
|
| 636 |
+
"teleop2": (45, 256, 7),
|
| 637 |
+
"pusht": (45, 256, 10),
|
| 638 |
+
"cinema_0": (45, 356, 5),
|
| 639 |
+
"cinema_1": (45, 756, 3),
|
| 640 |
+
"robot1": (45, 600, 2),
|
| 641 |
+
"robot2": (45, 600, 2),
|
| 642 |
+
"protein": (45, 600, 2),
|
| 643 |
+
"kitchen_egocentric": (45, 600, 2),
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
return video_settings.get(video_name, (50, 756, 3))
|
| 647 |
+
|
| 648 |
+
# Create the Gradio interface
|
| 649 |
+
print("🎨 Creating Gradio interface...")
|
| 650 |
+
|
| 651 |
+
with gr.Blocks(
|
| 652 |
+
theme=gr.themes.Soft(),
|
| 653 |
+
title="🎯 [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)",
|
| 654 |
+
css="""
|
| 655 |
+
.gradio-container {
|
| 656 |
+
max-width: 1200px !important;
|
| 657 |
+
margin: auto !important;
|
| 658 |
+
}
|
| 659 |
+
.gr-button {
|
| 660 |
+
margin: 5px;
|
| 661 |
+
}
|
| 662 |
+
.gr-form {
|
| 663 |
+
background: white;
|
| 664 |
+
border-radius: 10px;
|
| 665 |
+
padding: 20px;
|
| 666 |
+
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
| 667 |
+
}
|
| 668 |
+
/* 移除 gr.Group 的默认灰色背景 */
|
| 669 |
+
.gr-form {
|
| 670 |
+
background: transparent !important;
|
| 671 |
+
border: none !important;
|
| 672 |
+
box-shadow: none !important;
|
| 673 |
+
padding: 0 !important;
|
| 674 |
+
}
|
| 675 |
+
/* 固定3D可视化器尺寸 */
|
| 676 |
+
#viz_container {
|
| 677 |
+
height: 650px !important;
|
| 678 |
+
min-height: 650px !important;
|
| 679 |
+
max-height: 650px !important;
|
| 680 |
+
width: 100% !important;
|
| 681 |
+
margin: 0 !important;
|
| 682 |
+
padding: 0 !important;
|
| 683 |
+
overflow: hidden !important;
|
| 684 |
+
}
|
| 685 |
+
#viz_container > div {
|
| 686 |
+
height: 650px !important;
|
| 687 |
+
min-height: 650px !important;
|
| 688 |
+
max-height: 650px !important;
|
| 689 |
+
width: 100% !important;
|
| 690 |
+
margin: 0 !important;
|
| 691 |
+
padding: 0 !important;
|
| 692 |
+
box-sizing: border-box !important;
|
| 693 |
+
}
|
| 694 |
+
#viz_container iframe {
|
| 695 |
+
height: 650px !important;
|
| 696 |
+
min-height: 650px !important;
|
| 697 |
+
max-height: 650px !important;
|
| 698 |
+
width: 100% !important;
|
| 699 |
+
border: none !important;
|
| 700 |
+
display: block !important;
|
| 701 |
+
margin: 0 !important;
|
| 702 |
+
padding: 0 !important;
|
| 703 |
+
box-sizing: border-box !important;
|
| 704 |
+
}
|
| 705 |
+
/* 固定视频上传组件高度 */
|
| 706 |
+
.gr-video {
|
| 707 |
+
height: 300px !important;
|
| 708 |
+
min-height: 300px !important;
|
| 709 |
+
max-height: 300px !important;
|
| 710 |
+
}
|
| 711 |
+
.gr-video video {
|
| 712 |
+
height: 260px !important;
|
| 713 |
+
max-height: 260px !important;
|
| 714 |
+
object-fit: contain !important;
|
| 715 |
+
background: #f8f9fa;
|
| 716 |
+
}
|
| 717 |
+
.gr-video .gr-video-player {
|
| 718 |
+
height: 260px !important;
|
| 719 |
+
max-height: 260px !important;
|
| 720 |
+
}
|
| 721 |
+
/* 强力移除examples的灰色背景 - 使用更通用的选择器 */
|
| 722 |
+
.horizontal-examples,
|
| 723 |
+
.horizontal-examples > *,
|
| 724 |
+
.horizontal-examples * {
|
| 725 |
+
background: transparent !important;
|
| 726 |
+
background-color: transparent !important;
|
| 727 |
+
border: none !important;
|
| 728 |
+
}
|
| 729 |
+
|
| 730 |
+
/* Examples组件水平滚动样式 */
|
| 731 |
+
.horizontal-examples [data-testid="examples"] {
|
| 732 |
+
background: transparent !important;
|
| 733 |
+
background-color: transparent !important;
|
| 734 |
+
}
|
| 735 |
+
|
| 736 |
+
.horizontal-examples [data-testid="examples"] > div {
|
| 737 |
+
background: transparent !important;
|
| 738 |
+
background-color: transparent !important;
|
| 739 |
+
overflow-x: auto !important;
|
| 740 |
+
overflow-y: hidden !important;
|
| 741 |
+
scrollbar-width: thin;
|
| 742 |
+
scrollbar-color: #667eea transparent;
|
| 743 |
+
padding: 0 !important;
|
| 744 |
+
margin-top: 10px;
|
| 745 |
+
border: none !important;
|
| 746 |
+
}
|
| 747 |
+
|
| 748 |
+
.horizontal-examples [data-testid="examples"] table {
|
| 749 |
+
display: flex !important;
|
| 750 |
+
flex-wrap: nowrap !important;
|
| 751 |
+
min-width: max-content !important;
|
| 752 |
+
gap: 15px !important;
|
| 753 |
+
padding: 10px 0;
|
| 754 |
+
background: transparent !important;
|
| 755 |
+
border: none !important;
|
| 756 |
+
}
|
| 757 |
+
|
| 758 |
+
.horizontal-examples [data-testid="examples"] tbody {
|
| 759 |
+
display: flex !important;
|
| 760 |
+
flex-direction: row !important;
|
| 761 |
+
flex-wrap: nowrap !important;
|
| 762 |
+
gap: 15px !important;
|
| 763 |
+
background: transparent !important;
|
| 764 |
+
}
|
| 765 |
+
|
| 766 |
+
.horizontal-examples [data-testid="examples"] tr {
|
| 767 |
+
display: flex !important;
|
| 768 |
+
flex-direction: column !important;
|
| 769 |
+
min-width: 160px !important;
|
| 770 |
+
max-width: 160px !important;
|
| 771 |
+
margin: 0 !important;
|
| 772 |
+
background: white !important;
|
| 773 |
+
border-radius: 12px;
|
| 774 |
+
box-shadow: 0 3px 12px rgba(0,0,0,0.12);
|
| 775 |
+
transition: all 0.3s ease;
|
| 776 |
+
cursor: pointer;
|
| 777 |
+
overflow: hidden;
|
| 778 |
+
border: none !important;
|
| 779 |
+
}
|
| 780 |
+
|
| 781 |
+
.horizontal-examples [data-testid="examples"] tr:hover {
|
| 782 |
+
transform: translateY(-4px);
|
| 783 |
+
box-shadow: 0 8px 20px rgba(102, 126, 234, 0.25);
|
| 784 |
+
}
|
| 785 |
+
|
| 786 |
+
.horizontal-examples [data-testid="examples"] td {
|
| 787 |
+
text-align: center !important;
|
| 788 |
+
padding: 0 !important;
|
| 789 |
+
border: none !important;
|
| 790 |
+
background: transparent !important;
|
| 791 |
+
}
|
| 792 |
+
|
| 793 |
+
.horizontal-examples [data-testid="examples"] td:first-child {
|
| 794 |
+
padding: 0 !important;
|
| 795 |
+
background: transparent !important;
|
| 796 |
+
}
|
| 797 |
+
|
| 798 |
+
.horizontal-examples [data-testid="examples"] video {
|
| 799 |
+
border-radius: 8px 8px 0 0 !important;
|
| 800 |
+
width: 100% !important;
|
| 801 |
+
height: 90px !important;
|
| 802 |
+
object-fit: cover !important;
|
| 803 |
+
background: #f8f9fa !important;
|
| 804 |
+
}
|
| 805 |
+
|
| 806 |
+
.horizontal-examples [data-testid="examples"] td:last-child {
|
| 807 |
+
font-size: 11px !important;
|
| 808 |
+
font-weight: 600 !important;
|
| 809 |
+
color: #333 !important;
|
| 810 |
+
padding: 8px 12px !important;
|
| 811 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%) !important;
|
| 812 |
+
border-radius: 0 0 8px 8px;
|
| 813 |
+
}
|
| 814 |
+
|
| 815 |
+
/* 滚动条样式 */
|
| 816 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar {
|
| 817 |
+
height: 8px;
|
| 818 |
+
}
|
| 819 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-track {
|
| 820 |
+
background: transparent;
|
| 821 |
+
border-radius: 4px;
|
| 822 |
+
}
|
| 823 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb {
|
| 824 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 825 |
+
border-radius: 4px;
|
| 826 |
+
}
|
| 827 |
+
.horizontal-examples [data-testid="examples"] > div::-webkit-scrollbar-thumb:hover {
|
| 828 |
+
background: linear-gradient(135deg, #5a6fd8 0%, #6a4190 100%);
|
| 829 |
+
}
|
| 830 |
+
"""
|
| 831 |
+
) as demo:
|
| 832 |
+
|
| 833 |
+
# Add prominent main title
|
| 834 |
+
|
| 835 |
+
gr.Markdown("""
|
| 836 |
+
# ✨ SpatialTrackerV2
|
| 837 |
+
|
| 838 |
+
Welcome to [SpatialTracker V2](https://github.com/henry123-boy/SpaTrackerV2)! This interface allows you to track any pixels in 3D using our model.
|
| 839 |
+
|
| 840 |
+
**⚡ Quick Start:** Upload video → Click "Start Tracking Now!"
|
| 841 |
+
|
| 842 |
+
**🔬 Advanced Usage with SAM:**
|
| 843 |
+
1. Upload a video file or select from examples below
|
| 844 |
+
2. Expand "Manual Point Selection" to click on specific objects for SAM-guided tracking
|
| 845 |
+
3. Adjust tracking parameters for optimal performance
|
| 846 |
+
4. Click "Start Tracking Now!" to begin 3D tracking with SAM guidance
|
| 847 |
+
|
| 848 |
+
""")
|
| 849 |
+
|
| 850 |
+
# Status indicator
|
| 851 |
+
gr.Markdown("**Status:** 🟢 Local Processing Mode")
|
| 852 |
+
|
| 853 |
+
# Main content area - video upload left, 3D visualization right
|
| 854 |
+
with gr.Row():
|
| 855 |
+
with gr.Column(scale=1):
|
| 856 |
+
# Video upload section
|
| 857 |
+
gr.Markdown("### 📂 Select Video")
|
| 858 |
+
|
| 859 |
+
# Define video_input here so it can be referenced in examples
|
| 860 |
+
video_input = gr.Video(
|
| 861 |
+
label="Upload Video or Select Example",
|
| 862 |
+
format="mp4",
|
| 863 |
+
height=250 # Matched height with 3D viz
|
| 864 |
+
)
|
| 865 |
+
|
| 866 |
+
|
| 867 |
+
# Traditional examples but with horizontal scroll styling
|
| 868 |
+
gr.Markdown("🎨**Examples:** (scroll horizontally to see all videos)")
|
| 869 |
+
with gr.Row(elem_classes=["horizontal-examples"]):
|
| 870 |
+
# Horizontal video examples with slider
|
| 871 |
+
# gr.HTML("<div style='margin-top: 5px;'></div>")
|
| 872 |
+
gr.Examples(
|
| 873 |
+
examples=[
|
| 874 |
+
["./examples/robot1.mp4"],
|
| 875 |
+
["./examples/robot2.mp4"],
|
| 876 |
+
["./examples/protein.mp4"],
|
| 877 |
+
["./examples/kitchen_egocentric.mp4"],
|
| 878 |
+
["./examples/hockey.mp4"],
|
| 879 |
+
["./examples/running.mp4"],
|
| 880 |
+
["./examples/robot_3.mp4"],
|
| 881 |
+
["./examples/backpack.mp4"],
|
| 882 |
+
["./examples/kitchen.mp4"],
|
| 883 |
+
["./examples/pillow.mp4"],
|
| 884 |
+
["./examples/handwave.mp4"],
|
| 885 |
+
["./examples/drifting.mp4"],
|
| 886 |
+
["./examples/basketball.mp4"],
|
| 887 |
+
["./examples/ken_block_0.mp4"],
|
| 888 |
+
["./examples/ego_kc1.mp4"],
|
| 889 |
+
["./examples/vertical_place.mp4"],
|
| 890 |
+
["./examples/ego_teaser.mp4"],
|
| 891 |
+
["./examples/robot_unitree.mp4"],
|
| 892 |
+
["./examples/teleop2.mp4"],
|
| 893 |
+
["./examples/pusht.mp4"],
|
| 894 |
+
["./examples/cinema_0.mp4"],
|
| 895 |
+
["./examples/cinema_1.mp4"],
|
| 896 |
+
],
|
| 897 |
+
inputs=[video_input],
|
| 898 |
+
outputs=[video_input],
|
| 899 |
+
fn=None,
|
| 900 |
+
cache_examples=False,
|
| 901 |
+
label="",
|
| 902 |
+
examples_per_page=6 # Show 6 examples per page so they can wrap to multiple rows
|
| 903 |
+
)
|
| 904 |
+
|
| 905 |
+
with gr.Column(scale=2):
|
| 906 |
+
# 3D Visualization - wider and taller to match left side
|
| 907 |
+
with gr.Group():
|
| 908 |
+
gr.Markdown("### 🌐 3D Trajectory Visualization")
|
| 909 |
+
viz_html = gr.HTML(
|
| 910 |
+
label="3D Trajectory Visualization",
|
| 911 |
+
value="""
|
| 912 |
+
<div style='border: 3px solid #667eea; border-radius: 10px;
|
| 913 |
+
background: linear-gradient(135deg, #f8f9ff 0%, #e6f3ff 100%);
|
| 914 |
+
text-align: center; height: 650px; display: flex;
|
| 915 |
+
flex-direction: column; justify-content: center; align-items: center;
|
| 916 |
+
box-shadow: 0 4px 16px rgba(102, 126, 234, 0.15);
|
| 917 |
+
margin: 0; padding: 20px; box-sizing: border-box;'>
|
| 918 |
+
<div style='font-size: 56px; margin-bottom: 25px;'>🌐</div>
|
| 919 |
+
<h3 style='color: #667eea; margin-bottom: 18px; font-size: 28px; font-weight: 600;'>
|
| 920 |
+
3D Trajectory Visualization
|
| 921 |
+
</h3>
|
| 922 |
+
<p style='color: #666; font-size: 18px; line-height: 1.6; max-width: 550px; margin-bottom: 30px;'>
|
| 923 |
+
Track any pixels in 3D space with camera motion
|
| 924 |
+
</p>
|
| 925 |
+
<div style='background: rgba(102, 126, 234, 0.1); border-radius: 30px;
|
| 926 |
+
padding: 15px 30px; border: 1px solid rgba(102, 126, 234, 0.2);'>
|
| 927 |
+
<span style='color: #667eea; font-weight: 600; font-size: 16px;'>
|
| 928 |
+
⚡ Powered by SpatialTracker V2
|
| 929 |
+
</span>
|
| 930 |
+
</div>
|
| 931 |
+
</div>
|
| 932 |
+
""",
|
| 933 |
+
elem_id="viz_container"
|
| 934 |
+
)
|
| 935 |
+
|
| 936 |
+
# Start button section - below video area
|
| 937 |
+
with gr.Row():
|
| 938 |
+
with gr.Column(scale=3):
|
| 939 |
+
launch_btn = gr.Button("🚀 Start Tracking Now!", variant="primary", size="lg")
|
| 940 |
+
with gr.Column(scale=1):
|
| 941 |
+
clear_all_btn = gr.Button("🗑️ Clear All", variant="secondary", size="sm")
|
| 942 |
+
|
| 943 |
+
# Tracking parameters section
|
| 944 |
+
with gr.Row():
|
| 945 |
+
gr.Markdown("### ⚙️ Tracking Parameters")
|
| 946 |
+
with gr.Row():
|
| 947 |
+
grid_size = gr.Slider(
|
| 948 |
+
minimum=10, maximum=100, step=10, value=50,
|
| 949 |
+
label="Grid Size", info="Tracking detail level"
|
| 950 |
+
)
|
| 951 |
+
vo_points = gr.Slider(
|
| 952 |
+
minimum=100, maximum=2000, step=50, value=756,
|
| 953 |
+
label="VO Points", info="Motion accuracy"
|
| 954 |
+
)
|
| 955 |
+
fps = gr.Slider(
|
| 956 |
+
minimum=1, maximum=20, step=1, value=3,
|
| 957 |
+
label="FPS", info="Processing speed"
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
# Advanced Point Selection with SAM - Collapsed by default
|
| 961 |
+
with gr.Row():
|
| 962 |
+
gr.Markdown("### 🎯 Advanced: Manual Point Selection with SAM")
|
| 963 |
+
with gr.Accordion("🔬 SAM Point Selection Controls", open=False):
|
| 964 |
+
gr.HTML("""
|
| 965 |
+
<div style='margin-bottom: 15px;'>
|
| 966 |
+
<ul style='color: #4a5568; font-size: 14px; line-height: 1.6; margin: 0; padding-left: 20px;'>
|
| 967 |
+
<li>Click on target objects in the image for SAM-guided segmentation</li>
|
| 968 |
+
<li>Positive points: include these areas | Negative points: exclude these areas</li>
|
| 969 |
+
<li>Get more accurate 3D tracking results with SAM's powerful segmentation</li>
|
| 970 |
+
</ul>
|
| 971 |
+
</div>
|
| 972 |
+
""")
|
| 973 |
+
|
| 974 |
+
with gr.Row():
|
| 975 |
+
with gr.Column():
|
| 976 |
+
interactive_frame = gr.Image(
|
| 977 |
+
label="Click to select tracking points with SAM guidance",
|
| 978 |
+
type="numpy",
|
| 979 |
+
interactive=True,
|
| 980 |
+
height=300
|
| 981 |
+
)
|
| 982 |
+
|
| 983 |
+
with gr.Row():
|
| 984 |
+
point_type = gr.Radio(
|
| 985 |
+
choices=["positive_point", "negative_point"],
|
| 986 |
+
value="positive_point",
|
| 987 |
+
label="Point Type",
|
| 988 |
+
info="Positive: track these areas | Negative: avoid these areas"
|
| 989 |
+
)
|
| 990 |
+
|
| 991 |
+
with gr.Row():
|
| 992 |
+
reset_points_btn = gr.Button("🔄 Reset Points", variant="secondary", size="sm")
|
| 993 |
+
|
| 994 |
+
# Downloads section - hidden but still functional for local processing
|
| 995 |
+
with gr.Row(visible=False):
|
| 996 |
+
with gr.Column(scale=1):
|
| 997 |
+
tracking_video_download = gr.File(
|
| 998 |
+
label="📹 Download 2D Tracking Video",
|
| 999 |
+
interactive=False,
|
| 1000 |
+
visible=False
|
| 1001 |
+
)
|
| 1002 |
+
with gr.Column(scale=1):
|
| 1003 |
+
html_download = gr.File(
|
| 1004 |
+
label="📄 Download 3D Visualization HTML",
|
| 1005 |
+
interactive=False,
|
| 1006 |
+
visible=False
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
# GitHub Star Section
|
| 1010 |
+
gr.HTML("""
|
| 1011 |
+
<div style='background: linear-gradient(135deg, #e8eaff 0%, #f0f2ff 100%);
|
| 1012 |
+
border-radius: 8px; padding: 20px; margin: 15px 0;
|
| 1013 |
+
box-shadow: 0 2px 8px rgba(102, 126, 234, 0.1);
|
| 1014 |
+
border: 1px solid rgba(102, 126, 234, 0.15);'>
|
| 1015 |
+
<div style='text-align: center;'>
|
| 1016 |
+
<h3 style='color: #4a5568; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
|
| 1017 |
+
⭐ Love SpatialTracker? Give us a Star! ⭐
|
| 1018 |
+
</h3>
|
| 1019 |
+
<p style='color: #666; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
|
| 1020 |
+
Help us grow by starring our repository on GitHub! Your support means a lot to the community. 🚀
|
| 1021 |
+
</p>
|
| 1022 |
+
<a href="https://github.com/henry123-boy/SpaTrackerV2" target="_blank"
|
| 1023 |
+
style='display: inline-flex; align-items: center; gap: 8px;
|
| 1024 |
+
background: rgba(102, 126, 234, 0.1); color: #4a5568;
|
| 1025 |
+
padding: 10px 20px; border-radius: 25px; text-decoration: none;
|
| 1026 |
+
font-weight: bold; font-size: 14px; border: 1px solid rgba(102, 126, 234, 0.2);
|
| 1027 |
+
transition: all 0.3s ease;'
|
| 1028 |
+
onmouseover="this.style.background='rgba(102, 126, 234, 0.15)'; this.style.transform='translateY(-2px)'"
|
| 1029 |
+
onmouseout="this.style.background='rgba(102, 126, 234, 0.1)'; this.style.transform='translateY(0)'">
|
| 1030 |
+
<span style='font-size: 16px;'>⭐</span>
|
| 1031 |
+
Star SpatialTracker V2 on GitHub
|
| 1032 |
+
</a>
|
| 1033 |
+
</div>
|
| 1034 |
+
</div>
|
| 1035 |
+
""")
|
| 1036 |
+
|
| 1037 |
+
# Acknowledgments Section
|
| 1038 |
+
gr.HTML("""
|
| 1039 |
+
<div style='background: linear-gradient(135deg, #fff8e1 0%, #fffbf0 100%);
|
| 1040 |
+
border-radius: 8px; padding: 20px; margin: 15px 0;
|
| 1041 |
+
box-shadow: 0 2px 8px rgba(255, 193, 7, 0.1);
|
| 1042 |
+
border: 1px solid rgba(255, 193, 7, 0.2);'>
|
| 1043 |
+
<div style='text-align: center;'>
|
| 1044 |
+
<h3 style='color: #5d4037; margin: 0 0 10px 0; font-size: 18px; font-weight: 600;'>
|
| 1045 |
+
📚 Acknowledgments
|
| 1046 |
+
</h3>
|
| 1047 |
+
<p style='color: #5d4037; margin: 0 0 15px 0; font-size: 14px; line-height: 1.5;'>
|
| 1048 |
+
Our 3D visualizer is adapted from <strong>TAPIP3D</strong>. We thank the authors for their excellent work and contribution to the computer vision community!
|
| 1049 |
+
</p>
|
| 1050 |
+
<a href="https://github.com/zbw001/TAPIP3D" target="_blank"
|
| 1051 |
+
style='display: inline-flex; align-items: center; gap: 8px;
|
| 1052 |
+
background: rgba(255, 193, 7, 0.15); color: #5d4037;
|
| 1053 |
+
padding: 10px 20px; border-radius: 25px; text-decoration: none;
|
| 1054 |
+
font-weight: bold; font-size: 14px; border: 1px solid rgba(255, 193, 7, 0.3);
|
| 1055 |
+
transition: all 0.3s ease;'
|
| 1056 |
+
onmouseover="this.style.background='rgba(255, 193, 7, 0.25)'; this.style.transform='translateY(-2px)'"
|
| 1057 |
+
onmouseout="this.style.background='rgba(255, 193, 7, 0.15)'; this.style.transform='translateY(0)'">
|
| 1058 |
+
📚 Visit TAPIP3D Repository
|
| 1059 |
+
</a>
|
| 1060 |
+
</div>
|
| 1061 |
+
</div>
|
| 1062 |
+
""")
|
| 1063 |
+
|
| 1064 |
+
# Footer
|
| 1065 |
+
gr.HTML("""
|
| 1066 |
+
<div style='text-align: center; margin: 20px 0 10px 0;'>
|
| 1067 |
+
<span style='font-size: 12px; color: #888; font-style: italic;'>
|
| 1068 |
+
Powered by SpatialTracker V2 | Built with ❤️ for the Computer Vision Community
|
| 1069 |
+
</span>
|
| 1070 |
+
</div>
|
| 1071 |
+
""")
|
| 1072 |
+
|
| 1073 |
+
# Hidden state variables
|
| 1074 |
+
original_image_state = gr.State(None)
|
| 1075 |
+
selected_points = gr.State([])
|
| 1076 |
+
|
| 1077 |
+
# Event handlers
|
| 1078 |
+
video_input.change(
|
| 1079 |
+
fn=handle_video_upload,
|
| 1080 |
+
inputs=[video_input],
|
| 1081 |
+
outputs=[original_image_state, interactive_frame, selected_points, grid_size, vo_points, fps]
|
| 1082 |
+
)
|
| 1083 |
+
|
| 1084 |
+
interactive_frame.select(
|
| 1085 |
+
fn=select_point,
|
| 1086 |
+
inputs=[original_image_state, selected_points, point_type],
|
| 1087 |
+
outputs=[interactive_frame, selected_points]
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
+
reset_points_btn.click(
|
| 1091 |
+
fn=reset_points,
|
| 1092 |
+
inputs=[original_image_state, selected_points],
|
| 1093 |
+
outputs=[interactive_frame, selected_points]
|
| 1094 |
+
)
|
| 1095 |
+
|
| 1096 |
+
clear_all_btn.click(
|
| 1097 |
+
fn=clear_all_with_download,
|
| 1098 |
+
outputs=[video_input, interactive_frame, selected_points, grid_size, vo_points, fps, tracking_video_download, html_download]
|
| 1099 |
+
)
|
| 1100 |
+
|
| 1101 |
+
launch_btn.click(
|
| 1102 |
+
fn=launch_viz,
|
| 1103 |
+
inputs=[grid_size, vo_points, fps, original_image_state],
|
| 1104 |
+
outputs=[viz_html, tracking_video_download, html_download]
|
| 1105 |
+
)
|
| 1106 |
+
|
| 1107 |
+
# Launch the interface
|
| 1108 |
+
if __name__ == "__main__":
|
| 1109 |
+
print("🌟 Launching SpatialTracker V2 Local Version...")
|
| 1110 |
+
print("🔗 Running in Local Processing Mode")
|
| 1111 |
+
|
| 1112 |
+
demo.launch(
|
| 1113 |
+
server_name="0.0.0.0",
|
| 1114 |
+
server_port=7860,
|
| 1115 |
+
share=True,
|
| 1116 |
+
debug=True,
|
| 1117 |
+
show_error=True
|
| 1118 |
+
)
|
app_3rd/README.md
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 🌟 SpatialTrackerV2 Integrated with SAM 🌟
|
| 2 |
+
SAM receives a point prompt and generates a mask for the target object, facilitating easy interaction to obtain the object's 3D trajectories with SpaTrack2.
|
| 3 |
+
|
| 4 |
+
## Installation
|
| 5 |
+
```
|
| 6 |
+
|
| 7 |
+
python -m pip install git+https://github.com/facebookresearch/segment-anything.git
|
| 8 |
+
cd app_3rd/sam_utils
|
| 9 |
+
mkdir checkpoints
|
| 10 |
+
cd checkpoints
|
| 11 |
+
wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth
|
| 12 |
+
```
|
app_3rd/sam_utils/hf_sam_predictor.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from typing import Optional, Tuple, List, Union
|
| 5 |
+
import warnings
|
| 6 |
+
import cv2
|
| 7 |
+
try:
|
| 8 |
+
from transformers import SamModel, SamProcessor
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
HF_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
HF_AVAILABLE = False
|
| 13 |
+
warnings.warn("transformers or huggingface_hub not available. HF SAM models will not work.")
|
| 14 |
+
|
| 15 |
+
# Hugging Face model mapping
|
| 16 |
+
HF_MODELS = {
|
| 17 |
+
'vit_b': 'facebook/sam-vit-base',
|
| 18 |
+
'vit_l': 'facebook/sam-vit-large',
|
| 19 |
+
'vit_h': 'facebook/sam-vit-huge'
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
class HFSamPredictor:
|
| 23 |
+
"""
|
| 24 |
+
Hugging Face version of SamPredictor that wraps the transformers SAM models.
|
| 25 |
+
This class provides the same interface as the original SamPredictor for seamless integration.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, model: SamModel, processor: SamProcessor, device: Optional[str] = None):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the HF SAM predictor.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
model: The SAM model from transformers
|
| 34 |
+
processor: The SAM processor from transformers
|
| 35 |
+
device: Device to run the model on ('cuda', 'cpu', etc.)
|
| 36 |
+
"""
|
| 37 |
+
self.model = model
|
| 38 |
+
self.processor = processor
|
| 39 |
+
self.device = device or ('cuda' if torch.cuda.is_available() else 'cpu')
|
| 40 |
+
self.model.to(self.device)
|
| 41 |
+
self.model.eval()
|
| 42 |
+
|
| 43 |
+
# Store the current image and its features
|
| 44 |
+
self.original_size = None
|
| 45 |
+
self.input_size = None
|
| 46 |
+
self.features = None
|
| 47 |
+
self.image = None
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def from_pretrained(cls, model_name: str, device: Optional[str] = None) -> 'HFSamPredictor':
|
| 51 |
+
"""
|
| 52 |
+
Load a SAM model from Hugging Face Hub.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
model_name: Model name from HF_MODELS or direct HF model path
|
| 56 |
+
device: Device to load the model on
|
| 57 |
+
|
| 58 |
+
Returns:
|
| 59 |
+
HFSamPredictor instance
|
| 60 |
+
"""
|
| 61 |
+
if not HF_AVAILABLE:
|
| 62 |
+
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
| 63 |
+
|
| 64 |
+
# Map model type to HF model name if needed
|
| 65 |
+
if model_name in HF_MODELS:
|
| 66 |
+
model_name = HF_MODELS[model_name]
|
| 67 |
+
|
| 68 |
+
print(f"Loading SAM model from Hugging Face: {model_name}")
|
| 69 |
+
|
| 70 |
+
# Load model and processor
|
| 71 |
+
model = SamModel.from_pretrained(model_name)
|
| 72 |
+
processor = SamProcessor.from_pretrained(model_name)
|
| 73 |
+
return cls(model, processor, device)
|
| 74 |
+
|
| 75 |
+
def preprocess(self, image: np.ndarray,
|
| 76 |
+
input_points: List[List[float]], input_labels: List[int]) -> None:
|
| 77 |
+
"""
|
| 78 |
+
Set the image for prediction. This preprocesses the image and extracts features.
|
| 79 |
+
|
| 80 |
+
Args:
|
| 81 |
+
image: Input image as numpy array (H, W, C) in RGB format
|
| 82 |
+
"""
|
| 83 |
+
if image.dtype != np.uint8:
|
| 84 |
+
image = (image * 255).astype(np.uint8)
|
| 85 |
+
|
| 86 |
+
self.image = image
|
| 87 |
+
self.original_size = image.shape[:2]
|
| 88 |
+
|
| 89 |
+
# Use dummy point to ensure processor returns original_sizes & reshaped_input_sizes
|
| 90 |
+
inputs = self.processor(
|
| 91 |
+
images=image,
|
| 92 |
+
input_points=input_points,
|
| 93 |
+
input_labels=input_labels,
|
| 94 |
+
return_tensors="pt"
|
| 95 |
+
)
|
| 96 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 97 |
+
|
| 98 |
+
self.input_size = inputs['pixel_values'].shape[-2:]
|
| 99 |
+
self.features = inputs
|
| 100 |
+
return inputs
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def get_hf_sam_predictor(model_type: str = 'vit_h', device: Optional[str] = None,
|
| 104 |
+
image: Optional[np.ndarray] = None) -> HFSamPredictor:
|
| 105 |
+
"""
|
| 106 |
+
Get a Hugging Face SAM predictor with the same interface as the original get_sam_predictor.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
| 110 |
+
device: Device to run the model on
|
| 111 |
+
image: Optional image to set immediately
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
HFSamPredictor instance
|
| 115 |
+
"""
|
| 116 |
+
if not HF_AVAILABLE:
|
| 117 |
+
raise ImportError("transformers and huggingface_hub are required for HF SAM models")
|
| 118 |
+
|
| 119 |
+
if device is None:
|
| 120 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 121 |
+
|
| 122 |
+
# Load the predictor
|
| 123 |
+
predictor = HFSamPredictor.from_pretrained(model_type, device)
|
| 124 |
+
|
| 125 |
+
# Set image if provided
|
| 126 |
+
if image is not None:
|
| 127 |
+
predictor.set_image(image)
|
| 128 |
+
|
| 129 |
+
return predictor
|
app_3rd/sam_utils/inference.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gc
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from segment_anything import SamPredictor, sam_model_registry
|
| 6 |
+
|
| 7 |
+
# Try to import HF SAM support
|
| 8 |
+
try:
|
| 9 |
+
from app_3rd.sam_utils.hf_sam_predictor import get_hf_sam_predictor, HFSamPredictor
|
| 10 |
+
HF_AVAILABLE = True
|
| 11 |
+
except ImportError:
|
| 12 |
+
HF_AVAILABLE = False
|
| 13 |
+
|
| 14 |
+
models = {
|
| 15 |
+
'vit_b': 'app_3rd/sam_utils/checkpoints/sam_vit_b_01ec64.pth',
|
| 16 |
+
'vit_l': 'app_3rd/sam_utils/checkpoints/sam_vit_l_0b3195.pth',
|
| 17 |
+
'vit_h': 'app_3rd/sam_utils/checkpoints/sam_vit_h_4b8939.pth'
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_sam_predictor(model_type='vit_b', device=None, image=None, use_hf=True, predictor=None):
|
| 22 |
+
"""
|
| 23 |
+
Get SAM predictor with option to use HuggingFace version
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
model_type: Model type ('vit_b', 'vit_l', 'vit_h')
|
| 27 |
+
device: Device to run on
|
| 28 |
+
image: Optional image to set immediately
|
| 29 |
+
use_hf: Whether to use HuggingFace SAM instead of original SAM
|
| 30 |
+
"""
|
| 31 |
+
if predictor is not None:
|
| 32 |
+
return predictor
|
| 33 |
+
if use_hf:
|
| 34 |
+
if not HF_AVAILABLE:
|
| 35 |
+
raise ImportError("HuggingFace SAM not available. Install transformers and huggingface_hub.")
|
| 36 |
+
return get_hf_sam_predictor(model_type, device, image)
|
| 37 |
+
|
| 38 |
+
# Original SAM logic
|
| 39 |
+
if device is None and torch.cuda.is_available():
|
| 40 |
+
device = 'cuda'
|
| 41 |
+
elif device is None:
|
| 42 |
+
device = 'cpu'
|
| 43 |
+
# sam model
|
| 44 |
+
sam = sam_model_registry[model_type](checkpoint=models[model_type])
|
| 45 |
+
sam = sam.to(device)
|
| 46 |
+
|
| 47 |
+
predictor = SamPredictor(sam)
|
| 48 |
+
if image is not None:
|
| 49 |
+
predictor.set_image(image)
|
| 50 |
+
return predictor
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def run_inference(predictor, input_x, selected_points, multi_object: bool = False):
|
| 54 |
+
"""
|
| 55 |
+
Run inference with either original SAM or HF SAM predictor
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
predictor: SamPredictor or HFSamPredictor instance
|
| 59 |
+
input_x: Input image
|
| 60 |
+
selected_points: List of (point, label) tuples
|
| 61 |
+
multi_object: Whether to handle multiple objects
|
| 62 |
+
"""
|
| 63 |
+
if len(selected_points) == 0:
|
| 64 |
+
return []
|
| 65 |
+
|
| 66 |
+
# Check if using HF SAM
|
| 67 |
+
if isinstance(predictor, HFSamPredictor):
|
| 68 |
+
return _run_hf_inference(predictor, input_x, selected_points, multi_object)
|
| 69 |
+
else:
|
| 70 |
+
return _run_original_inference(predictor, input_x, selected_points, multi_object)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def _run_original_inference(predictor: SamPredictor, input_x, selected_points, multi_object: bool = False):
|
| 74 |
+
"""Run inference with original SAM"""
|
| 75 |
+
points = torch.Tensor(
|
| 76 |
+
[p for p, _ in selected_points]
|
| 77 |
+
).to(predictor.device).unsqueeze(1)
|
| 78 |
+
|
| 79 |
+
labels = torch.Tensor(
|
| 80 |
+
[int(l) for _, l in selected_points]
|
| 81 |
+
).to(predictor.device).unsqueeze(1)
|
| 82 |
+
|
| 83 |
+
transformed_points = predictor.transform.apply_coords_torch(
|
| 84 |
+
points, input_x.shape[:2])
|
| 85 |
+
|
| 86 |
+
masks, scores, logits = predictor.predict_torch(
|
| 87 |
+
point_coords=transformed_points[:,0][None],
|
| 88 |
+
point_labels=labels[:,0][None],
|
| 89 |
+
multimask_output=False,
|
| 90 |
+
)
|
| 91 |
+
masks = masks[0].cpu().numpy() # N 1 H W N is the number of points
|
| 92 |
+
|
| 93 |
+
gc.collect()
|
| 94 |
+
torch.cuda.empty_cache()
|
| 95 |
+
|
| 96 |
+
return [(masks, 'final_mask')]
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def _run_hf_inference(predictor: HFSamPredictor, input_x, selected_points, multi_object: bool = False):
|
| 100 |
+
"""Run inference with HF SAM"""
|
| 101 |
+
# Prepare points and labels for HF SAM
|
| 102 |
+
select_pts = [[list(p) for p, _ in selected_points]]
|
| 103 |
+
select_lbls = [[int(l) for _, l in selected_points]]
|
| 104 |
+
|
| 105 |
+
# Preprocess inputs
|
| 106 |
+
inputs = predictor.preprocess(input_x, select_pts, select_lbls)
|
| 107 |
+
|
| 108 |
+
# Run inference
|
| 109 |
+
with torch.no_grad():
|
| 110 |
+
outputs = predictor.model(**inputs)
|
| 111 |
+
|
| 112 |
+
# Post-process masks
|
| 113 |
+
masks = predictor.processor.image_processor.post_process_masks(
|
| 114 |
+
outputs.pred_masks.cpu(),
|
| 115 |
+
inputs["original_sizes"].cpu(),
|
| 116 |
+
inputs["reshaped_input_sizes"].cpu(),
|
| 117 |
+
)
|
| 118 |
+
masks = masks[0][:,:1,...].cpu().numpy()
|
| 119 |
+
|
| 120 |
+
gc.collect()
|
| 121 |
+
torch.cuda.empty_cache()
|
| 122 |
+
|
| 123 |
+
return [(masks, 'final_mask')]
|
app_3rd/spatrack_utils/infer_track.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.SpaTrackV2.models.predictor import Predictor
|
| 2 |
+
import yaml
|
| 3 |
+
import easydict
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import torch
|
| 8 |
+
import torchvision.transforms as T
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import io
|
| 11 |
+
import moviepy.editor as mp
|
| 12 |
+
from models.SpaTrackV2.utils.visualizer import Visualizer
|
| 13 |
+
import tqdm
|
| 14 |
+
from models.SpaTrackV2.models.utils import get_points_on_a_grid
|
| 15 |
+
import glob
|
| 16 |
+
from rich import print
|
| 17 |
+
import argparse
|
| 18 |
+
import decord
|
| 19 |
+
from huggingface_hub import hf_hub_download
|
| 20 |
+
|
| 21 |
+
config = {
|
| 22 |
+
"ckpt_dir": "Yuxihenry/SpatialTrackerCkpts", # HuggingFace repo ID
|
| 23 |
+
"cfg_dir": "config/magic_infer_moge.yaml",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
def get_tracker_predictor(output_dir: str, vo_points: int = 756, tracker_model=None):
|
| 27 |
+
"""
|
| 28 |
+
Initialize and return the tracker predictor and visualizer
|
| 29 |
+
Args:
|
| 30 |
+
output_dir: Directory to save visualization results
|
| 31 |
+
vo_points: Number of points for visual odometry
|
| 32 |
+
Returns:
|
| 33 |
+
Tuple of (tracker_predictor, visualizer)
|
| 34 |
+
"""
|
| 35 |
+
viz = True
|
| 36 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 37 |
+
|
| 38 |
+
with open(config["cfg_dir"], "r") as f:
|
| 39 |
+
cfg = yaml.load(f, Loader=yaml.FullLoader)
|
| 40 |
+
cfg = easydict.EasyDict(cfg)
|
| 41 |
+
cfg.out_dir = output_dir
|
| 42 |
+
cfg.model.track_num = vo_points
|
| 43 |
+
|
| 44 |
+
# Check if it's a local path or HuggingFace repo
|
| 45 |
+
if tracker_model is not None:
|
| 46 |
+
model = tracker_model
|
| 47 |
+
model.spatrack.track_num = vo_points
|
| 48 |
+
else:
|
| 49 |
+
if os.path.exists(config["ckpt_dir"]):
|
| 50 |
+
# Local file
|
| 51 |
+
model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
|
| 52 |
+
else:
|
| 53 |
+
# HuggingFace repo - download the model
|
| 54 |
+
print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
|
| 55 |
+
checkpoint_path = hf_hub_download(
|
| 56 |
+
repo_id=config["ckpt_dir"],
|
| 57 |
+
repo_type="model",
|
| 58 |
+
filename="SpaTrack3_offline.pth"
|
| 59 |
+
)
|
| 60 |
+
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
|
| 61 |
+
model.eval()
|
| 62 |
+
model.to("cuda")
|
| 63 |
+
|
| 64 |
+
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
|
| 65 |
+
fps=10, pad_value=0, tracks_leave_trace=5)
|
| 66 |
+
|
| 67 |
+
return model, viser
|
| 68 |
+
|
| 69 |
+
def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
|
| 70 |
+
"""
|
| 71 |
+
Run tracking on a video sequence
|
| 72 |
+
Args:
|
| 73 |
+
model: Tracker predictor instance
|
| 74 |
+
viser: Visualizer instance
|
| 75 |
+
temp_dir: Directory containing temporary files
|
| 76 |
+
video_name: Name of the video file (without extension)
|
| 77 |
+
grid_size: Size of the tracking grid
|
| 78 |
+
vo_points: Number of points for visual odometry
|
| 79 |
+
fps: Frames per second for visualization
|
| 80 |
+
"""
|
| 81 |
+
# Setup paths
|
| 82 |
+
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
|
| 83 |
+
mask_path = os.path.join(temp_dir, f"{video_name}.png")
|
| 84 |
+
out_dir = os.path.join(temp_dir, "results")
|
| 85 |
+
os.makedirs(out_dir, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
# Load video using decord
|
| 88 |
+
video_reader = decord.VideoReader(video_path)
|
| 89 |
+
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
|
| 90 |
+
|
| 91 |
+
# resize make sure the shortest side is 336
|
| 92 |
+
h, w = video_tensor.shape[2:]
|
| 93 |
+
scale = max(336 / h, 336 / w)
|
| 94 |
+
if scale < 1:
|
| 95 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 96 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 97 |
+
video_tensor = video_tensor[::fps].float()
|
| 98 |
+
depth_tensor = None
|
| 99 |
+
intrs = None
|
| 100 |
+
extrs = None
|
| 101 |
+
data_npz_load = {}
|
| 102 |
+
|
| 103 |
+
# Load and process mask
|
| 104 |
+
if os.path.exists(mask_path):
|
| 105 |
+
mask = cv2.imread(mask_path)
|
| 106 |
+
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
|
| 107 |
+
mask = mask.sum(axis=-1)>0
|
| 108 |
+
else:
|
| 109 |
+
mask = np.ones_like(video_tensor[0,0].numpy())>0
|
| 110 |
+
|
| 111 |
+
# Get frame dimensions and create grid points
|
| 112 |
+
frame_H, frame_W = video_tensor.shape[2:]
|
| 113 |
+
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
|
| 114 |
+
|
| 115 |
+
# Sample mask values at grid points and filter out points where mask=0
|
| 116 |
+
if os.path.exists(mask_path):
|
| 117 |
+
grid_pts_int = grid_pts[0].long()
|
| 118 |
+
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
|
| 119 |
+
grid_pts = grid_pts[:, mask_values]
|
| 120 |
+
|
| 121 |
+
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
|
| 122 |
+
|
| 123 |
+
# run vggt
|
| 124 |
+
if os.environ.get("VGGT_DIR", None) is not None:
|
| 125 |
+
vggt_model = VGGT()
|
| 126 |
+
vggt_model.load_state_dict(torch.load(VGGT_DIR))
|
| 127 |
+
vggt_model.eval()
|
| 128 |
+
vggt_model = vggt_model.to("cuda")
|
| 129 |
+
# process the image tensor
|
| 130 |
+
video_tensor = preprocess_image(video_tensor)[None]
|
| 131 |
+
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
|
| 132 |
+
# Predict attributes including cameras, depth maps, and point maps.
|
| 133 |
+
aggregated_tokens_list, ps_idx = vggt_model.aggregator(video_tensor.cuda()/255)
|
| 134 |
+
pose_enc = vggt_model.camera_head(aggregated_tokens_list)[-1]
|
| 135 |
+
# Extrinsic and intrinsic matrices, following OpenCV convention (camera from world)
|
| 136 |
+
extrinsic, intrinsic = pose_encoding_to_extri_intri(pose_enc, video_tensor.shape[-2:])
|
| 137 |
+
# Predict Depth Maps
|
| 138 |
+
depth_map, depth_conf = vggt_model.depth_head(aggregated_tokens_list, video_tensor.cuda()/255, ps_idx)
|
| 139 |
+
# clear the cache
|
| 140 |
+
del vggt_model, aggregated_tokens_list, ps_idx, pose_enc
|
| 141 |
+
torch.cuda.empty_cache()
|
| 142 |
+
depth_tensor = depth_map.squeeze().cpu().numpy()
|
| 143 |
+
extrs = np.eye(4)[None].repeat(len(depth_tensor), axis=0)
|
| 144 |
+
extrs[:, :3, :4] = extrinsic.squeeze().cpu().numpy()
|
| 145 |
+
intrs = intrinsic.squeeze().cpu().numpy()
|
| 146 |
+
video_tensor = video_tensor.squeeze()
|
| 147 |
+
#NOTE: 20% of the depth is not reliable
|
| 148 |
+
# threshold = depth_conf.squeeze().view(-1).quantile(0.5)
|
| 149 |
+
unc_metric = depth_conf.squeeze().cpu().numpy() > 0.5
|
| 150 |
+
|
| 151 |
+
# Run model inference
|
| 152 |
+
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
|
| 153 |
+
(
|
| 154 |
+
c2w_traj, intrs, point_map, conf_depth,
|
| 155 |
+
track3d_pred, track2d_pred, vis_pred, conf_pred, video
|
| 156 |
+
) = model.forward(video_tensor, depth=depth_tensor,
|
| 157 |
+
intrs=intrs, extrs=extrs,
|
| 158 |
+
queries=query_xyt,
|
| 159 |
+
fps=1, full_point=False, iters_track=4,
|
| 160 |
+
query_no_BA=True, fixed_cam=False, stage=1,
|
| 161 |
+
support_frame=len(video_tensor)-1, replace_ratio=0.2)
|
| 162 |
+
|
| 163 |
+
# Resize results to avoid too large I/O Burden
|
| 164 |
+
max_size = 336
|
| 165 |
+
h, w = video.shape[2:]
|
| 166 |
+
scale = min(max_size / h, max_size / w)
|
| 167 |
+
if scale < 1:
|
| 168 |
+
new_h, new_w = int(h * scale), int(w * scale)
|
| 169 |
+
video = T.Resize((new_h, new_w))(video)
|
| 170 |
+
video_tensor = T.Resize((new_h, new_w))(video_tensor)
|
| 171 |
+
point_map = T.Resize((new_h, new_w))(point_map)
|
| 172 |
+
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
|
| 173 |
+
intrs[:,:2,:] = intrs[:,:2,:] * scale
|
| 174 |
+
if depth_tensor is not None:
|
| 175 |
+
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
|
| 176 |
+
conf_depth = T.Resize((new_h, new_w))(conf_depth)
|
| 177 |
+
|
| 178 |
+
# Visualize tracks
|
| 179 |
+
viser.visualize(video=video[None],
|
| 180 |
+
tracks=track2d_pred[None][...,:2],
|
| 181 |
+
visibility=vis_pred[None],filename="test")
|
| 182 |
+
|
| 183 |
+
# Save in tapip3d format
|
| 184 |
+
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
|
| 185 |
+
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
|
| 186 |
+
data_npz_load["intrinsics"] = intrs.cpu().numpy()
|
| 187 |
+
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
|
| 188 |
+
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
|
| 189 |
+
data_npz_load["visibs"] = vis_pred.cpu().numpy()
|
| 190 |
+
data_npz_load["confs"] = conf_pred.cpu().numpy()
|
| 191 |
+
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
|
| 192 |
+
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
|
| 193 |
+
|
| 194 |
+
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")
|
config/__init__.py
ADDED
|
File without changes
|
config/magic_infer_moge.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
seed: 0
|
| 2 |
+
# config the hydra logger, only in hydra `$` can be decoded as cite
|
| 3 |
+
data: ./assets/room
|
| 4 |
+
vis_track: false
|
| 5 |
+
hydra:
|
| 6 |
+
run:
|
| 7 |
+
dir: .
|
| 8 |
+
output_subdir: null
|
| 9 |
+
job_logging: {}
|
| 10 |
+
hydra_logging: {}
|
| 11 |
+
mixed_precision: bf16
|
| 12 |
+
visdom:
|
| 13 |
+
viz_ip: "localhost"
|
| 14 |
+
port: 6666
|
| 15 |
+
relax_load: false
|
| 16 |
+
res_all: 336
|
| 17 |
+
# config the ckpt path
|
| 18 |
+
# ckpts: "/mnt/bn/xyxdata/home/codes/my_projs/SpaTrack2/checkpoints/new_base.pth"
|
| 19 |
+
ckpts: "Yuxihenry/SpatialTracker_Files"
|
| 20 |
+
batch_size: 1
|
| 21 |
+
input:
|
| 22 |
+
type: image
|
| 23 |
+
fps: 1
|
| 24 |
+
model_wind_size: 32
|
| 25 |
+
model:
|
| 26 |
+
backbone_cfg:
|
| 27 |
+
ckpt_dir: "checkpoints/model.pt"
|
| 28 |
+
chunk_size: 24 # downsample factor for patchified features
|
| 29 |
+
ckpt_fwd: true
|
| 30 |
+
ft_cfg:
|
| 31 |
+
mode: "fix"
|
| 32 |
+
paras_name: []
|
| 33 |
+
resolution: 336
|
| 34 |
+
max_len: 512
|
| 35 |
+
Track_cfg:
|
| 36 |
+
base_ckpt: "checkpoints/scaled_offline.pth"
|
| 37 |
+
base:
|
| 38 |
+
stride: 4
|
| 39 |
+
corr_radius: 3
|
| 40 |
+
window_len: 60
|
| 41 |
+
stablizer: True
|
| 42 |
+
mode: "online"
|
| 43 |
+
s_wind: 200
|
| 44 |
+
overlap: 4
|
| 45 |
+
track_num: 0
|
| 46 |
+
|
| 47 |
+
dist_train:
|
| 48 |
+
num_nodes: 1
|
examples/backpack.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4b5ac6b2285ffb48e3a740e419e38c781df9c963589a5fd894e5b4e13dd6a8b8
|
| 3 |
+
size 1208738
|
examples/ball.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:31f6e3bf875a85284b376c05170b4c08b546b7d5e95106848b1e3818a9d0db91
|
| 3 |
+
size 3030268
|
examples/basketball.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0df3b429d5fd64c298f2d79b2d818a4044e7341a71d70b957f60b24e313c3760
|
| 3 |
+
size 2487837
|
examples/biker.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:fba880c24bdb8fa3b84b1b491d52f2c1f426fb09e34c3013603e5a549cf3b22b
|
| 3 |
+
size 249196
|
examples/cinema_0.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a68a5643c14f61c05d48e25a98ddf5cf0344d3ffcda08ad4a0adc989d49d7a9c
|
| 3 |
+
size 1774022
|
examples/cinema_1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99624e2d0fb2e9f994e46aefb904e884de37a6d78e7f6b6670e286eaa397e515
|
| 3 |
+
size 2370749
|
examples/drifting.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4f3937871117d3cc5d7da3ef31d1edf5626fc8372126b73590f75f05713fe97c
|
| 3 |
+
size 4695804
|
examples/ego_kc1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:22fe64e458e329e8b3c3e20b3725ffd85c3a2e725fd03909cf883d3fd02c80b3
|
| 3 |
+
size 1365980
|
examples/ego_teaser.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8780b291b48046b1c7dea90712c1c3f59d60c03216df1c489f6f03e8d61fae5c
|
| 3 |
+
size 7365665
|
examples/handwave.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e6dde7cf4ffa7c66b6861bb5abdedc49dfc4b5b4dd9dd46ee8415dd4953935b6
|
| 3 |
+
size 2099369
|
examples/hockey.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c3be095777b442dc401e7d1f489b749611ffade3563a01e4e3d1e511311bd86
|
| 3 |
+
size 1795810
|
examples/ken_block_0.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7b788faeb4d3206fa604d622a05268f1321ad6a229178fe12319d20c9438deb1
|
| 3 |
+
size 196343
|
examples/kiss.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f78fffc5108d95d4e5837d7607226f3dd9796615ea3481f2629c69ccd2ccb12f
|
| 3 |
+
size 1073570
|
examples/kitchen.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3120e942a9b3d7b300928e43113b000fb5ccc209012a2c560ec26b8a04c2d5f9
|
| 3 |
+
size 543970
|
examples/kitchen_egocentric.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5468ab10d8d39b68b51fa616adc3d099dab7543e38dd221a0a7a20a2401824a2
|
| 3 |
+
size 2176685
|
examples/pillow.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f05818f586d7b0796fcd4714ea4be489c93701598cadc86ce7973fc24655fee
|
| 3 |
+
size 1407147
|
examples/protein.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2dc9cfceb0984b61ebc62fda4c826135ebe916c8c966a8123dcc3315d43b73f
|
| 3 |
+
size 2002300
|
examples/pusht.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:996d1923e36811a1069e4d6b5e8c0338d9068c0870ea09c4c04e13e9fbcd207a
|
| 3 |
+
size 5256495
|
examples/robot1.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7a3b9e4449572129fdd96a751938e211241cdd86bcc56ffd33bfd23fc4d6e9c0
|
| 3 |
+
size 1178671
|
examples/robot2.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:188b2d8824ce345c86a603bff210639a6158d72cf6119cc1d3f79d409ac68bb3
|
| 3 |
+
size 867261
|
examples/robot_3.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:784a0f9c36a316d0da5745075dbc8cefd9ce60c25b067d3d80a1d52830df8a37
|
| 3 |
+
size 1153015
|
examples/robot_unitree.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:99bc274f7613a665c6135085fe01691ebfaa9033101319071f37c550ab21d1ea
|
| 3 |
+
size 1964268
|
examples/running.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9ceb96b287fefb1c090dcd2f5db7634f808d2079413500beeb7b33023dfae51b
|
| 3 |
+
size 7307897
|
examples/teleop2.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:59ea006a18227da8cf5db1fa50cd48e71ec7eb66fef48ea2158c325088bd9fee
|
| 3 |
+
size 1077267
|
examples/vertical_place.mp4
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6c8061ae449f986113c2ecb17aefc2c13f737aecbcd41d6c057c88e6d41ac3ee
|
| 3 |
+
size 719810
|
models/SpaTrackV2/models/SpaTrack.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#python
|
| 2 |
+
"""
|
| 3 |
+
SpaTrackerV2, which is an unified model to estimate 'intrinsic',
|
| 4 |
+
'video depth', 'extrinsic' and '3D Tracking' from casual video frames.
|
| 5 |
+
|
| 6 |
+
Contact: DM yuxixiao@zju.edu.cn
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
from typing import Literal, Union, List, Tuple, Dict
|
| 12 |
+
import cv2
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.nn.functional as F
|
| 16 |
+
# from depth anything v2
|
| 17 |
+
from huggingface_hub import PyTorchModelHubMixin # used for model hub
|
| 18 |
+
from einops import rearrange
|
| 19 |
+
from models.monoD.depth_anything_v2.dpt import DepthAnythingV2
|
| 20 |
+
from models.moge.model.v1 import MoGeModel
|
| 21 |
+
import copy
|
| 22 |
+
from functools import partial
|
| 23 |
+
from models.SpaTrackV2.models.tracker3D.TrackRefiner import TrackRefiner3D
|
| 24 |
+
import kornia
|
| 25 |
+
from models.SpaTrackV2.utils.model_utils import sample_features5d
|
| 26 |
+
import utils3d
|
| 27 |
+
from models.SpaTrackV2.models.tracker3D.spatrack_modules.utils import depth_to_points_colmap, get_nth_visible_time_index
|
| 28 |
+
from models.SpaTrackV2.models.utils import pose_enc2mat, matrix_to_quaternion, get_track_points, normalize_rgb
|
| 29 |
+
import random
|
| 30 |
+
|
| 31 |
+
class SpaTrack2(nn.Module, PyTorchModelHubMixin):
|
| 32 |
+
def __init__(
|
| 33 |
+
self,
|
| 34 |
+
loggers: list, # include [ viz, logger_tf, logger]
|
| 35 |
+
backbone_cfg,
|
| 36 |
+
Track_cfg=None,
|
| 37 |
+
chunk_size=24,
|
| 38 |
+
ckpt_fwd: bool = False,
|
| 39 |
+
ft_cfg=None,
|
| 40 |
+
resolution=518,
|
| 41 |
+
max_len=600, # the maximum video length we can preprocess,
|
| 42 |
+
track_num=768,
|
| 43 |
+
):
|
| 44 |
+
|
| 45 |
+
self.chunk_size = chunk_size
|
| 46 |
+
self.max_len = max_len
|
| 47 |
+
self.resolution = resolution
|
| 48 |
+
# config the T-Lora Dinov2
|
| 49 |
+
#NOTE: initial the base model
|
| 50 |
+
base_cfg = copy.deepcopy(backbone_cfg)
|
| 51 |
+
backbone_ckpt_dir = base_cfg.pop('ckpt_dir', None)
|
| 52 |
+
|
| 53 |
+
super(SpaTrack2, self).__init__()
|
| 54 |
+
if os.path.exists(backbone_ckpt_dir)==False:
|
| 55 |
+
base_model = MoGeModel.from_pretrained('Ruicheng/moge-vitl')
|
| 56 |
+
else:
|
| 57 |
+
checkpoint = torch.load(backbone_ckpt_dir, map_location='cpu', weights_only=True)
|
| 58 |
+
base_model = MoGeModel(**checkpoint["model_config"])
|
| 59 |
+
base_model.load_state_dict(checkpoint['model'])
|
| 60 |
+
# avoid the base_model is a member of SpaTrack2
|
| 61 |
+
object.__setattr__(self, 'base_model', base_model)
|
| 62 |
+
|
| 63 |
+
# Tracker model
|
| 64 |
+
self.Track3D = TrackRefiner3D(Track_cfg)
|
| 65 |
+
track_base_ckpt_dir = Track_cfg.base_ckpt
|
| 66 |
+
if os.path.exists(track_base_ckpt_dir):
|
| 67 |
+
track_pretrain = torch.load(track_base_ckpt_dir)
|
| 68 |
+
self.Track3D.load_state_dict(track_pretrain, strict=False)
|
| 69 |
+
|
| 70 |
+
# wrap the function of make lora trainable
|
| 71 |
+
self.make_paras_trainable = partial(self.make_paras_trainable,
|
| 72 |
+
mode=ft_cfg.mode,
|
| 73 |
+
paras_name=ft_cfg.paras_name)
|
| 74 |
+
self.track_num = track_num
|
| 75 |
+
|
| 76 |
+
def make_paras_trainable(self, mode: str = 'fix', paras_name: List[str] = []):
|
| 77 |
+
# gradient required for the lora_experts and gate
|
| 78 |
+
for name, param in self.named_parameters():
|
| 79 |
+
if any(x in name for x in paras_name):
|
| 80 |
+
if mode == 'fix':
|
| 81 |
+
param.requires_grad = False
|
| 82 |
+
else:
|
| 83 |
+
param.requires_grad = True
|
| 84 |
+
else:
|
| 85 |
+
if mode == 'fix':
|
| 86 |
+
param.requires_grad = True
|
| 87 |
+
else:
|
| 88 |
+
param.requires_grad = False
|
| 89 |
+
total_params = sum(p.numel() for p in self.parameters())
|
| 90 |
+
trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 91 |
+
print(f"Total parameters: {total_params}")
|
| 92 |
+
print(f"Trainable parameters: {trainable_params/total_params*100:.2f}%")
|
| 93 |
+
|
| 94 |
+
def ProcVid(self,
|
| 95 |
+
x: torch.Tensor):
|
| 96 |
+
"""
|
| 97 |
+
split the video into several overlapped windows.
|
| 98 |
+
|
| 99 |
+
args:
|
| 100 |
+
x: the input video frames. [B, T, C, H, W]
|
| 101 |
+
outputs:
|
| 102 |
+
patch_size: the patch size of the video features
|
| 103 |
+
raises:
|
| 104 |
+
ValueError: if the input video is longer than `max_len`.
|
| 105 |
+
|
| 106 |
+
"""
|
| 107 |
+
# normalize the input images
|
| 108 |
+
num_types = x.dtype
|
| 109 |
+
x = normalize_rgb(x, input_size=self.resolution)
|
| 110 |
+
x = x.to(num_types)
|
| 111 |
+
# get the video features
|
| 112 |
+
B, T, C, H, W = x.size()
|
| 113 |
+
if T > self.max_len:
|
| 114 |
+
raise ValueError(f"the video length should no more than {self.max_len}.")
|
| 115 |
+
# get the video features
|
| 116 |
+
patch_h, patch_w = H // 14, W // 14
|
| 117 |
+
patch_size = (patch_h, patch_w)
|
| 118 |
+
# resize and get the video features
|
| 119 |
+
x = x.view(B * T, C, H, W)
|
| 120 |
+
# operate the temporal encoding
|
| 121 |
+
return patch_size, x
|
| 122 |
+
|
| 123 |
+
def forward_stream(
|
| 124 |
+
self,
|
| 125 |
+
video: torch.Tensor,
|
| 126 |
+
queries: torch.Tensor = None,
|
| 127 |
+
T_org: int = None,
|
| 128 |
+
depth: torch.Tensor|np.ndarray|str=None,
|
| 129 |
+
unc_metric_in: torch.Tensor|np.ndarray|str=None,
|
| 130 |
+
intrs: torch.Tensor|np.ndarray|str=None,
|
| 131 |
+
extrs: torch.Tensor|np.ndarray|str=None,
|
| 132 |
+
queries_3d: torch.Tensor = None,
|
| 133 |
+
window_len: int = 16,
|
| 134 |
+
overlap_len: int = 4,
|
| 135 |
+
full_point: bool = False,
|
| 136 |
+
track2d_gt: torch.Tensor = None,
|
| 137 |
+
fixed_cam: bool = False,
|
| 138 |
+
query_no_BA: bool = False,
|
| 139 |
+
stage: int = 0,
|
| 140 |
+
support_frame: int = 0,
|
| 141 |
+
replace_ratio: float = 0.6,
|
| 142 |
+
annots_train: Dict = None,
|
| 143 |
+
iters_track=4,
|
| 144 |
+
**kwargs,
|
| 145 |
+
):
|
| 146 |
+
# step 1 allocate the query points on the grid
|
| 147 |
+
T, C, H, W = video.shape
|
| 148 |
+
|
| 149 |
+
if annots_train is not None:
|
| 150 |
+
vis_gt = annots_train["vis"]
|
| 151 |
+
_, _, N = vis_gt.shape
|
| 152 |
+
number_visible = vis_gt.sum(dim=1)
|
| 153 |
+
ratio_rand = torch.rand(1, N, device=vis_gt.device)
|
| 154 |
+
first_positive_inds = get_nth_visible_time_index(vis_gt, (number_visible*ratio_rand).long().clamp(min=1, max=T))
|
| 155 |
+
assert (torch.gather(vis_gt, 1, first_positive_inds[:, None, :].repeat(1, T, 1)) < 0).sum() == 0
|
| 156 |
+
|
| 157 |
+
first_positive_inds = first_positive_inds.long()
|
| 158 |
+
gather = torch.gather(
|
| 159 |
+
annots_train["traj_3d"][...,:2], 1, first_positive_inds[:, :, None, None].repeat(1, 1, N, 2)
|
| 160 |
+
)
|
| 161 |
+
xys = torch.diagonal(gather, dim1=1, dim2=2).permute(0, 2, 1)
|
| 162 |
+
queries = torch.cat([first_positive_inds[:, :, None], xys], dim=-1)[0].cpu().numpy()
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Unfold video into segments of window_len with overlap_len
|
| 166 |
+
step_slide = window_len - overlap_len
|
| 167 |
+
if T < window_len:
|
| 168 |
+
video_unf = video.unsqueeze(0)
|
| 169 |
+
if depth is not None:
|
| 170 |
+
depth_unf = depth.unsqueeze(0)
|
| 171 |
+
else:
|
| 172 |
+
depth_unf = None
|
| 173 |
+
if unc_metric_in is not None:
|
| 174 |
+
unc_metric_unf = unc_metric_in.unsqueeze(0)
|
| 175 |
+
else:
|
| 176 |
+
unc_metric_unf = None
|
| 177 |
+
if intrs is not None:
|
| 178 |
+
intrs_unf = intrs.unsqueeze(0)
|
| 179 |
+
else:
|
| 180 |
+
intrs_unf = None
|
| 181 |
+
if extrs is not None:
|
| 182 |
+
extrs_unf = extrs.unsqueeze(0)
|
| 183 |
+
else:
|
| 184 |
+
extrs_unf = None
|
| 185 |
+
else:
|
| 186 |
+
video_unf = video.unfold(0, window_len, step_slide).permute(0, 4, 1, 2, 3) # [B, S, C, H, W]
|
| 187 |
+
if depth is not None:
|
| 188 |
+
depth_unf = depth.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 189 |
+
intrs_unf = intrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 190 |
+
else:
|
| 191 |
+
depth_unf = None
|
| 192 |
+
intrs_unf = None
|
| 193 |
+
if extrs is not None:
|
| 194 |
+
extrs_unf = extrs.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 195 |
+
else:
|
| 196 |
+
extrs_unf = None
|
| 197 |
+
if unc_metric_in is not None:
|
| 198 |
+
unc_metric_unf = unc_metric_in.unfold(0, window_len, step_slide).permute(0, 3, 1, 2)
|
| 199 |
+
else:
|
| 200 |
+
unc_metric_unf = None
|
| 201 |
+
|
| 202 |
+
# parallel
|
| 203 |
+
# Get number of segments
|
| 204 |
+
B = video_unf.shape[0]
|
| 205 |
+
#TODO: Process each segment in parallel using torch.nn.DataParallel
|
| 206 |
+
c2w_traj = torch.eye(4, 4)[None].repeat(T, 1, 1)
|
| 207 |
+
intrs_out = torch.eye(3, 3)[None].repeat(T, 1, 1)
|
| 208 |
+
point_map = torch.zeros(T, 3, H, W).cuda()
|
| 209 |
+
unc_metric = torch.zeros(T, H, W).cuda()
|
| 210 |
+
# set the queries
|
| 211 |
+
N, _ = queries.shape
|
| 212 |
+
track3d_pred = torch.zeros(T, N, 6).cuda()
|
| 213 |
+
track2d_pred = torch.zeros(T, N, 3).cuda()
|
| 214 |
+
vis_pred = torch.zeros(T, N, 1).cuda()
|
| 215 |
+
conf_pred = torch.zeros(T, N, 1).cuda()
|
| 216 |
+
dyn_preds = torch.zeros(T, N, 1).cuda()
|
| 217 |
+
# sort the queries by time
|
| 218 |
+
sorted_indices = np.argsort(queries[...,0])
|
| 219 |
+
sorted_inv_indices = np.argsort(sorted_indices)
|
| 220 |
+
sort_query = queries[sorted_indices]
|
| 221 |
+
sort_query = torch.from_numpy(sort_query).cuda()
|
| 222 |
+
if queries_3d is not None:
|
| 223 |
+
sort_query_3d = queries_3d[sorted_indices]
|
| 224 |
+
sort_query_3d = torch.from_numpy(sort_query_3d).cuda()
|
| 225 |
+
|
| 226 |
+
queries_len = 0
|
| 227 |
+
overlap_d = None
|
| 228 |
+
cache = None
|
| 229 |
+
loss = 0.0
|
| 230 |
+
|
| 231 |
+
for i in range(B):
|
| 232 |
+
segment = video_unf[i:i+1].cuda()
|
| 233 |
+
# Forward pass through model
|
| 234 |
+
# detect the key points for each frames
|
| 235 |
+
|
| 236 |
+
queries_new_mask = (sort_query[...,0] < i * step_slide + window_len) * (sort_query[...,0] >= (i * step_slide + overlap_len if i > 0 else 0))
|
| 237 |
+
if queries_3d is not None:
|
| 238 |
+
queries_new_3d = sort_query_3d[queries_new_mask]
|
| 239 |
+
queries_new_3d = queries_new_3d.float()
|
| 240 |
+
else:
|
| 241 |
+
queries_new_3d = None
|
| 242 |
+
queries_new = sort_query[queries_new_mask.bool()]
|
| 243 |
+
queries_new = queries_new.float()
|
| 244 |
+
if i > 0:
|
| 245 |
+
overlap2d = track2d_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
| 246 |
+
overlapvis = vis_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
| 247 |
+
overlapconf = conf_pred[i*step_slide:(i+1)*step_slide, :queries_len, :]
|
| 248 |
+
overlap_query = (overlapvis * overlapconf).max(dim=0)[1][None, ...]
|
| 249 |
+
overlap_xy = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,2))
|
| 250 |
+
overlap_d = torch.gather(overlap2d, 0, overlap_query.repeat(1,1,3))[...,2].detach()
|
| 251 |
+
overlap_query = torch.cat([overlap_query[...,:1], overlap_xy], dim=-1)[0]
|
| 252 |
+
queries_new[...,0] -= i*step_slide
|
| 253 |
+
queries_new = torch.cat([overlap_query, queries_new], dim=0).detach()
|
| 254 |
+
|
| 255 |
+
if annots_train is None:
|
| 256 |
+
annots = {}
|
| 257 |
+
else:
|
| 258 |
+
annots = copy.deepcopy(annots_train)
|
| 259 |
+
annots["traj_3d"] = annots["traj_3d"][:, i*step_slide:i*step_slide+window_len, sorted_indices,:][...,:len(queries_new),:]
|
| 260 |
+
annots["vis"] = annots["vis"][:, i*step_slide:i*step_slide+window_len, sorted_indices][...,:len(queries_new)]
|
| 261 |
+
annots["poses_gt"] = annots["poses_gt"][:, i*step_slide:i*step_slide+window_len]
|
| 262 |
+
annots["depth_gt"] = annots["depth_gt"][:, i*step_slide:i*step_slide+window_len]
|
| 263 |
+
annots["intrs"] = annots["intrs"][:, i*step_slide:i*step_slide+window_len]
|
| 264 |
+
annots["traj_mat"] = annots["traj_mat"][:,i*step_slide:i*step_slide+window_len]
|
| 265 |
+
|
| 266 |
+
if depth is not None:
|
| 267 |
+
annots["depth_gt"] = depth_unf[i:i+1].to(segment.device).to(segment.dtype)
|
| 268 |
+
if unc_metric_in is not None:
|
| 269 |
+
annots["unc_metric"] = unc_metric_unf[i:i+1].to(segment.device).to(segment.dtype)
|
| 270 |
+
if intrs is not None:
|
| 271 |
+
intr_seg = intrs_unf[i:i+1].to(segment.device).to(segment.dtype)[0].clone()
|
| 272 |
+
focal = (intr_seg[:,0,0] / segment.shape[-1] + intr_seg[:,1,1]/segment.shape[-2]) / 2
|
| 273 |
+
pose_fake = torch.zeros(1, 8).to(depth.device).to(depth.dtype).repeat(segment.shape[1], 1)
|
| 274 |
+
pose_fake[:, -1] = focal
|
| 275 |
+
pose_fake[:,3]=1
|
| 276 |
+
annots["intrs_gt"] = intr_seg
|
| 277 |
+
if extrs is not None:
|
| 278 |
+
extrs_unf_norm = extrs_unf[i:i+1][0].clone()
|
| 279 |
+
extrs_unf_norm = torch.inverse(extrs_unf_norm[:1,...]) @ extrs_unf[i:i+1][0]
|
| 280 |
+
rot_vec = matrix_to_quaternion(extrs_unf_norm[:,:3,:3])
|
| 281 |
+
annots["poses_gt"] = torch.zeros(1, rot_vec.shape[0], 7).to(segment.device).to(segment.dtype)
|
| 282 |
+
annots["poses_gt"][:, :, 3:7] = rot_vec.to(segment.device).to(segment.dtype)[None]
|
| 283 |
+
annots["poses_gt"][:, :, :3] = extrs_unf_norm[:,:3,3].to(segment.device).to(segment.dtype)[None]
|
| 284 |
+
annots["use_extr"] = True
|
| 285 |
+
|
| 286 |
+
kwargs.update({"stage": stage})
|
| 287 |
+
|
| 288 |
+
#TODO: DEBUG
|
| 289 |
+
out = self.forward(segment, pts_q=queries_new,
|
| 290 |
+
pts_q_3d=queries_new_3d, overlap_d=overlap_d,
|
| 291 |
+
full_point=full_point,
|
| 292 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA,
|
| 293 |
+
support_frame=segment.shape[1]-1,
|
| 294 |
+
cache=cache, replace_ratio=replace_ratio,
|
| 295 |
+
iters_track=iters_track,
|
| 296 |
+
**kwargs, annots=annots)
|
| 297 |
+
if self.training:
|
| 298 |
+
loss += out["loss"].squeeze()
|
| 299 |
+
# from models.SpaTrackV2.utils.visualizer import Visualizer
|
| 300 |
+
# vis_track = Visualizer(grayscale=False,
|
| 301 |
+
# fps=10, pad_value=50, tracks_leave_trace=0)
|
| 302 |
+
# vis_track.visualize(video=segment,
|
| 303 |
+
# tracks=out["traj_est"][...,:2],
|
| 304 |
+
# visibility=out["vis_est"],
|
| 305 |
+
# save_video=True)
|
| 306 |
+
# # visualize 4d
|
| 307 |
+
# import os, json
|
| 308 |
+
# import os.path as osp
|
| 309 |
+
# viser4d_dir = os.path.join("viser_4d_results")
|
| 310 |
+
# os.makedirs(viser4d_dir, exist_ok=True)
|
| 311 |
+
# depth_est = annots["depth_gt"][0]
|
| 312 |
+
# unc_metric = out["unc_metric"]
|
| 313 |
+
# mask = (unc_metric > 0.5).squeeze(1)
|
| 314 |
+
# # pose_est = out["poses_pred"].squeeze(0)
|
| 315 |
+
# pose_est = annots["traj_mat"][0]
|
| 316 |
+
# rgb_tracks = out["rgb_tracks"].squeeze(0)
|
| 317 |
+
# intrinsics = out["intrs"].squeeze(0)
|
| 318 |
+
# for i_k in range(out["depth"].shape[0]):
|
| 319 |
+
# img_i = out["imgs_raw"][0][i_k].permute(1, 2, 0).cpu().numpy()
|
| 320 |
+
# img_i = cv2.cvtColor(img_i, cv2.COLOR_BGR2RGB)
|
| 321 |
+
# cv2.imwrite(osp.join(viser4d_dir, f'frame_{i_k:04d}.png'), img_i)
|
| 322 |
+
# if stage == 1:
|
| 323 |
+
# depth = depth_est[i_k].squeeze().cpu().numpy()
|
| 324 |
+
# np.save(osp.join(viser4d_dir, f'frame_{i_k:04d}.npy'), depth)
|
| 325 |
+
# else:
|
| 326 |
+
# point_map_vis = out["points_map"][i_k].cpu().numpy()
|
| 327 |
+
# np.save(osp.join(viser4d_dir, f'point_{i_k:04d}.npy'), point_map_vis)
|
| 328 |
+
# np.save(os.path.join(viser4d_dir, f'intrinsics.npy'), intrinsics.cpu().numpy())
|
| 329 |
+
# np.save(os.path.join(viser4d_dir, f'extrinsics.npy'), pose_est.cpu().numpy())
|
| 330 |
+
# np.save(os.path.join(viser4d_dir, f'conf.npy'), mask.float().cpu().numpy())
|
| 331 |
+
# np.save(os.path.join(viser4d_dir, f'colored_track3d.npy'), rgb_tracks.cpu().numpy())
|
| 332 |
+
|
| 333 |
+
queries_len = len(queries_new)
|
| 334 |
+
# update the track3d and track2d
|
| 335 |
+
left_len = len(track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :])
|
| 336 |
+
track3d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["rgb_tracks"][0,:left_len,:queries_len,:]
|
| 337 |
+
track2d_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["traj_est"][0,:left_len,:queries_len,:3]
|
| 338 |
+
vis_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["vis_est"][0,:left_len,:queries_len,None]
|
| 339 |
+
conf_pred[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["conf_pred"][0,:left_len,:queries_len,None]
|
| 340 |
+
dyn_preds[i*step_slide:i*step_slide+window_len, :queries_len, :] = out["dyn_preds"][0,:left_len,:queries_len,None]
|
| 341 |
+
|
| 342 |
+
# process the output for each segment
|
| 343 |
+
seg_c2w = out["poses_pred"][0]
|
| 344 |
+
seg_intrs = out["intrs"][0]
|
| 345 |
+
seg_point_map = out["points_map"]
|
| 346 |
+
seg_conf_depth = out["unc_metric"]
|
| 347 |
+
|
| 348 |
+
# cache management
|
| 349 |
+
cache = out["cache"]
|
| 350 |
+
for k in cache.keys():
|
| 351 |
+
if "_pyramid" in k:
|
| 352 |
+
for j in range(len(cache[k])):
|
| 353 |
+
if len(cache[k][j].shape) == 5:
|
| 354 |
+
cache[k][j] = cache[k][j][:,:,:,:queries_len,:]
|
| 355 |
+
elif len(cache[k][j].shape) == 4:
|
| 356 |
+
cache[k][j] = cache[k][j][:,:1,:queries_len,:]
|
| 357 |
+
elif "_pred_cache" in k:
|
| 358 |
+
cache[k] = cache[k][-overlap_len:,:queries_len,:]
|
| 359 |
+
else:
|
| 360 |
+
cache[k] = cache[k][-overlap_len:]
|
| 361 |
+
|
| 362 |
+
# update the results
|
| 363 |
+
idx_glob = i * step_slide
|
| 364 |
+
# refine part
|
| 365 |
+
# mask_update = sort_query[..., 0] < i * step_slide + window_len
|
| 366 |
+
# sort_query_pick = sort_query[mask_update]
|
| 367 |
+
intrs_out[idx_glob:idx_glob+window_len] = seg_intrs
|
| 368 |
+
point_map[idx_glob:idx_glob+window_len] = seg_point_map
|
| 369 |
+
unc_metric[idx_glob:idx_glob+window_len] = seg_conf_depth
|
| 370 |
+
# update the camera poses
|
| 371 |
+
|
| 372 |
+
# if using the ground truth pose
|
| 373 |
+
# if extrs_unf is not None:
|
| 374 |
+
# c2w_traj[idx_glob:idx_glob+window_len] = extrs_unf[i:i+1][0].to(c2w_traj.device).to(c2w_traj.dtype)
|
| 375 |
+
# else:
|
| 376 |
+
prev_c2w = c2w_traj[idx_glob:idx_glob+window_len][:1]
|
| 377 |
+
c2w_traj[idx_glob:idx_glob+window_len] = prev_c2w@seg_c2w.to(c2w_traj.device).to(c2w_traj.dtype)
|
| 378 |
+
|
| 379 |
+
track2d_pred = track2d_pred[:T_org,sorted_inv_indices,:]
|
| 380 |
+
track3d_pred = track3d_pred[:T_org,sorted_inv_indices,:]
|
| 381 |
+
vis_pred = vis_pred[:T_org,sorted_inv_indices,:]
|
| 382 |
+
conf_pred = conf_pred[:T_org,sorted_inv_indices,:]
|
| 383 |
+
dyn_preds = dyn_preds[:T_org,sorted_inv_indices,:]
|
| 384 |
+
unc_metric = unc_metric[:T_org,:]
|
| 385 |
+
point_map = point_map[:T_org,:]
|
| 386 |
+
intrs_out = intrs_out[:T_org,:]
|
| 387 |
+
c2w_traj = c2w_traj[:T_org,:]
|
| 388 |
+
if self.training:
|
| 389 |
+
ret = {
|
| 390 |
+
"loss": loss,
|
| 391 |
+
"depth_loss": 0.0,
|
| 392 |
+
"ab_loss": 0.0,
|
| 393 |
+
"vis_loss": out["vis_loss"],
|
| 394 |
+
"track_loss": out["track_loss"],
|
| 395 |
+
"conf_loss": out["conf_loss"],
|
| 396 |
+
"dyn_loss": out["dyn_loss"],
|
| 397 |
+
"sync_loss": out["sync_loss"],
|
| 398 |
+
"poses_pred": c2w_traj[None],
|
| 399 |
+
"intrs": intrs_out[None],
|
| 400 |
+
"points_map": point_map,
|
| 401 |
+
"track3d_pred": track3d_pred[None],
|
| 402 |
+
"rgb_tracks": track3d_pred[None],
|
| 403 |
+
"track2d_pred": track2d_pred[None],
|
| 404 |
+
"traj_est": track2d_pred[None],
|
| 405 |
+
"vis_est": vis_pred[None], "conf_pred": conf_pred[None],
|
| 406 |
+
"dyn_preds": dyn_preds[None],
|
| 407 |
+
"imgs_raw": video[None],
|
| 408 |
+
"unc_metric": unc_metric,
|
| 409 |
+
}
|
| 410 |
+
|
| 411 |
+
return ret
|
| 412 |
+
else:
|
| 413 |
+
return c2w_traj, intrs_out, point_map, unc_metric, track3d_pred, track2d_pred, vis_pred, conf_pred
|
| 414 |
+
def forward(self,
|
| 415 |
+
x: torch.Tensor,
|
| 416 |
+
annots: Dict = {},
|
| 417 |
+
pts_q: torch.Tensor = None,
|
| 418 |
+
pts_q_3d: torch.Tensor = None,
|
| 419 |
+
overlap_d: torch.Tensor = None,
|
| 420 |
+
full_point = False,
|
| 421 |
+
fixed_cam = False,
|
| 422 |
+
support_frame = 0,
|
| 423 |
+
query_no_BA = False,
|
| 424 |
+
cache = None,
|
| 425 |
+
replace_ratio = 0.6,
|
| 426 |
+
iters_track=4,
|
| 427 |
+
**kwargs):
|
| 428 |
+
"""
|
| 429 |
+
forward the video camera model, which predict (
|
| 430 |
+
`intr` `camera poses` `video depth`
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
args:
|
| 434 |
+
x: the input video frames. [B, T, C, H, W]
|
| 435 |
+
annots: the annotations for video frames.
|
| 436 |
+
{
|
| 437 |
+
"poses_gt": the pose encoding for the video frames. [B, T, 7]
|
| 438 |
+
"depth_gt": the ground truth depth for the video frames. [B, T, 1, H, W],
|
| 439 |
+
"metric": bool, whether to calculate the metric for the video frames.
|
| 440 |
+
}
|
| 441 |
+
"""
|
| 442 |
+
self.support_frame = support_frame
|
| 443 |
+
|
| 444 |
+
#TODO: to adjust a little bit
|
| 445 |
+
track_loss=ab_loss=vis_loss=track_loss=conf_loss=dyn_loss=0.0
|
| 446 |
+
B, T, _, H, W = x.shape
|
| 447 |
+
imgs_raw = x.clone()
|
| 448 |
+
# get the video split and features for each segment
|
| 449 |
+
patch_size, x_resize = self.ProcVid(x)
|
| 450 |
+
x_resize = rearrange(x_resize, "(b t) c h w -> b t c h w", b=B)
|
| 451 |
+
H_resize, W_resize = x_resize.shape[-2:]
|
| 452 |
+
|
| 453 |
+
prec_fx = W / W_resize
|
| 454 |
+
prec_fy = H / H_resize
|
| 455 |
+
# get patch size
|
| 456 |
+
P_H, P_W = patch_size
|
| 457 |
+
|
| 458 |
+
# get the depth, pointmap and mask
|
| 459 |
+
#TODO: Release DepthAnything Version
|
| 460 |
+
points_map_gt = None
|
| 461 |
+
with torch.no_grad():
|
| 462 |
+
if_gt_depth = (("depth_gt" in annots.keys())) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3)
|
| 463 |
+
if if_gt_depth==False:
|
| 464 |
+
if cache is not None:
|
| 465 |
+
T_cache = cache["points_map"].shape[0]
|
| 466 |
+
T_new = T - T_cache
|
| 467 |
+
x_resize_new = x_resize[:, T_cache:]
|
| 468 |
+
else:
|
| 469 |
+
T_new = T
|
| 470 |
+
x_resize_new = x_resize
|
| 471 |
+
# infer with chunk
|
| 472 |
+
chunk_size = self.chunk_size
|
| 473 |
+
metric_depth = []
|
| 474 |
+
intrs = []
|
| 475 |
+
unc_metric = []
|
| 476 |
+
mask = []
|
| 477 |
+
points_map = []
|
| 478 |
+
normals = []
|
| 479 |
+
normals_mask = []
|
| 480 |
+
for i in range(0, B*T_new, chunk_size):
|
| 481 |
+
output = self.base_model.infer(x_resize_new.view(B*T_new, -1, H_resize, W_resize)[i:i+chunk_size])
|
| 482 |
+
metric_depth.append(output['depth'])
|
| 483 |
+
intrs.append(output['intrinsics'])
|
| 484 |
+
unc_metric.append(output['mask_prob'])
|
| 485 |
+
mask.append(output['mask'])
|
| 486 |
+
points_map.append(output['points'])
|
| 487 |
+
normals_i, normals_mask_i = utils3d.torch.points_to_normals(output['points'], mask=output['mask'])
|
| 488 |
+
normals.append(normals_i)
|
| 489 |
+
normals_mask.append(normals_mask_i)
|
| 490 |
+
|
| 491 |
+
metric_depth = torch.cat(metric_depth, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
| 492 |
+
intrs = torch.cat(intrs, dim=0).view(B, T_new, 3, 3).to(x.dtype)
|
| 493 |
+
intrs[:,:,0,:] *= W_resize
|
| 494 |
+
intrs[:,:,1,:] *= H_resize
|
| 495 |
+
# points_map = torch.cat(points_map, dim=0)
|
| 496 |
+
mask = torch.cat(mask, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
| 497 |
+
# cat the normals
|
| 498 |
+
normals = torch.cat(normals, dim=0)
|
| 499 |
+
normals_mask = torch.cat(normals_mask, dim=0)
|
| 500 |
+
|
| 501 |
+
metric_depth = metric_depth.clone()
|
| 502 |
+
metric_depth[metric_depth == torch.inf] = 0
|
| 503 |
+
_depths = metric_depth[metric_depth > 0].reshape(-1)
|
| 504 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
| 505 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
| 506 |
+
iqr = q75 - q25
|
| 507 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
| 508 |
+
_depth_roi = torch.tensor(
|
| 509 |
+
[1e-1, upper_bound.item()],
|
| 510 |
+
dtype=metric_depth.dtype,
|
| 511 |
+
device=metric_depth.device
|
| 512 |
+
)
|
| 513 |
+
mask_roi = (metric_depth > _depth_roi[0]) & (metric_depth < _depth_roi[1])
|
| 514 |
+
mask = mask * mask_roi
|
| 515 |
+
mask = mask * (~(utils3d.torch.depth_edge(metric_depth, rtol=0.03, mask=mask.bool()))) * normals_mask[:,None,...]
|
| 516 |
+
points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T_new, 3, 3))
|
| 517 |
+
unc_metric = torch.cat(unc_metric, dim=0).view(B*T_new, 1, H_resize, W_resize).to(x.dtype)
|
| 518 |
+
unc_metric *= mask
|
| 519 |
+
if full_point:
|
| 520 |
+
unc_metric = (~(utils3d.torch.depth_edge(metric_depth, rtol=0.1, mask=torch.ones_like(metric_depth).bool()))).float() * (metric_depth != 0)
|
| 521 |
+
if cache is not None:
|
| 522 |
+
assert B==1, "only support batch size 1 right now."
|
| 523 |
+
unc_metric = torch.cat([cache["unc_metric"], unc_metric], dim=0)
|
| 524 |
+
intrs = torch.cat([cache["intrs"][None], intrs], dim=1)
|
| 525 |
+
points_map = torch.cat([cache["points_map"].permute(0,2,3,1), points_map], dim=0)
|
| 526 |
+
metric_depth = torch.cat([cache["metric_depth"], metric_depth], dim=0)
|
| 527 |
+
|
| 528 |
+
if "poses_gt" in annots.keys():
|
| 529 |
+
intrs, c2w_traj_gt = pose_enc2mat(annots["poses_gt"],
|
| 530 |
+
H_resize, W_resize, self.resolution)
|
| 531 |
+
else:
|
| 532 |
+
c2w_traj_gt = None
|
| 533 |
+
|
| 534 |
+
if "intrs_gt" in annots.keys():
|
| 535 |
+
intrs = annots["intrs_gt"].view(B, T, 3, 3)
|
| 536 |
+
fx_factor = W_resize / W
|
| 537 |
+
fy_factor = H_resize / H
|
| 538 |
+
intrs[:,:,0,:] *= fx_factor
|
| 539 |
+
intrs[:,:,1,:] *= fy_factor
|
| 540 |
+
|
| 541 |
+
if "depth_gt" in annots.keys():
|
| 542 |
+
|
| 543 |
+
metric_depth_gt = annots['depth_gt'].view(B*T, 1, H, W)
|
| 544 |
+
metric_depth_gt = F.interpolate(metric_depth_gt,
|
| 545 |
+
size=(H_resize, W_resize), mode='nearest')
|
| 546 |
+
|
| 547 |
+
_depths = metric_depth_gt[metric_depth_gt > 0].reshape(-1)
|
| 548 |
+
q25 = torch.kthvalue(_depths, int(0.25 * len(_depths))).values
|
| 549 |
+
q75 = torch.kthvalue(_depths, int(0.75 * len(_depths))).values
|
| 550 |
+
iqr = q75 - q25
|
| 551 |
+
upper_bound = (q75 + 0.8*iqr).clamp(min=1e-6, max=10*q25)
|
| 552 |
+
_depth_roi = torch.tensor(
|
| 553 |
+
[1e-1, upper_bound.item()],
|
| 554 |
+
dtype=metric_depth_gt.dtype,
|
| 555 |
+
device=metric_depth_gt.device
|
| 556 |
+
)
|
| 557 |
+
mask_roi = (metric_depth_gt > _depth_roi[0]) & (metric_depth_gt < _depth_roi[1])
|
| 558 |
+
# if (upper_bound > 200).any():
|
| 559 |
+
# import pdb; pdb.set_trace()
|
| 560 |
+
if (kwargs.get('stage', 0) == 2):
|
| 561 |
+
unc_metric = ((metric_depth_gt > 0)*(mask_roi) * (unc_metric > 0.5)).float()
|
| 562 |
+
metric_depth_gt[metric_depth_gt > 10*q25] = 0
|
| 563 |
+
else:
|
| 564 |
+
unc_metric = ((metric_depth_gt > 0)*(mask_roi)).float()
|
| 565 |
+
unc_metric *= (~(utils3d.torch.depth_edge(metric_depth_gt, rtol=0.03, mask=mask_roi.bool()))).float()
|
| 566 |
+
# filter the sky
|
| 567 |
+
metric_depth_gt[metric_depth_gt > 10*q25] = 0
|
| 568 |
+
if "unc_metric" in annots.keys():
|
| 569 |
+
unc_metric_ = F.interpolate(annots["unc_metric"].permute(1,0,2,3),
|
| 570 |
+
size=(H_resize, W_resize), mode='nearest')
|
| 571 |
+
unc_metric = unc_metric * unc_metric_
|
| 572 |
+
if if_gt_depth:
|
| 573 |
+
points_map = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
| 574 |
+
metric_depth = metric_depth_gt
|
| 575 |
+
points_map_gt = points_map
|
| 576 |
+
else:
|
| 577 |
+
points_map_gt = depth_to_points_colmap(metric_depth_gt.squeeze(1), intrs.view(B*T, 3, 3))
|
| 578 |
+
|
| 579 |
+
# track the 3d points
|
| 580 |
+
ret_track = None
|
| 581 |
+
regular_track = True
|
| 582 |
+
dyn_preds, final_tracks = None, None
|
| 583 |
+
|
| 584 |
+
if "use_extr" in annots.keys():
|
| 585 |
+
init_pose = True
|
| 586 |
+
else:
|
| 587 |
+
init_pose = False
|
| 588 |
+
# set the custom vid and valid only
|
| 589 |
+
custom_vid = annots.get("custom_vid", False)
|
| 590 |
+
valid_only = annots.get("data_dir", [""])[0] == "stereo4d"
|
| 591 |
+
if self.training:
|
| 592 |
+
if (annots["vis"].sum() > 0) and (kwargs.get('stage', 0)==1 or kwargs.get('stage', 0)==3):
|
| 593 |
+
traj3d = annots['traj_3d']
|
| 594 |
+
if (kwargs.get('stage', 0)==1) and (annots.get("custom_vid", False)==False):
|
| 595 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
| 596 |
+
T, x.device, query_size=self.track_num // 2,
|
| 597 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
|
| 598 |
+
else:
|
| 599 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
| 600 |
+
T, x.device, query_size=random.randint(32, 256),
|
| 601 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental")[None]
|
| 602 |
+
if pts_q is not None:
|
| 603 |
+
pts_q = pts_q[None,None]
|
| 604 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
| 605 |
+
metric_depth,
|
| 606 |
+
unc_metric.detach(), points_map, pts_q,
|
| 607 |
+
intrs=intrs.clone(), cache=cache,
|
| 608 |
+
prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
|
| 609 |
+
vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
|
| 610 |
+
cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
|
| 611 |
+
init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
|
| 612 |
+
points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
|
| 613 |
+
else:
|
| 614 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
| 615 |
+
metric_depth,
|
| 616 |
+
unc_metric.detach(), points_map, traj3d[..., :2],
|
| 617 |
+
intrs=intrs.clone(), cache=cache,
|
| 618 |
+
prec_fx=prec_fx, prec_fy=prec_fy, overlap_d=overlap_d,
|
| 619 |
+
vis_gt=annots['vis'], traj3d_gt=traj3d, iters=iters_track,
|
| 620 |
+
cam_gt=c2w_traj_gt, support_pts_q=support_pts_q, custom_vid=custom_vid,
|
| 621 |
+
init_pose=init_pose, fixed_cam=fixed_cam, stage=kwargs.get('stage', 0),
|
| 622 |
+
points_map_gt=points_map_gt, valid_only=valid_only, replace_ratio=replace_ratio)
|
| 623 |
+
regular_track = False
|
| 624 |
+
|
| 625 |
+
|
| 626 |
+
if regular_track:
|
| 627 |
+
if pts_q is None:
|
| 628 |
+
pts_q = get_track_points(H_resize, W_resize,
|
| 629 |
+
T, x.device, query_size=self.track_num,
|
| 630 |
+
support_frame=self.support_frame, unc_metric=unc_metric, mode="incremental" if self.training else "incremental")[None]
|
| 631 |
+
support_pts_q = None
|
| 632 |
+
else:
|
| 633 |
+
pts_q = pts_q[None,None]
|
| 634 |
+
# resize the query points
|
| 635 |
+
pts_q[...,1] *= W_resize / W
|
| 636 |
+
pts_q[...,2] *= H_resize / H
|
| 637 |
+
|
| 638 |
+
if pts_q_3d is not None:
|
| 639 |
+
pts_q_3d = pts_q_3d[None,None]
|
| 640 |
+
# resize the query points
|
| 641 |
+
pts_q_3d[...,1] *= W_resize / W
|
| 642 |
+
pts_q_3d[...,2] *= H_resize / H
|
| 643 |
+
else:
|
| 644 |
+
# adjust the query with uncertainty
|
| 645 |
+
if (full_point==False) and (overlap_d is None):
|
| 646 |
+
pts_q_unc = sample_features5d(unc_metric[None], pts_q).squeeze()
|
| 647 |
+
pts_q = pts_q[:,:,pts_q_unc>0.5,:]
|
| 648 |
+
if (pts_q_unc<0.5).sum() > 0:
|
| 649 |
+
# pad the query points
|
| 650 |
+
pad_num = pts_q_unc.shape[0] - pts_q.shape[2]
|
| 651 |
+
# pick the random indices
|
| 652 |
+
indices = torch.randint(0, pts_q.shape[2], (pad_num,), device=pts_q.device)
|
| 653 |
+
pad_pts = indices
|
| 654 |
+
pts_q = torch.cat([pts_q, pts_q[:,:,pad_pts,:]], dim=-2)
|
| 655 |
+
|
| 656 |
+
support_pts_q = get_track_points(H_resize, W_resize,
|
| 657 |
+
T, x.device, query_size=self.track_num,
|
| 658 |
+
support_frame=self.support_frame,
|
| 659 |
+
unc_metric=unc_metric, mode="mixed")[None]
|
| 660 |
+
|
| 661 |
+
points_map[points_map>1e3] = 0
|
| 662 |
+
points_map = depth_to_points_colmap(metric_depth.squeeze(1), intrs.view(B*T, 3, 3))
|
| 663 |
+
ret_track, dyn_preds, final_tracks, rgb_tracks, intrs_org, point_map_org_refined, cache = self.Track3D(imgs_raw,
|
| 664 |
+
metric_depth,
|
| 665 |
+
unc_metric.detach(), points_map, pts_q,
|
| 666 |
+
pts_q_3d=pts_q_3d, intrs=intrs.clone(),cache=cache,
|
| 667 |
+
overlap_d=overlap_d, cam_gt=c2w_traj_gt if kwargs.get('stage', 0)==1 else None,
|
| 668 |
+
prec_fx=prec_fx, prec_fy=prec_fy, support_pts_q=support_pts_q, custom_vid=custom_vid, valid_only=valid_only,
|
| 669 |
+
fixed_cam=fixed_cam, query_no_BA=query_no_BA, init_pose=init_pose, iters=iters_track,
|
| 670 |
+
stage=kwargs.get('stage', 0), points_map_gt=points_map_gt, replace_ratio=replace_ratio)
|
| 671 |
+
intrs = intrs_org
|
| 672 |
+
points_map = point_map_org_refined
|
| 673 |
+
c2w_traj = ret_track["cam_pred"]
|
| 674 |
+
|
| 675 |
+
if ret_track is not None:
|
| 676 |
+
if ret_track["loss"] is not None:
|
| 677 |
+
track_loss, conf_loss, dyn_loss, vis_loss, point_map_loss, scale_loss, shift_loss, sync_loss= ret_track["loss"]
|
| 678 |
+
|
| 679 |
+
# update the cache
|
| 680 |
+
cache.update({"metric_depth": metric_depth, "unc_metric": unc_metric, "points_map": points_map, "intrs": intrs[0]})
|
| 681 |
+
# output
|
| 682 |
+
depth = F.interpolate(metric_depth,
|
| 683 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
| 684 |
+
points_map = F.interpolate(points_map,
|
| 685 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
| 686 |
+
unc_metric = F.interpolate(unc_metric,
|
| 687 |
+
size=(H, W), mode='bilinear', align_corners=True).squeeze(1)
|
| 688 |
+
|
| 689 |
+
if self.training:
|
| 690 |
+
|
| 691 |
+
loss = track_loss + conf_loss + dyn_loss + sync_loss + vis_loss + point_map_loss + (scale_loss + shift_loss)*50
|
| 692 |
+
ret = {"loss": loss,
|
| 693 |
+
"depth_loss": point_map_loss,
|
| 694 |
+
"ab_loss": (scale_loss + shift_loss)*50,
|
| 695 |
+
"vis_loss": vis_loss, "track_loss": track_loss,
|
| 696 |
+
"poses_pred": c2w_traj, "dyn_preds": dyn_preds, "traj_est": final_tracks, "conf_loss": conf_loss,
|
| 697 |
+
"imgs_raw": imgs_raw, "rgb_tracks": rgb_tracks, "vis_est": ret_track['vis_pred'],
|
| 698 |
+
"depth": depth, "points_map": points_map, "unc_metric": unc_metric, "intrs": intrs, "dyn_loss": dyn_loss,
|
| 699 |
+
"sync_loss": sync_loss, "conf_pred": ret_track['conf_pred'], "cache": cache,
|
| 700 |
+
}
|
| 701 |
+
|
| 702 |
+
else:
|
| 703 |
+
|
| 704 |
+
if ret_track is not None:
|
| 705 |
+
traj_est = ret_track['preds']
|
| 706 |
+
traj_est[..., 0] *= W / W_resize
|
| 707 |
+
traj_est[..., 1] *= H / H_resize
|
| 708 |
+
vis_est = ret_track['vis_pred']
|
| 709 |
+
else:
|
| 710 |
+
traj_est = torch.zeros(B, self.track_num // 2, 3).to(x.device)
|
| 711 |
+
vis_est = torch.zeros(B, self.track_num // 2).to(x.device)
|
| 712 |
+
|
| 713 |
+
if intrs is not None:
|
| 714 |
+
intrs[..., 0, :] *= W / W_resize
|
| 715 |
+
intrs[..., 1, :] *= H / H_resize
|
| 716 |
+
ret = {"poses_pred": c2w_traj, "dyn_preds": dyn_preds,
|
| 717 |
+
"depth": depth, "traj_est": traj_est, "vis_est": vis_est, "imgs_raw": imgs_raw,
|
| 718 |
+
"rgb_tracks": rgb_tracks, "intrs": intrs, "unc_metric": unc_metric, "points_map": points_map,
|
| 719 |
+
"conf_pred": ret_track['conf_pred'], "cache": cache,
|
| 720 |
+
}
|
| 721 |
+
|
| 722 |
+
return ret
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
|
| 726 |
+
|
| 727 |
+
# three stages of training
|
| 728 |
+
|
| 729 |
+
# stage 1:
|
| 730 |
+
# gt depth and intrinsics synthetic (includes Dynamic Replica, Kubric, Pointodyssey, Vkitti, TartanAir and Indoor() ) Motion Patern (tapvid3d)
|
| 731 |
+
# Tracking and Pose as well -> based on gt depth and intrinsics
|
| 732 |
+
# (Finished) -> (megasam + base model) vs. tapip3d. (use depth from megasam or pose, which keep the same setting as tapip3d.)
|
| 733 |
+
|
| 734 |
+
# stage 2: fixed 3D tracking
|
| 735 |
+
# Joint depth refiner
|
| 736 |
+
# input depth from whatever + rgb -> temporal module + scale and shift token -> coarse alignment -> scale and shift
|
| 737 |
+
# estimate the 3D tracks -> 3D tracks combine with pointmap -> update for pointmap (iteratively) -> residual map B T 3 H W
|
| 738 |
+
# ongoing two days
|
| 739 |
+
|
| 740 |
+
# stage 3: train multi windows by propagation
|
| 741 |
+
# 4 frames overlapped -> train on 64 -> fozen image encoder and finetuning the transformer (learnable parameters pretty small)
|
| 742 |
+
|
| 743 |
+
# types of scenarioes:
|
| 744 |
+
# 1. auto driving (waymo open dataset)
|
| 745 |
+
# 2. robot
|
| 746 |
+
# 3. internet ego video
|
| 747 |
+
|
| 748 |
+
|
| 749 |
+
|
| 750 |
+
# Iterative Transformer -- Solver -- General Neural MegaSAM + Tracks
|
| 751 |
+
# Update Variables:
|
| 752 |
+
# 1. 3D tracks B T N 3 xyz.
|
| 753 |
+
# 2. 2D tracks B T N 2 x y.
|
| 754 |
+
# 3. Dynamic Mask B T H W.
|
| 755 |
+
# 4. Camera Pose B T 4 4.
|
| 756 |
+
# 5. Video Depth.
|
| 757 |
+
|
| 758 |
+
# (RGB, RGBD, RGBD+Pose) x (Static, Dynamic)
|
| 759 |
+
# Campatiablity by product.
|
models/SpaTrackV2/models/__init__.py
ADDED
|
File without changes
|
models/SpaTrackV2/models/blocks.py
ADDED
|
@@ -0,0 +1,519 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torch.cuda.amp import autocast
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
import collections
|
| 13 |
+
from functools import partial
|
| 14 |
+
from itertools import repeat
|
| 15 |
+
import torchvision.models as tvm
|
| 16 |
+
from torch.utils.checkpoint import checkpoint
|
| 17 |
+
from models.monoD.depth_anything.dpt import DPTHeadEnc, DPTHead
|
| 18 |
+
from typing import Union, Tuple
|
| 19 |
+
from torch import Tensor
|
| 20 |
+
|
| 21 |
+
# From PyTorch internals
|
| 22 |
+
def _ntuple(n):
|
| 23 |
+
def parse(x):
|
| 24 |
+
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
|
| 25 |
+
return tuple(x)
|
| 26 |
+
return tuple(repeat(x, n))
|
| 27 |
+
|
| 28 |
+
return parse
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def exists(val):
|
| 32 |
+
return val is not None
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def default(val, d):
|
| 36 |
+
return val if exists(val) else d
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
to_2tuple = _ntuple(2)
|
| 40 |
+
|
| 41 |
+
class LayerScale(nn.Module):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
dim: int,
|
| 45 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 46 |
+
inplace: bool = False,
|
| 47 |
+
) -> None:
|
| 48 |
+
super().__init__()
|
| 49 |
+
self.inplace = inplace
|
| 50 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 51 |
+
|
| 52 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 53 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 54 |
+
|
| 55 |
+
class Mlp(nn.Module):
|
| 56 |
+
"""MLP as used in Vision Transformer, MLP-Mixer and related networks"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
in_features,
|
| 61 |
+
hidden_features=None,
|
| 62 |
+
out_features=None,
|
| 63 |
+
act_layer=nn.GELU,
|
| 64 |
+
norm_layer=None,
|
| 65 |
+
bias=True,
|
| 66 |
+
drop=0.0,
|
| 67 |
+
use_conv=False,
|
| 68 |
+
):
|
| 69 |
+
super().__init__()
|
| 70 |
+
out_features = out_features or in_features
|
| 71 |
+
hidden_features = hidden_features or in_features
|
| 72 |
+
bias = to_2tuple(bias)
|
| 73 |
+
drop_probs = to_2tuple(drop)
|
| 74 |
+
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
| 75 |
+
|
| 76 |
+
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
| 77 |
+
self.act = act_layer()
|
| 78 |
+
self.drop1 = nn.Dropout(drop_probs[0])
|
| 79 |
+
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
| 80 |
+
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
| 81 |
+
self.drop2 = nn.Dropout(drop_probs[1])
|
| 82 |
+
|
| 83 |
+
def forward(self, x):
|
| 84 |
+
x = self.fc1(x)
|
| 85 |
+
x = self.act(x)
|
| 86 |
+
x = self.drop1(x)
|
| 87 |
+
x = self.fc2(x)
|
| 88 |
+
x = self.drop2(x)
|
| 89 |
+
return x
|
| 90 |
+
|
| 91 |
+
class Attention(nn.Module):
|
| 92 |
+
def __init__(self, query_dim, context_dim=None,
|
| 93 |
+
num_heads=8, dim_head=48, qkv_bias=False, flash=False):
|
| 94 |
+
super().__init__()
|
| 95 |
+
inner_dim = self.inner_dim = dim_head * num_heads
|
| 96 |
+
context_dim = default(context_dim, query_dim)
|
| 97 |
+
self.scale = dim_head**-0.5
|
| 98 |
+
self.heads = num_heads
|
| 99 |
+
self.flash = flash
|
| 100 |
+
|
| 101 |
+
self.to_q = nn.Linear(query_dim, inner_dim, bias=qkv_bias)
|
| 102 |
+
self.to_kv = nn.Linear(context_dim, inner_dim * 2, bias=qkv_bias)
|
| 103 |
+
self.to_out = nn.Linear(inner_dim, query_dim)
|
| 104 |
+
|
| 105 |
+
def forward(self, x, context=None, attn_bias=None):
|
| 106 |
+
B, N1, _ = x.shape
|
| 107 |
+
C = self.inner_dim
|
| 108 |
+
h = self.heads
|
| 109 |
+
q = self.to_q(x).reshape(B, N1, h, C // h).permute(0, 2, 1, 3)
|
| 110 |
+
context = default(context, x)
|
| 111 |
+
k, v = self.to_kv(context).chunk(2, dim=-1)
|
| 112 |
+
|
| 113 |
+
N2 = context.shape[1]
|
| 114 |
+
k = k.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
| 115 |
+
v = v.reshape(B, N2, h, C // h).permute(0, 2, 1, 3)
|
| 116 |
+
|
| 117 |
+
with torch.autocast("cuda", enabled=True, dtype=torch.bfloat16):
|
| 118 |
+
if self.flash==False:
|
| 119 |
+
sim = (q @ k.transpose(-2, -1)) * self.scale
|
| 120 |
+
if attn_bias is not None:
|
| 121 |
+
sim = sim + attn_bias
|
| 122 |
+
if sim.abs().max()>1e2:
|
| 123 |
+
import pdb; pdb.set_trace()
|
| 124 |
+
attn = sim.softmax(dim=-1)
|
| 125 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N1, C)
|
| 126 |
+
else:
|
| 127 |
+
input_args = [x.contiguous() for x in [q, k, v]]
|
| 128 |
+
x = F.scaled_dot_product_attention(*input_args).permute(0,2,1,3).reshape(B,N1,-1) # type: ignore
|
| 129 |
+
|
| 130 |
+
if self.to_out.bias.dtype != x.dtype:
|
| 131 |
+
x = x.to(self.to_out.bias.dtype)
|
| 132 |
+
|
| 133 |
+
return self.to_out(x)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class VGG19(nn.Module):
|
| 137 |
+
def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
|
| 138 |
+
super().__init__()
|
| 139 |
+
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 140 |
+
self.amp = amp
|
| 141 |
+
self.amp_dtype = amp_dtype
|
| 142 |
+
|
| 143 |
+
def forward(self, x, **kwargs):
|
| 144 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
| 145 |
+
feats = {}
|
| 146 |
+
scale = 1
|
| 147 |
+
for layer in self.layers:
|
| 148 |
+
if isinstance(layer, nn.MaxPool2d):
|
| 149 |
+
feats[scale] = x
|
| 150 |
+
scale = scale*2
|
| 151 |
+
x = layer(x)
|
| 152 |
+
return feats
|
| 153 |
+
|
| 154 |
+
class CNNandDinov2(nn.Module):
|
| 155 |
+
def __init__(self, cnn_kwargs = None, amp = True, amp_dtype = torch.float16):
|
| 156 |
+
super().__init__()
|
| 157 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
| 158 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
| 159 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
| 160 |
+
|
| 161 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
| 162 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
cnn_kwargs = cnn_kwargs if cnn_kwargs is not None else {}
|
| 166 |
+
self.cnn = VGG19(**cnn_kwargs)
|
| 167 |
+
self.amp = amp
|
| 168 |
+
self.amp_dtype = amp_dtype
|
| 169 |
+
if self.amp:
|
| 170 |
+
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
| 171 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def train(self, mode: bool = True):
|
| 175 |
+
return self.cnn.train(mode)
|
| 176 |
+
|
| 177 |
+
def forward(self, x, upsample = False):
|
| 178 |
+
B,C,H,W = x.shape
|
| 179 |
+
feature_pyramid = self.cnn(x)
|
| 180 |
+
|
| 181 |
+
if not upsample:
|
| 182 |
+
with torch.no_grad():
|
| 183 |
+
if self.dinov2_vitl14[0].device != x.device:
|
| 184 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
| 185 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
| 186 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
| 187 |
+
del dinov2_features_16
|
| 188 |
+
feature_pyramid[16] = features_16
|
| 189 |
+
return feature_pyramid
|
| 190 |
+
|
| 191 |
+
class Dinov2(nn.Module):
|
| 192 |
+
def __init__(self, amp = True, amp_dtype = torch.float16):
|
| 193 |
+
super().__init__()
|
| 194 |
+
# in case the Internet connection is not stable, please load the DINOv2 locally
|
| 195 |
+
self.dinov2_vitl14 = torch.hub.load('models/torchhub/facebookresearch_dinov2_main',
|
| 196 |
+
'dinov2_{:}14'.format("vitl"), source='local', pretrained=False)
|
| 197 |
+
|
| 198 |
+
state_dict = torch.load("models/monoD/zoeDepth/ckpts/dinov2_vitl14_pretrain.pth")
|
| 199 |
+
self.dinov2_vitl14.load_state_dict(state_dict, strict=True)
|
| 200 |
+
|
| 201 |
+
self.amp = amp
|
| 202 |
+
self.amp_dtype = amp_dtype
|
| 203 |
+
if self.amp:
|
| 204 |
+
self.dinov2_vitl14 = self.dinov2_vitl14.to(self.amp_dtype)
|
| 205 |
+
|
| 206 |
+
def forward(self, x, upsample = False):
|
| 207 |
+
B,C,H,W = x.shape
|
| 208 |
+
mean_ = torch.tensor([0.485, 0.456, 0.406],
|
| 209 |
+
device=x.device).view(1, 3, 1, 1)
|
| 210 |
+
std_ = torch.tensor([0.229, 0.224, 0.225],
|
| 211 |
+
device=x.device).view(1, 3, 1, 1)
|
| 212 |
+
x = (x+1)/2
|
| 213 |
+
x = (x - mean_)/std_
|
| 214 |
+
h_re, w_re = 560, 560
|
| 215 |
+
x_resize = F.interpolate(x, size=(h_re, w_re),
|
| 216 |
+
mode='bilinear', align_corners=True)
|
| 217 |
+
if not upsample:
|
| 218 |
+
with torch.no_grad():
|
| 219 |
+
dinov2_features_16 = self.dinov2_vitl14.forward_features(x_resize.to(self.amp_dtype))
|
| 220 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,h_re//14, w_re//14)
|
| 221 |
+
del dinov2_features_16
|
| 222 |
+
features_16 = F.interpolate(features_16, size=(H//8, W//8), mode="bilinear", align_corners=True)
|
| 223 |
+
return features_16
|
| 224 |
+
|
| 225 |
+
class AttnBlock(nn.Module):
|
| 226 |
+
"""
|
| 227 |
+
A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning.
|
| 228 |
+
"""
|
| 229 |
+
|
| 230 |
+
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0,
|
| 231 |
+
flash=False, ckpt_fwd=False, debug=False, **block_kwargs):
|
| 232 |
+
super().__init__()
|
| 233 |
+
self.debug=debug
|
| 234 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 235 |
+
self.flash=flash
|
| 236 |
+
|
| 237 |
+
self.attn = Attention(
|
| 238 |
+
hidden_size, num_heads=num_heads, qkv_bias=True, flash=flash,
|
| 239 |
+
**block_kwargs
|
| 240 |
+
)
|
| 241 |
+
self.ls = LayerScale(hidden_size, init_values=0.005)
|
| 242 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 243 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 244 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 245 |
+
self.mlp = Mlp(
|
| 246 |
+
in_features=hidden_size,
|
| 247 |
+
hidden_features=mlp_hidden_dim,
|
| 248 |
+
act_layer=approx_gelu,
|
| 249 |
+
)
|
| 250 |
+
self.ckpt_fwd = ckpt_fwd
|
| 251 |
+
def forward(self, x):
|
| 252 |
+
if self.debug:
|
| 253 |
+
print(x.max(), x.min(), x.mean())
|
| 254 |
+
if self.ckpt_fwd:
|
| 255 |
+
x = x + checkpoint(self.attn, self.norm1(x), use_reentrant=False)
|
| 256 |
+
else:
|
| 257 |
+
x = x + self.attn(self.norm1(x))
|
| 258 |
+
|
| 259 |
+
x = x + self.ls(self.mlp(self.norm2(x)))
|
| 260 |
+
return x
|
| 261 |
+
|
| 262 |
+
class CrossAttnBlock(nn.Module):
|
| 263 |
+
def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, head_dim=48,
|
| 264 |
+
flash=False, ckpt_fwd=False, **block_kwargs):
|
| 265 |
+
super().__init__()
|
| 266 |
+
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 267 |
+
self.norm_context = nn.LayerNorm(hidden_size)
|
| 268 |
+
|
| 269 |
+
self.cross_attn = Attention(
|
| 270 |
+
hidden_size, context_dim=context_dim, dim_head=head_dim,
|
| 271 |
+
num_heads=num_heads, qkv_bias=True, **block_kwargs, flash=flash,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 275 |
+
mlp_hidden_dim = int(hidden_size * mlp_ratio)
|
| 276 |
+
approx_gelu = lambda: nn.GELU(approximate="tanh")
|
| 277 |
+
self.mlp = Mlp(
|
| 278 |
+
in_features=hidden_size,
|
| 279 |
+
hidden_features=mlp_hidden_dim,
|
| 280 |
+
act_layer=approx_gelu,
|
| 281 |
+
drop=0,
|
| 282 |
+
)
|
| 283 |
+
self.ckpt_fwd = ckpt_fwd
|
| 284 |
+
def forward(self, x, context):
|
| 285 |
+
if self.ckpt_fwd:
|
| 286 |
+
with autocast():
|
| 287 |
+
x = x + checkpoint(self.cross_attn,
|
| 288 |
+
self.norm1(x), self.norm_context(context), use_reentrant=False)
|
| 289 |
+
else:
|
| 290 |
+
with autocast():
|
| 291 |
+
x = x + self.cross_attn(
|
| 292 |
+
self.norm1(x), self.norm_context(context)
|
| 293 |
+
)
|
| 294 |
+
x = x + self.mlp(self.norm2(x))
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def bilinear_sampler(img, coords, mode="bilinear", mask=False):
|
| 299 |
+
"""Wrapper for grid_sample, uses pixel coordinates"""
|
| 300 |
+
H, W = img.shape[-2:]
|
| 301 |
+
xgrid, ygrid = coords.split([1, 1], dim=-1)
|
| 302 |
+
# go to 0,1 then 0,2 then -1,1
|
| 303 |
+
xgrid = 2 * xgrid / (W - 1) - 1
|
| 304 |
+
ygrid = 2 * ygrid / (H - 1) - 1
|
| 305 |
+
|
| 306 |
+
grid = torch.cat([xgrid, ygrid], dim=-1)
|
| 307 |
+
img = F.grid_sample(img, grid, align_corners=True, mode=mode)
|
| 308 |
+
|
| 309 |
+
if mask:
|
| 310 |
+
mask = (xgrid > -1) & (ygrid > -1) & (xgrid < 1) & (ygrid < 1)
|
| 311 |
+
return img, mask.float()
|
| 312 |
+
|
| 313 |
+
return img
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class CorrBlock:
|
| 317 |
+
def __init__(self, fmaps, num_levels=4, radius=4, depths_dnG=None):
|
| 318 |
+
B, S, C, H_prev, W_prev = fmaps.shape
|
| 319 |
+
self.S, self.C, self.H, self.W = S, C, H_prev, W_prev
|
| 320 |
+
|
| 321 |
+
self.num_levels = num_levels
|
| 322 |
+
self.radius = radius
|
| 323 |
+
self.fmaps_pyramid = []
|
| 324 |
+
self.depth_pyramid = []
|
| 325 |
+
self.fmaps_pyramid.append(fmaps)
|
| 326 |
+
if depths_dnG is not None:
|
| 327 |
+
self.depth_pyramid.append(depths_dnG)
|
| 328 |
+
for i in range(self.num_levels - 1):
|
| 329 |
+
if depths_dnG is not None:
|
| 330 |
+
depths_dnG_ = depths_dnG.reshape(B * S, 1, H_prev, W_prev)
|
| 331 |
+
depths_dnG_ = F.avg_pool2d(depths_dnG_, 2, stride=2)
|
| 332 |
+
_, _, H, W = depths_dnG_.shape
|
| 333 |
+
depths_dnG = depths_dnG_.reshape(B, S, 1, H, W)
|
| 334 |
+
self.depth_pyramid.append(depths_dnG)
|
| 335 |
+
fmaps_ = fmaps.reshape(B * S, C, H_prev, W_prev)
|
| 336 |
+
fmaps_ = F.avg_pool2d(fmaps_, 2, stride=2)
|
| 337 |
+
_, _, H, W = fmaps_.shape
|
| 338 |
+
fmaps = fmaps_.reshape(B, S, C, H, W)
|
| 339 |
+
H_prev = H
|
| 340 |
+
W_prev = W
|
| 341 |
+
self.fmaps_pyramid.append(fmaps)
|
| 342 |
+
|
| 343 |
+
def sample(self, coords):
|
| 344 |
+
r = self.radius
|
| 345 |
+
B, S, N, D = coords.shape
|
| 346 |
+
assert D == 2
|
| 347 |
+
|
| 348 |
+
H, W = self.H, self.W
|
| 349 |
+
out_pyramid = []
|
| 350 |
+
for i in range(self.num_levels):
|
| 351 |
+
corrs = self.corrs_pyramid[i] # B, S, N, H, W
|
| 352 |
+
_, _, _, H, W = corrs.shape
|
| 353 |
+
|
| 354 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 355 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 356 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
| 357 |
+
coords.device
|
| 358 |
+
)
|
| 359 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
| 360 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 361 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 362 |
+
corrs = bilinear_sampler(corrs.reshape(B * S * N, 1, H, W), coords_lvl)
|
| 363 |
+
corrs = corrs.view(B, S, N, -1)
|
| 364 |
+
out_pyramid.append(corrs)
|
| 365 |
+
|
| 366 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
| 367 |
+
return out.contiguous().float()
|
| 368 |
+
|
| 369 |
+
def corr(self, targets):
|
| 370 |
+
B, S, N, C = targets.shape
|
| 371 |
+
assert C == self.C
|
| 372 |
+
assert S == self.S
|
| 373 |
+
|
| 374 |
+
fmap1 = targets
|
| 375 |
+
|
| 376 |
+
self.corrs_pyramid = []
|
| 377 |
+
for fmaps in self.fmaps_pyramid:
|
| 378 |
+
_, _, _, H, W = fmaps.shape
|
| 379 |
+
fmap2s = fmaps.view(B, S, C, H * W)
|
| 380 |
+
corrs = torch.matmul(fmap1, fmap2s)
|
| 381 |
+
corrs = corrs.view(B, S, N, H, W)
|
| 382 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 383 |
+
self.corrs_pyramid.append(corrs)
|
| 384 |
+
|
| 385 |
+
def corr_sample(self, targets, coords, coords_dp=None):
|
| 386 |
+
B, S, N, C = targets.shape
|
| 387 |
+
r = self.radius
|
| 388 |
+
Dim_c = (2*r+1)**2
|
| 389 |
+
assert C == self.C
|
| 390 |
+
assert S == self.S
|
| 391 |
+
|
| 392 |
+
out_pyramid = []
|
| 393 |
+
out_pyramid_dp = []
|
| 394 |
+
for i in range(self.num_levels):
|
| 395 |
+
dx = torch.linspace(-r, r, 2 * r + 1)
|
| 396 |
+
dy = torch.linspace(-r, r, 2 * r + 1)
|
| 397 |
+
delta = torch.stack(torch.meshgrid(dy, dx, indexing="ij"), axis=-1).to(
|
| 398 |
+
coords.device
|
| 399 |
+
)
|
| 400 |
+
centroid_lvl = coords.reshape(B * S * N, 1, 1, 2) / 2 ** i
|
| 401 |
+
delta_lvl = delta.view(1, 2 * r + 1, 2 * r + 1, 2)
|
| 402 |
+
coords_lvl = centroid_lvl + delta_lvl
|
| 403 |
+
fmaps = self.fmaps_pyramid[i]
|
| 404 |
+
_, _, _, H, W = fmaps.shape
|
| 405 |
+
fmap2s = fmaps.view(B*S, C, H, W)
|
| 406 |
+
if len(self.depth_pyramid)>0:
|
| 407 |
+
depths_dnG_i = self.depth_pyramid[i]
|
| 408 |
+
depths_dnG_i = depths_dnG_i.view(B*S, 1, H, W)
|
| 409 |
+
dnG_sample = bilinear_sampler(depths_dnG_i, coords_lvl.view(B*S,1,N*Dim_c,2))
|
| 410 |
+
dp_corrs = (dnG_sample.view(B*S,N,-1) - coords_dp[0]).abs()/coords_dp[0]
|
| 411 |
+
out_pyramid_dp.append(dp_corrs)
|
| 412 |
+
fmap2s_sample = bilinear_sampler(fmap2s, coords_lvl.view(B*S,1,N*Dim_c,2))
|
| 413 |
+
fmap2s_sample = fmap2s_sample.permute(0, 3, 1, 2) # B*S, N*Dim_c, C, -1
|
| 414 |
+
corrs = torch.matmul(targets.reshape(B*S*N, 1, -1), fmap2s_sample.reshape(B*S*N, Dim_c, -1).permute(0, 2, 1))
|
| 415 |
+
corrs = corrs / torch.sqrt(torch.tensor(C).float())
|
| 416 |
+
corrs = corrs.view(B, S, N, -1)
|
| 417 |
+
out_pyramid.append(corrs)
|
| 418 |
+
|
| 419 |
+
out = torch.cat(out_pyramid, dim=-1) # B, S, N, LRR*2
|
| 420 |
+
if len(self.depth_pyramid)>0:
|
| 421 |
+
out_dp = torch.cat(out_pyramid_dp, dim=-1)
|
| 422 |
+
self.fcorrD = out_dp.contiguous().float()
|
| 423 |
+
else:
|
| 424 |
+
self.fcorrD = torch.zeros_like(out).contiguous().float()
|
| 425 |
+
return out.contiguous().float()
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
class EUpdateFormer(nn.Module):
|
| 429 |
+
"""
|
| 430 |
+
Transformer model that updates track estimates.
|
| 431 |
+
"""
|
| 432 |
+
|
| 433 |
+
def __init__(
|
| 434 |
+
self,
|
| 435 |
+
space_depth=12,
|
| 436 |
+
time_depth=12,
|
| 437 |
+
input_dim=320,
|
| 438 |
+
hidden_size=384,
|
| 439 |
+
num_heads=8,
|
| 440 |
+
output_dim=130,
|
| 441 |
+
mlp_ratio=4.0,
|
| 442 |
+
vq_depth=3,
|
| 443 |
+
add_space_attn=True,
|
| 444 |
+
add_time_attn=True,
|
| 445 |
+
flash=True
|
| 446 |
+
):
|
| 447 |
+
super().__init__()
|
| 448 |
+
self.out_channels = 2
|
| 449 |
+
self.num_heads = num_heads
|
| 450 |
+
self.hidden_size = hidden_size
|
| 451 |
+
self.add_space_attn = add_space_attn
|
| 452 |
+
self.input_transform = torch.nn.Linear(input_dim, hidden_size, bias=True)
|
| 453 |
+
self.flash = flash
|
| 454 |
+
self.flow_head = nn.Sequential(
|
| 455 |
+
nn.Linear(hidden_size, output_dim, bias=True),
|
| 456 |
+
nn.ReLU(inplace=True),
|
| 457 |
+
nn.Linear(output_dim, output_dim, bias=True),
|
| 458 |
+
nn.ReLU(inplace=True),
|
| 459 |
+
nn.Linear(output_dim, output_dim, bias=True)
|
| 460 |
+
)
|
| 461 |
+
self.norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
| 462 |
+
cfg = xLSTMBlockStackConfig(
|
| 463 |
+
mlstm_block=mLSTMBlockConfig(
|
| 464 |
+
mlstm=mLSTMLayerConfig(
|
| 465 |
+
conv1d_kernel_size=4, qkv_proj_blocksize=4, num_heads=4
|
| 466 |
+
)
|
| 467 |
+
),
|
| 468 |
+
slstm_block=sLSTMBlockConfig(
|
| 469 |
+
slstm=sLSTMLayerConfig(
|
| 470 |
+
backend="cuda",
|
| 471 |
+
num_heads=4,
|
| 472 |
+
conv1d_kernel_size=4,
|
| 473 |
+
bias_init="powerlaw_blockdependent",
|
| 474 |
+
),
|
| 475 |
+
feedforward=FeedForwardConfig(proj_factor=1.3, act_fn="gelu"),
|
| 476 |
+
),
|
| 477 |
+
context_length=50,
|
| 478 |
+
num_blocks=7,
|
| 479 |
+
embedding_dim=384,
|
| 480 |
+
slstm_at=[1],
|
| 481 |
+
|
| 482 |
+
)
|
| 483 |
+
self.xlstm_fwd = xLSTMBlockStack(cfg)
|
| 484 |
+
self.xlstm_bwd = xLSTMBlockStack(cfg)
|
| 485 |
+
|
| 486 |
+
self.initialize_weights()
|
| 487 |
+
|
| 488 |
+
def initialize_weights(self):
|
| 489 |
+
def _basic_init(module):
|
| 490 |
+
if isinstance(module, nn.Linear):
|
| 491 |
+
torch.nn.init.xavier_uniform_(module.weight)
|
| 492 |
+
if module.bias is not None:
|
| 493 |
+
nn.init.constant_(module.bias, 0)
|
| 494 |
+
|
| 495 |
+
self.apply(_basic_init)
|
| 496 |
+
|
| 497 |
+
def forward(self,
|
| 498 |
+
input_tensor,
|
| 499 |
+
track_mask=None):
|
| 500 |
+
""" Updating with Transformer
|
| 501 |
+
|
| 502 |
+
Args:
|
| 503 |
+
input_tensor: B, N, T, C
|
| 504 |
+
arap_embed: B, N, T, C
|
| 505 |
+
"""
|
| 506 |
+
B, N, T, C = input_tensor.shape
|
| 507 |
+
x = self.input_transform(input_tensor)
|
| 508 |
+
|
| 509 |
+
track_mask = track_mask.permute(0,2,1,3).float()
|
| 510 |
+
fwd_x = x*track_mask
|
| 511 |
+
bwd_x = x.flip(2)*track_mask.flip(2)
|
| 512 |
+
feat_fwd = self.xlstm_fwd(self.norm(fwd_x.view(B*N, T, -1)))
|
| 513 |
+
feat_bwd = self.xlstm_bwd(self.norm(bwd_x.view(B*N, T, -1)))
|
| 514 |
+
feat = (feat_bwd.flip(1) + feat_fwd).view(B, N, T, -1)
|
| 515 |
+
|
| 516 |
+
flow = self.flow_head(feat)
|
| 517 |
+
|
| 518 |
+
return flow[..., :2], flow[..., 2:]
|
| 519 |
+
|
models/SpaTrackV2/models/camera_transform.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
# Adapted from https://github.com/amyxlase/relpose-plus-plus
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
import numpy as np
|
| 12 |
+
import math
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def bbox_xyxy_to_xywh(xyxy):
|
| 18 |
+
wh = xyxy[2:] - xyxy[:2]
|
| 19 |
+
xywh = np.concatenate([xyxy[:2], wh])
|
| 20 |
+
return xywh
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def adjust_camera_to_bbox_crop_(fl, pp, image_size_wh: torch.Tensor, clamp_bbox_xywh: torch.Tensor):
|
| 24 |
+
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, image_size_wh)
|
| 25 |
+
|
| 26 |
+
principal_point_px_cropped = principal_point_px - clamp_bbox_xywh[:2]
|
| 27 |
+
|
| 28 |
+
focal_length, principal_point_cropped = _convert_pixels_to_ndc(
|
| 29 |
+
focal_length_px, principal_point_px_cropped, clamp_bbox_xywh[2:]
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
return focal_length, principal_point_cropped
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def adjust_camera_to_image_scale_(fl, pp, original_size_wh: torch.Tensor, new_size_wh: torch.LongTensor):
|
| 36 |
+
focal_length_px, principal_point_px = _convert_ndc_to_pixels(fl, pp, original_size_wh)
|
| 37 |
+
|
| 38 |
+
# now scale and convert from pixels to NDC
|
| 39 |
+
image_size_wh_output = new_size_wh.float()
|
| 40 |
+
scale = (image_size_wh_output / original_size_wh).min(dim=-1, keepdim=True).values
|
| 41 |
+
focal_length_px_scaled = focal_length_px * scale
|
| 42 |
+
principal_point_px_scaled = principal_point_px * scale
|
| 43 |
+
|
| 44 |
+
focal_length_scaled, principal_point_scaled = _convert_pixels_to_ndc(
|
| 45 |
+
focal_length_px_scaled, principal_point_px_scaled, image_size_wh_output
|
| 46 |
+
)
|
| 47 |
+
return focal_length_scaled, principal_point_scaled
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _convert_ndc_to_pixels(focal_length: torch.Tensor, principal_point: torch.Tensor, image_size_wh: torch.Tensor):
|
| 51 |
+
half_image_size = image_size_wh / 2
|
| 52 |
+
rescale = half_image_size.min()
|
| 53 |
+
principal_point_px = half_image_size - principal_point * rescale
|
| 54 |
+
focal_length_px = focal_length * rescale
|
| 55 |
+
return focal_length_px, principal_point_px
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _convert_pixels_to_ndc(
|
| 59 |
+
focal_length_px: torch.Tensor, principal_point_px: torch.Tensor, image_size_wh: torch.Tensor
|
| 60 |
+
):
|
| 61 |
+
half_image_size = image_size_wh / 2
|
| 62 |
+
rescale = half_image_size.min()
|
| 63 |
+
principal_point = (half_image_size - principal_point_px) / rescale
|
| 64 |
+
focal_length = focal_length_px / rescale
|
| 65 |
+
return focal_length, principal_point
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def normalize_cameras(
|
| 69 |
+
cameras, compute_optical=True, first_camera=True, normalize_trans=True, scale=1.0, points=None, max_norm=False,
|
| 70 |
+
pose_mode="C2W"
|
| 71 |
+
):
|
| 72 |
+
"""
|
| 73 |
+
Normalizes cameras such that
|
| 74 |
+
(1) the optical axes point to the origin and the average distance to the origin is 1
|
| 75 |
+
(2) the first camera is the origin
|
| 76 |
+
(3) the translation vector is normalized
|
| 77 |
+
|
| 78 |
+
TODO: some transforms overlap with others. no need to do so many transforms
|
| 79 |
+
Args:
|
| 80 |
+
cameras (List[camera]).
|
| 81 |
+
"""
|
| 82 |
+
# Let distance from first camera to origin be unit
|
| 83 |
+
new_cameras = cameras.clone()
|
| 84 |
+
scale = 1.0
|
| 85 |
+
|
| 86 |
+
if compute_optical:
|
| 87 |
+
new_cameras, points = compute_optical_transform(new_cameras, points=points)
|
| 88 |
+
if first_camera:
|
| 89 |
+
new_cameras, points = first_camera_transform(new_cameras, points=points, pose_mode=pose_mode)
|
| 90 |
+
if normalize_trans:
|
| 91 |
+
new_cameras, points, scale = normalize_translation(new_cameras,
|
| 92 |
+
points=points, max_norm=max_norm)
|
| 93 |
+
return new_cameras, points, scale
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def compute_optical_transform(new_cameras, points=None):
|
| 97 |
+
"""
|
| 98 |
+
adapted from https://github.com/amyxlase/relpose-plus-plus
|
| 99 |
+
"""
|
| 100 |
+
|
| 101 |
+
new_transform = new_cameras.get_world_to_view_transform()
|
| 102 |
+
p_intersect, dist, p_line_intersect, pp, r = compute_optical_axis_intersection(new_cameras)
|
| 103 |
+
t = Translate(p_intersect)
|
| 104 |
+
scale = dist.squeeze()[0]
|
| 105 |
+
|
| 106 |
+
if points is not None:
|
| 107 |
+
points = t.inverse().transform_points(points)
|
| 108 |
+
points = points / scale
|
| 109 |
+
|
| 110 |
+
# Degenerate case
|
| 111 |
+
if scale == 0:
|
| 112 |
+
scale = torch.norm(new_cameras.T, dim=(0, 1))
|
| 113 |
+
scale = torch.sqrt(scale)
|
| 114 |
+
new_cameras.T = new_cameras.T / scale
|
| 115 |
+
else:
|
| 116 |
+
new_matrix = t.compose(new_transform).get_matrix()
|
| 117 |
+
new_cameras.R = new_matrix[:, :3, :3]
|
| 118 |
+
new_cameras.T = new_matrix[:, 3, :3] / scale
|
| 119 |
+
|
| 120 |
+
return new_cameras, points
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def compute_optical_axis_intersection(cameras):
|
| 124 |
+
centers = cameras.get_camera_center()
|
| 125 |
+
principal_points = cameras.principal_point
|
| 126 |
+
|
| 127 |
+
one_vec = torch.ones((len(cameras), 1))
|
| 128 |
+
optical_axis = torch.cat((principal_points, one_vec), -1)
|
| 129 |
+
|
| 130 |
+
pp = cameras.unproject_points(optical_axis, from_ndc=True, world_coordinates=True)
|
| 131 |
+
|
| 132 |
+
pp2 = pp[torch.arange(pp.shape[0]), torch.arange(pp.shape[0])]
|
| 133 |
+
|
| 134 |
+
directions = pp2 - centers
|
| 135 |
+
centers = centers.unsqueeze(0).unsqueeze(0)
|
| 136 |
+
directions = directions.unsqueeze(0).unsqueeze(0)
|
| 137 |
+
|
| 138 |
+
p_intersect, p_line_intersect, _, r = intersect_skew_line_groups(p=centers, r=directions, mask=None)
|
| 139 |
+
|
| 140 |
+
p_intersect = p_intersect.squeeze().unsqueeze(0)
|
| 141 |
+
dist = (p_intersect - centers).norm(dim=-1)
|
| 142 |
+
|
| 143 |
+
return p_intersect, dist, p_line_intersect, pp2, r
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def intersect_skew_line_groups(p, r, mask):
|
| 147 |
+
# p, r both of shape (B, N, n_intersected_lines, 3)
|
| 148 |
+
# mask of shape (B, N, n_intersected_lines)
|
| 149 |
+
p_intersect, r = intersect_skew_lines_high_dim(p, r, mask=mask)
|
| 150 |
+
_, p_line_intersect = _point_line_distance(p, r, p_intersect[..., None, :].expand_as(p))
|
| 151 |
+
intersect_dist_squared = ((p_line_intersect - p_intersect[..., None, :]) ** 2).sum(dim=-1)
|
| 152 |
+
return p_intersect, p_line_intersect, intersect_dist_squared, r
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def intersect_skew_lines_high_dim(p, r, mask=None):
|
| 156 |
+
# Implements https://en.wikipedia.org/wiki/Skew_lines In more than two dimensions
|
| 157 |
+
dim = p.shape[-1]
|
| 158 |
+
# make sure the heading vectors are l2-normed
|
| 159 |
+
if mask is None:
|
| 160 |
+
mask = torch.ones_like(p[..., 0])
|
| 161 |
+
r = torch.nn.functional.normalize(r, dim=-1)
|
| 162 |
+
|
| 163 |
+
eye = torch.eye(dim, device=p.device, dtype=p.dtype)[None, None]
|
| 164 |
+
I_min_cov = (eye - (r[..., None] * r[..., None, :])) * mask[..., None, None]
|
| 165 |
+
sum_proj = I_min_cov.matmul(p[..., None]).sum(dim=-3)
|
| 166 |
+
p_intersect = torch.linalg.lstsq(I_min_cov.sum(dim=-3), sum_proj).solution[..., 0]
|
| 167 |
+
|
| 168 |
+
if torch.any(torch.isnan(p_intersect)):
|
| 169 |
+
print(p_intersect)
|
| 170 |
+
raise ValueError(f"p_intersect is NaN")
|
| 171 |
+
|
| 172 |
+
return p_intersect, r
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
def _point_line_distance(p1, r1, p2):
|
| 176 |
+
df = p2 - p1
|
| 177 |
+
proj_vector = df - ((df * r1).sum(dim=-1, keepdim=True) * r1)
|
| 178 |
+
line_pt_nearest = p2 - proj_vector
|
| 179 |
+
d = (proj_vector).norm(dim=-1)
|
| 180 |
+
return d, line_pt_nearest
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
def first_camera_transform(cameras, rotation_only=False,
|
| 184 |
+
points=None, pose_mode="C2W"):
|
| 185 |
+
"""
|
| 186 |
+
Transform so that the first camera is the origin
|
| 187 |
+
"""
|
| 188 |
+
|
| 189 |
+
new_cameras = cameras.clone()
|
| 190 |
+
# new_transform = new_cameras.get_world_to_view_transform()
|
| 191 |
+
|
| 192 |
+
R = cameras.R
|
| 193 |
+
T = cameras.T
|
| 194 |
+
Tran_M = torch.cat([R, T.unsqueeze(-1)], dim=-1) # [B, 3, 4]
|
| 195 |
+
Tran_M = torch.cat([Tran_M,
|
| 196 |
+
torch.tensor([[[0, 0, 0, 1]]], device=Tran_M.device).expand(Tran_M.shape[0], -1, -1)], dim=1)
|
| 197 |
+
if pose_mode == "C2W":
|
| 198 |
+
Tran_M_new = (Tran_M[:1,...].inverse())@Tran_M
|
| 199 |
+
elif pose_mode == "W2C":
|
| 200 |
+
Tran_M_new = Tran_M@(Tran_M[:1,...].inverse())
|
| 201 |
+
|
| 202 |
+
if False:
|
| 203 |
+
tR = Rotate(new_cameras.R[0].unsqueeze(0))
|
| 204 |
+
if rotation_only:
|
| 205 |
+
t = tR.inverse()
|
| 206 |
+
else:
|
| 207 |
+
tT = Translate(new_cameras.T[0].unsqueeze(0))
|
| 208 |
+
t = tR.compose(tT).inverse()
|
| 209 |
+
|
| 210 |
+
if points is not None:
|
| 211 |
+
points = t.inverse().transform_points(points)
|
| 212 |
+
|
| 213 |
+
if pose_mode == "C2W":
|
| 214 |
+
new_matrix = new_transform.compose(t).get_matrix()
|
| 215 |
+
else:
|
| 216 |
+
import ipdb; ipdb.set_trace()
|
| 217 |
+
new_matrix = t.compose(new_transform).get_matrix()
|
| 218 |
+
|
| 219 |
+
new_cameras.R = Tran_M_new[:, :3, :3]
|
| 220 |
+
new_cameras.T = Tran_M_new[:, :3, 3]
|
| 221 |
+
|
| 222 |
+
return new_cameras, points
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def normalize_translation(new_cameras, points=None, max_norm=False):
|
| 226 |
+
t_gt = new_cameras.T.clone()
|
| 227 |
+
t_gt = t_gt[1:, :]
|
| 228 |
+
|
| 229 |
+
if max_norm:
|
| 230 |
+
t_gt_norm = torch.norm(t_gt, dim=(-1))
|
| 231 |
+
t_gt_scale = t_gt_norm.max()
|
| 232 |
+
if t_gt_norm.max() < 0.001:
|
| 233 |
+
t_gt_scale = torch.ones_like(t_gt_scale)
|
| 234 |
+
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
| 235 |
+
else:
|
| 236 |
+
t_gt_norm = torch.norm(t_gt, dim=(0, 1))
|
| 237 |
+
t_gt_scale = t_gt_norm / math.sqrt(len(t_gt))
|
| 238 |
+
t_gt_scale = t_gt_scale / 2
|
| 239 |
+
if t_gt_norm.max() < 0.001:
|
| 240 |
+
t_gt_scale = torch.ones_like(t_gt_scale)
|
| 241 |
+
t_gt_scale = t_gt_scale.clamp(min=0.01, max=1e5)
|
| 242 |
+
|
| 243 |
+
new_cameras.T = new_cameras.T / t_gt_scale
|
| 244 |
+
|
| 245 |
+
if points is not None:
|
| 246 |
+
points = points / t_gt_scale
|
| 247 |
+
|
| 248 |
+
return new_cameras, points, t_gt_scale
|
models/SpaTrackV2/models/depth_refiner/backbone.py
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ---------------------------------------------------------------
|
| 2 |
+
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This work is licensed under the NVIDIA Source Code License
|
| 5 |
+
# ---------------------------------------------------------------
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
from timm.layers import DropPath, to_2tuple, trunc_normal_
|
| 12 |
+
from timm.models import register_model
|
| 13 |
+
from timm.models.vision_transformer import _cfg
|
| 14 |
+
import math
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
|
| 19 |
+
super().__init__()
|
| 20 |
+
out_features = out_features or in_features
|
| 21 |
+
hidden_features = hidden_features or in_features
|
| 22 |
+
self.fc1 = nn.Linear(in_features, hidden_features)
|
| 23 |
+
self.dwconv = DWConv(hidden_features)
|
| 24 |
+
self.act = act_layer()
|
| 25 |
+
self.fc2 = nn.Linear(hidden_features, out_features)
|
| 26 |
+
self.drop = nn.Dropout(drop)
|
| 27 |
+
|
| 28 |
+
self.apply(self._init_weights)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
elif isinstance(m, nn.LayerNorm):
|
| 36 |
+
nn.init.constant_(m.bias, 0)
|
| 37 |
+
nn.init.constant_(m.weight, 1.0)
|
| 38 |
+
elif isinstance(m, nn.Conv2d):
|
| 39 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 40 |
+
fan_out //= m.groups
|
| 41 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 42 |
+
if m.bias is not None:
|
| 43 |
+
m.bias.data.zero_()
|
| 44 |
+
|
| 45 |
+
def forward(self, x, H, W):
|
| 46 |
+
x = self.fc1(x)
|
| 47 |
+
x = self.dwconv(x, H, W)
|
| 48 |
+
x = self.act(x)
|
| 49 |
+
x = self.drop(x)
|
| 50 |
+
x = self.fc2(x)
|
| 51 |
+
x = self.drop(x)
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class Attention(nn.Module):
|
| 56 |
+
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
|
| 57 |
+
super().__init__()
|
| 58 |
+
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
| 59 |
+
|
| 60 |
+
self.dim = dim
|
| 61 |
+
self.num_heads = num_heads
|
| 62 |
+
head_dim = dim // num_heads
|
| 63 |
+
self.scale = qk_scale or head_dim ** -0.5
|
| 64 |
+
|
| 65 |
+
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
| 66 |
+
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
| 67 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 68 |
+
self.proj = nn.Linear(dim, dim)
|
| 69 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 70 |
+
|
| 71 |
+
self.sr_ratio = sr_ratio
|
| 72 |
+
if sr_ratio > 1:
|
| 73 |
+
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
| 74 |
+
self.norm = nn.LayerNorm(dim)
|
| 75 |
+
|
| 76 |
+
self.apply(self._init_weights)
|
| 77 |
+
|
| 78 |
+
def _init_weights(self, m):
|
| 79 |
+
if isinstance(m, nn.Linear):
|
| 80 |
+
trunc_normal_(m.weight, std=.02)
|
| 81 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 82 |
+
nn.init.constant_(m.bias, 0)
|
| 83 |
+
elif isinstance(m, nn.LayerNorm):
|
| 84 |
+
nn.init.constant_(m.bias, 0)
|
| 85 |
+
nn.init.constant_(m.weight, 1.0)
|
| 86 |
+
elif isinstance(m, nn.Conv2d):
|
| 87 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 88 |
+
fan_out //= m.groups
|
| 89 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 90 |
+
if m.bias is not None:
|
| 91 |
+
m.bias.data.zero_()
|
| 92 |
+
|
| 93 |
+
def forward(self, x, H, W):
|
| 94 |
+
B, N, C = x.shape
|
| 95 |
+
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
| 96 |
+
|
| 97 |
+
if self.sr_ratio > 1:
|
| 98 |
+
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
| 99 |
+
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
| 100 |
+
x_ = self.norm(x_)
|
| 101 |
+
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 102 |
+
else:
|
| 103 |
+
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 104 |
+
k, v = kv[0], kv[1]
|
| 105 |
+
|
| 106 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
| 107 |
+
attn = attn.softmax(dim=-1)
|
| 108 |
+
attn = self.attn_drop(attn)
|
| 109 |
+
|
| 110 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 111 |
+
x = self.proj(x)
|
| 112 |
+
x = self.proj_drop(x)
|
| 113 |
+
|
| 114 |
+
return x
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class Block(nn.Module):
|
| 118 |
+
|
| 119 |
+
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
| 120 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
|
| 121 |
+
super().__init__()
|
| 122 |
+
self.norm1 = norm_layer(dim)
|
| 123 |
+
self.attn = Attention(
|
| 124 |
+
dim,
|
| 125 |
+
num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 126 |
+
attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)
|
| 127 |
+
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
| 128 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
| 129 |
+
self.norm2 = norm_layer(dim)
|
| 130 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 131 |
+
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
| 132 |
+
|
| 133 |
+
self.apply(self._init_weights)
|
| 134 |
+
|
| 135 |
+
def _init_weights(self, m):
|
| 136 |
+
if isinstance(m, nn.Linear):
|
| 137 |
+
trunc_normal_(m.weight, std=.02)
|
| 138 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 139 |
+
nn.init.constant_(m.bias, 0)
|
| 140 |
+
elif isinstance(m, nn.LayerNorm):
|
| 141 |
+
nn.init.constant_(m.bias, 0)
|
| 142 |
+
nn.init.constant_(m.weight, 1.0)
|
| 143 |
+
elif isinstance(m, nn.Conv2d):
|
| 144 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 145 |
+
fan_out //= m.groups
|
| 146 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 147 |
+
if m.bias is not None:
|
| 148 |
+
m.bias.data.zero_()
|
| 149 |
+
|
| 150 |
+
def forward(self, x, H, W):
|
| 151 |
+
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
|
| 152 |
+
x = x + self.drop_path(self.mlp(self.norm2(x), H, W))
|
| 153 |
+
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class OverlapPatchEmbed(nn.Module):
|
| 158 |
+
""" Image to Patch Embedding
|
| 159 |
+
"""
|
| 160 |
+
|
| 161 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
| 162 |
+
super().__init__()
|
| 163 |
+
img_size = to_2tuple(img_size)
|
| 164 |
+
patch_size = to_2tuple(patch_size)
|
| 165 |
+
|
| 166 |
+
self.img_size = img_size
|
| 167 |
+
self.patch_size = patch_size
|
| 168 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
| 169 |
+
self.num_patches = self.H * self.W
|
| 170 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 171 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 172 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 173 |
+
|
| 174 |
+
self.apply(self._init_weights)
|
| 175 |
+
|
| 176 |
+
def _init_weights(self, m):
|
| 177 |
+
if isinstance(m, nn.Linear):
|
| 178 |
+
trunc_normal_(m.weight, std=.02)
|
| 179 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 180 |
+
nn.init.constant_(m.bias, 0)
|
| 181 |
+
elif isinstance(m, nn.LayerNorm):
|
| 182 |
+
nn.init.constant_(m.bias, 0)
|
| 183 |
+
nn.init.constant_(m.weight, 1.0)
|
| 184 |
+
elif isinstance(m, nn.Conv2d):
|
| 185 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 186 |
+
fan_out //= m.groups
|
| 187 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 188 |
+
if m.bias is not None:
|
| 189 |
+
m.bias.data.zero_()
|
| 190 |
+
|
| 191 |
+
def forward(self, x):
|
| 192 |
+
x = self.proj(x)
|
| 193 |
+
_, _, H, W = x.shape
|
| 194 |
+
x = x.flatten(2).transpose(1, 2)
|
| 195 |
+
x = self.norm(x)
|
| 196 |
+
|
| 197 |
+
return x, H, W
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class OverlapPatchEmbed43(nn.Module):
|
| 203 |
+
""" Image to Patch Embedding
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(self, img_size=224, patch_size=7, stride=4, in_chans=3, embed_dim=768):
|
| 207 |
+
super().__init__()
|
| 208 |
+
img_size = to_2tuple(img_size)
|
| 209 |
+
patch_size = to_2tuple(patch_size)
|
| 210 |
+
|
| 211 |
+
self.img_size = img_size
|
| 212 |
+
self.patch_size = patch_size
|
| 213 |
+
self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
|
| 214 |
+
self.num_patches = self.H * self.W
|
| 215 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride,
|
| 216 |
+
padding=(patch_size[0] // 2, patch_size[1] // 2))
|
| 217 |
+
self.norm = nn.LayerNorm(embed_dim)
|
| 218 |
+
|
| 219 |
+
self.apply(self._init_weights)
|
| 220 |
+
|
| 221 |
+
def _init_weights(self, m):
|
| 222 |
+
if isinstance(m, nn.Linear):
|
| 223 |
+
trunc_normal_(m.weight, std=.02)
|
| 224 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 225 |
+
nn.init.constant_(m.bias, 0)
|
| 226 |
+
elif isinstance(m, nn.LayerNorm):
|
| 227 |
+
nn.init.constant_(m.bias, 0)
|
| 228 |
+
nn.init.constant_(m.weight, 1.0)
|
| 229 |
+
elif isinstance(m, nn.Conv2d):
|
| 230 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 231 |
+
fan_out //= m.groups
|
| 232 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 233 |
+
if m.bias is not None:
|
| 234 |
+
m.bias.data.zero_()
|
| 235 |
+
|
| 236 |
+
def forward(self, x):
|
| 237 |
+
if x.shape[1]==4:
|
| 238 |
+
x = self.proj_4c(x)
|
| 239 |
+
else:
|
| 240 |
+
x = self.proj(x)
|
| 241 |
+
_, _, H, W = x.shape
|
| 242 |
+
x = x.flatten(2).transpose(1, 2)
|
| 243 |
+
x = self.norm(x)
|
| 244 |
+
|
| 245 |
+
return x, H, W
|
| 246 |
+
|
| 247 |
+
class MixVisionTransformer(nn.Module):
|
| 248 |
+
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
|
| 249 |
+
num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
|
| 250 |
+
attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
|
| 251 |
+
depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1]):
|
| 252 |
+
super().__init__()
|
| 253 |
+
self.num_classes = num_classes
|
| 254 |
+
self.depths = depths
|
| 255 |
+
|
| 256 |
+
# patch_embed 43
|
| 257 |
+
self.patch_embed1 = OverlapPatchEmbed(img_size=img_size, patch_size=7, stride=4, in_chans=in_chans,
|
| 258 |
+
embed_dim=embed_dims[0])
|
| 259 |
+
self.patch_embed2 = OverlapPatchEmbed(img_size=img_size // 4, patch_size=3, stride=2, in_chans=embed_dims[0],
|
| 260 |
+
embed_dim=embed_dims[1])
|
| 261 |
+
self.patch_embed3 = OverlapPatchEmbed(img_size=img_size // 8, patch_size=3, stride=2, in_chans=embed_dims[1],
|
| 262 |
+
embed_dim=embed_dims[2])
|
| 263 |
+
self.patch_embed4 = OverlapPatchEmbed(img_size=img_size // 16, patch_size=3, stride=2, in_chans=embed_dims[2],
|
| 264 |
+
embed_dim=embed_dims[3])
|
| 265 |
+
|
| 266 |
+
# transformer encoder
|
| 267 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
|
| 268 |
+
cur = 0
|
| 269 |
+
self.block1 = nn.ModuleList([Block(
|
| 270 |
+
dim=embed_dims[0], num_heads=num_heads[0], mlp_ratio=mlp_ratios[0], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 271 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 272 |
+
sr_ratio=sr_ratios[0])
|
| 273 |
+
for i in range(depths[0])])
|
| 274 |
+
self.norm1 = norm_layer(embed_dims[0])
|
| 275 |
+
|
| 276 |
+
cur += depths[0]
|
| 277 |
+
self.block2 = nn.ModuleList([Block(
|
| 278 |
+
dim=embed_dims[1], num_heads=num_heads[1], mlp_ratio=mlp_ratios[1], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 279 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 280 |
+
sr_ratio=sr_ratios[1])
|
| 281 |
+
for i in range(depths[1])])
|
| 282 |
+
self.norm2 = norm_layer(embed_dims[1])
|
| 283 |
+
|
| 284 |
+
cur += depths[1]
|
| 285 |
+
self.block3 = nn.ModuleList([Block(
|
| 286 |
+
dim=embed_dims[2], num_heads=num_heads[2], mlp_ratio=mlp_ratios[2], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 287 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 288 |
+
sr_ratio=sr_ratios[2])
|
| 289 |
+
for i in range(depths[2])])
|
| 290 |
+
self.norm3 = norm_layer(embed_dims[2])
|
| 291 |
+
|
| 292 |
+
cur += depths[2]
|
| 293 |
+
self.block4 = nn.ModuleList([Block(
|
| 294 |
+
dim=embed_dims[3], num_heads=num_heads[3], mlp_ratio=mlp_ratios[3], qkv_bias=qkv_bias, qk_scale=qk_scale,
|
| 295 |
+
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + i], norm_layer=norm_layer,
|
| 296 |
+
sr_ratio=sr_ratios[3])
|
| 297 |
+
for i in range(depths[3])])
|
| 298 |
+
self.norm4 = norm_layer(embed_dims[3])
|
| 299 |
+
|
| 300 |
+
# classification head
|
| 301 |
+
# self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()
|
| 302 |
+
|
| 303 |
+
self.apply(self._init_weights)
|
| 304 |
+
|
| 305 |
+
def _init_weights(self, m):
|
| 306 |
+
if isinstance(m, nn.Linear):
|
| 307 |
+
trunc_normal_(m.weight, std=.02)
|
| 308 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 309 |
+
nn.init.constant_(m.bias, 0)
|
| 310 |
+
elif isinstance(m, nn.LayerNorm):
|
| 311 |
+
nn.init.constant_(m.bias, 0)
|
| 312 |
+
nn.init.constant_(m.weight, 1.0)
|
| 313 |
+
elif isinstance(m, nn.Conv2d):
|
| 314 |
+
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
| 315 |
+
fan_out //= m.groups
|
| 316 |
+
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
| 317 |
+
if m.bias is not None:
|
| 318 |
+
m.bias.data.zero_()
|
| 319 |
+
|
| 320 |
+
def init_weights(self, pretrained=None):
|
| 321 |
+
if isinstance(pretrained, str):
|
| 322 |
+
logger = get_root_logger()
|
| 323 |
+
load_checkpoint(self, pretrained, map_location='cpu', strict=False, logger=logger)
|
| 324 |
+
|
| 325 |
+
def reset_drop_path(self, drop_path_rate):
|
| 326 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(self.depths))]
|
| 327 |
+
cur = 0
|
| 328 |
+
for i in range(self.depths[0]):
|
| 329 |
+
self.block1[i].drop_path.drop_prob = dpr[cur + i]
|
| 330 |
+
|
| 331 |
+
cur += self.depths[0]
|
| 332 |
+
for i in range(self.depths[1]):
|
| 333 |
+
self.block2[i].drop_path.drop_prob = dpr[cur + i]
|
| 334 |
+
|
| 335 |
+
cur += self.depths[1]
|
| 336 |
+
for i in range(self.depths[2]):
|
| 337 |
+
self.block3[i].drop_path.drop_prob = dpr[cur + i]
|
| 338 |
+
|
| 339 |
+
cur += self.depths[2]
|
| 340 |
+
for i in range(self.depths[3]):
|
| 341 |
+
self.block4[i].drop_path.drop_prob = dpr[cur + i]
|
| 342 |
+
|
| 343 |
+
def freeze_patch_emb(self):
|
| 344 |
+
self.patch_embed1.requires_grad = False
|
| 345 |
+
|
| 346 |
+
@torch.jit.ignore
|
| 347 |
+
def no_weight_decay(self):
|
| 348 |
+
return {'pos_embed1', 'pos_embed2', 'pos_embed3', 'pos_embed4', 'cls_token'} # has pos_embed may be better
|
| 349 |
+
|
| 350 |
+
def get_classifier(self):
|
| 351 |
+
return self.head
|
| 352 |
+
|
| 353 |
+
def reset_classifier(self, num_classes, global_pool=''):
|
| 354 |
+
self.num_classes = num_classes
|
| 355 |
+
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
| 356 |
+
|
| 357 |
+
def forward_features(self, x):
|
| 358 |
+
B = x.shape[0]
|
| 359 |
+
outs = []
|
| 360 |
+
|
| 361 |
+
# stage 1
|
| 362 |
+
x, H, W = self.patch_embed1(x)
|
| 363 |
+
for i, blk in enumerate(self.block1):
|
| 364 |
+
x = blk(x, H, W)
|
| 365 |
+
x = self.norm1(x)
|
| 366 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 367 |
+
outs.append(x)
|
| 368 |
+
|
| 369 |
+
# stage 2
|
| 370 |
+
x, H, W = self.patch_embed2(x)
|
| 371 |
+
for i, blk in enumerate(self.block2):
|
| 372 |
+
x = blk(x, H, W)
|
| 373 |
+
x = self.norm2(x)
|
| 374 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 375 |
+
outs.append(x)
|
| 376 |
+
|
| 377 |
+
# stage 3
|
| 378 |
+
x, H, W = self.patch_embed3(x)
|
| 379 |
+
for i, blk in enumerate(self.block3):
|
| 380 |
+
x = blk(x, H, W)
|
| 381 |
+
x = self.norm3(x)
|
| 382 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 383 |
+
outs.append(x)
|
| 384 |
+
|
| 385 |
+
# stage 4
|
| 386 |
+
x, H, W = self.patch_embed4(x)
|
| 387 |
+
for i, blk in enumerate(self.block4):
|
| 388 |
+
x = blk(x, H, W)
|
| 389 |
+
x = self.norm4(x)
|
| 390 |
+
x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()
|
| 391 |
+
outs.append(x)
|
| 392 |
+
|
| 393 |
+
return outs
|
| 394 |
+
|
| 395 |
+
def forward(self, x):
|
| 396 |
+
if x.dim() == 5:
|
| 397 |
+
x = x.reshape(x.shape[0]*x.shape[1],x.shape[2],x.shape[3],x.shape[4])
|
| 398 |
+
x = self.forward_features(x)
|
| 399 |
+
# x = self.head(x)
|
| 400 |
+
|
| 401 |
+
return x
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
class DWConv(nn.Module):
|
| 405 |
+
def __init__(self, dim=768):
|
| 406 |
+
super(DWConv, self).__init__()
|
| 407 |
+
self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)
|
| 408 |
+
|
| 409 |
+
def forward(self, x, H, W):
|
| 410 |
+
B, N, C = x.shape
|
| 411 |
+
x = x.transpose(1, 2).view(B, C, H, W)
|
| 412 |
+
x = self.dwconv(x)
|
| 413 |
+
x = x.flatten(2).transpose(1, 2)
|
| 414 |
+
|
| 415 |
+
return x
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
|
| 419 |
+
#@BACKBONES.register_module()
|
| 420 |
+
class mit_b0(MixVisionTransformer):
|
| 421 |
+
def __init__(self, **kwargs):
|
| 422 |
+
super(mit_b0, self).__init__(
|
| 423 |
+
patch_size=4, embed_dims=[32, 64, 160, 256], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 424 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 425 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
#@BACKBONES.register_module()
|
| 429 |
+
class mit_b1(MixVisionTransformer):
|
| 430 |
+
def __init__(self, **kwargs):
|
| 431 |
+
super(mit_b1, self).__init__(
|
| 432 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 433 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
|
| 434 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
#@BACKBONES.register_module()
|
| 438 |
+
class mit_b2(MixVisionTransformer):
|
| 439 |
+
def __init__(self, **kwargs):
|
| 440 |
+
super(mit_b2, self).__init__(
|
| 441 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 442 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1],
|
| 443 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
#@BACKBONES.register_module()
|
| 447 |
+
class mit_b3(MixVisionTransformer):
|
| 448 |
+
def __init__(self, **kwargs):
|
| 449 |
+
super(mit_b3, self).__init__(
|
| 450 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 451 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
|
| 452 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
#@BACKBONES.register_module()
|
| 456 |
+
class mit_b4(MixVisionTransformer):
|
| 457 |
+
def __init__(self, **kwargs):
|
| 458 |
+
super(mit_b4, self).__init__(
|
| 459 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 460 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
|
| 461 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
#@BACKBONES.register_module()
|
| 465 |
+
class mit_b5(MixVisionTransformer):
|
| 466 |
+
def __init__(self, **kwargs):
|
| 467 |
+
super(mit_b5, self).__init__(
|
| 468 |
+
patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[4, 4, 4, 4],
|
| 469 |
+
qkv_bias=True, norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 6, 40, 3], sr_ratios=[8, 4, 2, 1],
|
| 470 |
+
drop_rate=0.0, drop_path_rate=0.1)
|
| 471 |
+
|
| 472 |
+
|
models/SpaTrackV2/models/depth_refiner/decode_head.py
ADDED
|
@@ -0,0 +1,619 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABCMeta, abstractmethod
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
# from mmcv.cnn import normal_init
|
| 7 |
+
# from mmcv.runner import auto_fp16, force_fp32
|
| 8 |
+
|
| 9 |
+
# from mmseg.core import build_pixel_sampler
|
| 10 |
+
# from mmseg.ops import resize
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class BaseDecodeHead(nn.Module, metaclass=ABCMeta):
|
| 14 |
+
"""Base class for BaseDecodeHead.
|
| 15 |
+
|
| 16 |
+
Args:
|
| 17 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 18 |
+
channels (int): Channels after modules, before conv_seg.
|
| 19 |
+
num_classes (int): Number of classes.
|
| 20 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 21 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 22 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 23 |
+
act_cfg (dict): Config of activation layers.
|
| 24 |
+
Default: dict(type='ReLU')
|
| 25 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 26 |
+
input_transform (str|None): Transformation type of input features.
|
| 27 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 28 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 29 |
+
same size as first one and than concat together.
|
| 30 |
+
Usually used in FCN head of HRNet.
|
| 31 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 32 |
+
a list and passed into decode head.
|
| 33 |
+
None: Only one select feature map is allowed.
|
| 34 |
+
Default: None.
|
| 35 |
+
loss_decode (dict): Config of decode loss.
|
| 36 |
+
Default: dict(type='CrossEntropyLoss').
|
| 37 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 38 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 39 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 40 |
+
Default: None.
|
| 41 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 42 |
+
Default: False.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(self,
|
| 46 |
+
in_channels,
|
| 47 |
+
channels,
|
| 48 |
+
*,
|
| 49 |
+
num_classes,
|
| 50 |
+
dropout_ratio=0.1,
|
| 51 |
+
conv_cfg=None,
|
| 52 |
+
norm_cfg=None,
|
| 53 |
+
act_cfg=dict(type='ReLU'),
|
| 54 |
+
in_index=-1,
|
| 55 |
+
input_transform=None,
|
| 56 |
+
loss_decode=dict(
|
| 57 |
+
type='CrossEntropyLoss',
|
| 58 |
+
use_sigmoid=False,
|
| 59 |
+
loss_weight=1.0),
|
| 60 |
+
decoder_params=None,
|
| 61 |
+
ignore_index=255,
|
| 62 |
+
sampler=None,
|
| 63 |
+
align_corners=False):
|
| 64 |
+
super(BaseDecodeHead, self).__init__()
|
| 65 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 66 |
+
self.channels = channels
|
| 67 |
+
self.num_classes = num_classes
|
| 68 |
+
self.dropout_ratio = dropout_ratio
|
| 69 |
+
self.conv_cfg = conv_cfg
|
| 70 |
+
self.norm_cfg = norm_cfg
|
| 71 |
+
self.act_cfg = act_cfg
|
| 72 |
+
self.in_index = in_index
|
| 73 |
+
self.ignore_index = ignore_index
|
| 74 |
+
self.align_corners = align_corners
|
| 75 |
+
|
| 76 |
+
if sampler is not None:
|
| 77 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 78 |
+
else:
|
| 79 |
+
self.sampler = None
|
| 80 |
+
|
| 81 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 82 |
+
if dropout_ratio > 0:
|
| 83 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 84 |
+
else:
|
| 85 |
+
self.dropout = None
|
| 86 |
+
self.fp16_enabled = False
|
| 87 |
+
|
| 88 |
+
def extra_repr(self):
|
| 89 |
+
"""Extra repr."""
|
| 90 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 91 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 92 |
+
f'align_corners={self.align_corners}'
|
| 93 |
+
return s
|
| 94 |
+
|
| 95 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 96 |
+
"""Check and initialize input transforms.
|
| 97 |
+
|
| 98 |
+
The in_channels, in_index and input_transform must match.
|
| 99 |
+
Specifically, when input_transform is None, only single feature map
|
| 100 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 101 |
+
When input_transform
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 105 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 106 |
+
input_transform (str|None): Transformation type of input features.
|
| 107 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 108 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 109 |
+
same size as first one and than concat together.
|
| 110 |
+
Usually used in FCN head of HRNet.
|
| 111 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 112 |
+
a list and passed into decode head.
|
| 113 |
+
None: Only one select feature map is allowed.
|
| 114 |
+
"""
|
| 115 |
+
|
| 116 |
+
if input_transform is not None:
|
| 117 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 118 |
+
self.input_transform = input_transform
|
| 119 |
+
self.in_index = in_index
|
| 120 |
+
if input_transform is not None:
|
| 121 |
+
assert isinstance(in_channels, (list, tuple))
|
| 122 |
+
assert isinstance(in_index, (list, tuple))
|
| 123 |
+
assert len(in_channels) == len(in_index)
|
| 124 |
+
if input_transform == 'resize_concat':
|
| 125 |
+
self.in_channels = sum(in_channels)
|
| 126 |
+
else:
|
| 127 |
+
self.in_channels = in_channels
|
| 128 |
+
else:
|
| 129 |
+
assert isinstance(in_channels, int)
|
| 130 |
+
assert isinstance(in_index, int)
|
| 131 |
+
self.in_channels = in_channels
|
| 132 |
+
|
| 133 |
+
def init_weights(self):
|
| 134 |
+
"""Initialize weights of classification layer."""
|
| 135 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 136 |
+
|
| 137 |
+
def _transform_inputs(self, inputs):
|
| 138 |
+
"""Transform inputs for decoder.
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 142 |
+
|
| 143 |
+
Returns:
|
| 144 |
+
Tensor: The transformed inputs
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
if self.input_transform == 'resize_concat':
|
| 148 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 149 |
+
upsampled_inputs = [
|
| 150 |
+
resize(
|
| 151 |
+
input=x,
|
| 152 |
+
size=inputs[0].shape[2:],
|
| 153 |
+
mode='bilinear',
|
| 154 |
+
align_corners=self.align_corners) for x in inputs
|
| 155 |
+
]
|
| 156 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 157 |
+
elif self.input_transform == 'multiple_select':
|
| 158 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 159 |
+
else:
|
| 160 |
+
inputs = inputs[self.in_index]
|
| 161 |
+
|
| 162 |
+
return inputs
|
| 163 |
+
|
| 164 |
+
# @auto_fp16()
|
| 165 |
+
@abstractmethod
|
| 166 |
+
def forward(self, inputs):
|
| 167 |
+
"""Placeholder of forward function."""
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg):
|
| 171 |
+
"""Forward function for training.
|
| 172 |
+
Args:
|
| 173 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 174 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 175 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 176 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 177 |
+
For details on the values of these keys see
|
| 178 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 179 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 180 |
+
used if the architecture supports semantic segmentation task.
|
| 181 |
+
train_cfg (dict): The training config.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 185 |
+
"""
|
| 186 |
+
seg_logits = self.forward(inputs)
|
| 187 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 188 |
+
return losses
|
| 189 |
+
|
| 190 |
+
def forward_test(self, inputs, img_metas, test_cfg):
|
| 191 |
+
"""Forward function for testing.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 195 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 196 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 197 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 198 |
+
For details on the values of these keys see
|
| 199 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 200 |
+
test_cfg (dict): The testing config.
|
| 201 |
+
|
| 202 |
+
Returns:
|
| 203 |
+
Tensor: Output segmentation map.
|
| 204 |
+
"""
|
| 205 |
+
return self.forward(inputs)
|
| 206 |
+
|
| 207 |
+
def cls_seg(self, feat):
|
| 208 |
+
"""Classify each pixel."""
|
| 209 |
+
if self.dropout is not None:
|
| 210 |
+
feat = self.dropout(feat)
|
| 211 |
+
output = self.conv_seg(feat)
|
| 212 |
+
return output
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class BaseDecodeHead_clips(nn.Module, metaclass=ABCMeta):
|
| 216 |
+
"""Base class for BaseDecodeHead_clips.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 220 |
+
channels (int): Channels after modules, before conv_seg.
|
| 221 |
+
num_classes (int): Number of classes.
|
| 222 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 223 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 224 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 225 |
+
act_cfg (dict): Config of activation layers.
|
| 226 |
+
Default: dict(type='ReLU')
|
| 227 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 228 |
+
input_transform (str|None): Transformation type of input features.
|
| 229 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 230 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 231 |
+
same size as first one and than concat together.
|
| 232 |
+
Usually used in FCN head of HRNet.
|
| 233 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 234 |
+
a list and passed into decode head.
|
| 235 |
+
None: Only one select feature map is allowed.
|
| 236 |
+
Default: None.
|
| 237 |
+
loss_decode (dict): Config of decode loss.
|
| 238 |
+
Default: dict(type='CrossEntropyLoss').
|
| 239 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 240 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 241 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 242 |
+
Default: None.
|
| 243 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 244 |
+
Default: False.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
def __init__(self,
|
| 248 |
+
in_channels,
|
| 249 |
+
channels,
|
| 250 |
+
*,
|
| 251 |
+
num_classes,
|
| 252 |
+
dropout_ratio=0.1,
|
| 253 |
+
conv_cfg=None,
|
| 254 |
+
norm_cfg=None,
|
| 255 |
+
act_cfg=dict(type='ReLU'),
|
| 256 |
+
in_index=-1,
|
| 257 |
+
input_transform=None,
|
| 258 |
+
loss_decode=dict(
|
| 259 |
+
type='CrossEntropyLoss',
|
| 260 |
+
use_sigmoid=False,
|
| 261 |
+
loss_weight=1.0),
|
| 262 |
+
decoder_params=None,
|
| 263 |
+
ignore_index=255,
|
| 264 |
+
sampler=None,
|
| 265 |
+
align_corners=False,
|
| 266 |
+
num_clips=5):
|
| 267 |
+
super(BaseDecodeHead_clips, self).__init__()
|
| 268 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 269 |
+
self.channels = channels
|
| 270 |
+
self.num_classes = num_classes
|
| 271 |
+
self.dropout_ratio = dropout_ratio
|
| 272 |
+
self.conv_cfg = conv_cfg
|
| 273 |
+
self.norm_cfg = norm_cfg
|
| 274 |
+
self.act_cfg = act_cfg
|
| 275 |
+
self.in_index = in_index
|
| 276 |
+
self.ignore_index = ignore_index
|
| 277 |
+
self.align_corners = align_corners
|
| 278 |
+
self.num_clips=num_clips
|
| 279 |
+
|
| 280 |
+
if sampler is not None:
|
| 281 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 282 |
+
else:
|
| 283 |
+
self.sampler = None
|
| 284 |
+
|
| 285 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 286 |
+
if dropout_ratio > 0:
|
| 287 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 288 |
+
else:
|
| 289 |
+
self.dropout = None
|
| 290 |
+
self.fp16_enabled = False
|
| 291 |
+
|
| 292 |
+
def extra_repr(self):
|
| 293 |
+
"""Extra repr."""
|
| 294 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 295 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 296 |
+
f'align_corners={self.align_corners}'
|
| 297 |
+
return s
|
| 298 |
+
|
| 299 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 300 |
+
"""Check and initialize input transforms.
|
| 301 |
+
|
| 302 |
+
The in_channels, in_index and input_transform must match.
|
| 303 |
+
Specifically, when input_transform is None, only single feature map
|
| 304 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 305 |
+
When input_transform
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 309 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 310 |
+
input_transform (str|None): Transformation type of input features.
|
| 311 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 312 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 313 |
+
same size as first one and than concat together.
|
| 314 |
+
Usually used in FCN head of HRNet.
|
| 315 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 316 |
+
a list and passed into decode head.
|
| 317 |
+
None: Only one select feature map is allowed.
|
| 318 |
+
"""
|
| 319 |
+
|
| 320 |
+
if input_transform is not None:
|
| 321 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 322 |
+
self.input_transform = input_transform
|
| 323 |
+
self.in_index = in_index
|
| 324 |
+
if input_transform is not None:
|
| 325 |
+
assert isinstance(in_channels, (list, tuple))
|
| 326 |
+
assert isinstance(in_index, (list, tuple))
|
| 327 |
+
assert len(in_channels) == len(in_index)
|
| 328 |
+
if input_transform == 'resize_concat':
|
| 329 |
+
self.in_channels = sum(in_channels)
|
| 330 |
+
else:
|
| 331 |
+
self.in_channels = in_channels
|
| 332 |
+
else:
|
| 333 |
+
assert isinstance(in_channels, int)
|
| 334 |
+
assert isinstance(in_index, int)
|
| 335 |
+
self.in_channels = in_channels
|
| 336 |
+
|
| 337 |
+
def init_weights(self):
|
| 338 |
+
"""Initialize weights of classification layer."""
|
| 339 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 340 |
+
|
| 341 |
+
def _transform_inputs(self, inputs):
|
| 342 |
+
"""Transform inputs for decoder.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
Tensor: The transformed inputs
|
| 349 |
+
"""
|
| 350 |
+
|
| 351 |
+
if self.input_transform == 'resize_concat':
|
| 352 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 353 |
+
upsampled_inputs = [
|
| 354 |
+
resize(
|
| 355 |
+
input=x,
|
| 356 |
+
size=inputs[0].shape[2:],
|
| 357 |
+
mode='bilinear',
|
| 358 |
+
align_corners=self.align_corners) for x in inputs
|
| 359 |
+
]
|
| 360 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 361 |
+
elif self.input_transform == 'multiple_select':
|
| 362 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 363 |
+
else:
|
| 364 |
+
inputs = inputs[self.in_index]
|
| 365 |
+
|
| 366 |
+
return inputs
|
| 367 |
+
|
| 368 |
+
# @auto_fp16()
|
| 369 |
+
@abstractmethod
|
| 370 |
+
def forward(self, inputs):
|
| 371 |
+
"""Placeholder of forward function."""
|
| 372 |
+
pass
|
| 373 |
+
|
| 374 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips):
|
| 375 |
+
"""Forward function for training.
|
| 376 |
+
Args:
|
| 377 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 378 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 379 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 380 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 381 |
+
For details on the values of these keys see
|
| 382 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 383 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 384 |
+
used if the architecture supports semantic segmentation task.
|
| 385 |
+
train_cfg (dict): The training config.
|
| 386 |
+
|
| 387 |
+
Returns:
|
| 388 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 389 |
+
"""
|
| 390 |
+
seg_logits = self.forward(inputs,batch_size, num_clips)
|
| 391 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 392 |
+
return losses
|
| 393 |
+
|
| 394 |
+
def forward_test(self, inputs, img_metas, test_cfg, batch_size, num_clips):
|
| 395 |
+
"""Forward function for testing.
|
| 396 |
+
|
| 397 |
+
Args:
|
| 398 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 399 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 400 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 401 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 402 |
+
For details on the values of these keys see
|
| 403 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 404 |
+
test_cfg (dict): The testing config.
|
| 405 |
+
|
| 406 |
+
Returns:
|
| 407 |
+
Tensor: Output segmentation map.
|
| 408 |
+
"""
|
| 409 |
+
return self.forward(inputs, batch_size, num_clips)
|
| 410 |
+
|
| 411 |
+
def cls_seg(self, feat):
|
| 412 |
+
"""Classify each pixel."""
|
| 413 |
+
if self.dropout is not None:
|
| 414 |
+
feat = self.dropout(feat)
|
| 415 |
+
output = self.conv_seg(feat)
|
| 416 |
+
return output
|
| 417 |
+
|
| 418 |
+
class BaseDecodeHead_clips_flow(nn.Module, metaclass=ABCMeta):
|
| 419 |
+
"""Base class for BaseDecodeHead_clips_flow.
|
| 420 |
+
|
| 421 |
+
Args:
|
| 422 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 423 |
+
channels (int): Channels after modules, before conv_seg.
|
| 424 |
+
num_classes (int): Number of classes.
|
| 425 |
+
dropout_ratio (float): Ratio of dropout layer. Default: 0.1.
|
| 426 |
+
conv_cfg (dict|None): Config of conv layers. Default: None.
|
| 427 |
+
norm_cfg (dict|None): Config of norm layers. Default: None.
|
| 428 |
+
act_cfg (dict): Config of activation layers.
|
| 429 |
+
Default: dict(type='ReLU')
|
| 430 |
+
in_index (int|Sequence[int]): Input feature index. Default: -1
|
| 431 |
+
input_transform (str|None): Transformation type of input features.
|
| 432 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 433 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 434 |
+
same size as first one and than concat together.
|
| 435 |
+
Usually used in FCN head of HRNet.
|
| 436 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 437 |
+
a list and passed into decode head.
|
| 438 |
+
None: Only one select feature map is allowed.
|
| 439 |
+
Default: None.
|
| 440 |
+
loss_decode (dict): Config of decode loss.
|
| 441 |
+
Default: dict(type='CrossEntropyLoss').
|
| 442 |
+
ignore_index (int | None): The label index to be ignored. When using
|
| 443 |
+
masked BCE loss, ignore_index should be set to None. Default: 255
|
| 444 |
+
sampler (dict|None): The config of segmentation map sampler.
|
| 445 |
+
Default: None.
|
| 446 |
+
align_corners (bool): align_corners argument of F.interpolate.
|
| 447 |
+
Default: False.
|
| 448 |
+
"""
|
| 449 |
+
|
| 450 |
+
def __init__(self,
|
| 451 |
+
in_channels,
|
| 452 |
+
channels,
|
| 453 |
+
*,
|
| 454 |
+
num_classes,
|
| 455 |
+
dropout_ratio=0.1,
|
| 456 |
+
conv_cfg=None,
|
| 457 |
+
norm_cfg=None,
|
| 458 |
+
act_cfg=dict(type='ReLU'),
|
| 459 |
+
in_index=-1,
|
| 460 |
+
input_transform=None,
|
| 461 |
+
loss_decode=dict(
|
| 462 |
+
type='CrossEntropyLoss',
|
| 463 |
+
use_sigmoid=False,
|
| 464 |
+
loss_weight=1.0),
|
| 465 |
+
decoder_params=None,
|
| 466 |
+
ignore_index=255,
|
| 467 |
+
sampler=None,
|
| 468 |
+
align_corners=False,
|
| 469 |
+
num_clips=5):
|
| 470 |
+
super(BaseDecodeHead_clips_flow, self).__init__()
|
| 471 |
+
self._init_inputs(in_channels, in_index, input_transform)
|
| 472 |
+
self.channels = channels
|
| 473 |
+
self.num_classes = num_classes
|
| 474 |
+
self.dropout_ratio = dropout_ratio
|
| 475 |
+
self.conv_cfg = conv_cfg
|
| 476 |
+
self.norm_cfg = norm_cfg
|
| 477 |
+
self.act_cfg = act_cfg
|
| 478 |
+
self.in_index = in_index
|
| 479 |
+
self.ignore_index = ignore_index
|
| 480 |
+
self.align_corners = align_corners
|
| 481 |
+
self.num_clips=num_clips
|
| 482 |
+
|
| 483 |
+
if sampler is not None:
|
| 484 |
+
self.sampler = build_pixel_sampler(sampler, context=self)
|
| 485 |
+
else:
|
| 486 |
+
self.sampler = None
|
| 487 |
+
|
| 488 |
+
self.conv_seg = nn.Conv2d(channels, num_classes, kernel_size=1)
|
| 489 |
+
if dropout_ratio > 0:
|
| 490 |
+
self.dropout = nn.Dropout2d(dropout_ratio)
|
| 491 |
+
else:
|
| 492 |
+
self.dropout = None
|
| 493 |
+
self.fp16_enabled = False
|
| 494 |
+
|
| 495 |
+
def extra_repr(self):
|
| 496 |
+
"""Extra repr."""
|
| 497 |
+
s = f'input_transform={self.input_transform}, ' \
|
| 498 |
+
f'ignore_index={self.ignore_index}, ' \
|
| 499 |
+
f'align_corners={self.align_corners}'
|
| 500 |
+
return s
|
| 501 |
+
|
| 502 |
+
def _init_inputs(self, in_channels, in_index, input_transform):
|
| 503 |
+
"""Check and initialize input transforms.
|
| 504 |
+
|
| 505 |
+
The in_channels, in_index and input_transform must match.
|
| 506 |
+
Specifically, when input_transform is None, only single feature map
|
| 507 |
+
will be selected. So in_channels and in_index must be of type int.
|
| 508 |
+
When input_transform
|
| 509 |
+
|
| 510 |
+
Args:
|
| 511 |
+
in_channels (int|Sequence[int]): Input channels.
|
| 512 |
+
in_index (int|Sequence[int]): Input feature index.
|
| 513 |
+
input_transform (str|None): Transformation type of input features.
|
| 514 |
+
Options: 'resize_concat', 'multiple_select', None.
|
| 515 |
+
'resize_concat': Multiple feature maps will be resize to the
|
| 516 |
+
same size as first one and than concat together.
|
| 517 |
+
Usually used in FCN head of HRNet.
|
| 518 |
+
'multiple_select': Multiple feature maps will be bundle into
|
| 519 |
+
a list and passed into decode head.
|
| 520 |
+
None: Only one select feature map is allowed.
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
if input_transform is not None:
|
| 524 |
+
assert input_transform in ['resize_concat', 'multiple_select']
|
| 525 |
+
self.input_transform = input_transform
|
| 526 |
+
self.in_index = in_index
|
| 527 |
+
if input_transform is not None:
|
| 528 |
+
assert isinstance(in_channels, (list, tuple))
|
| 529 |
+
assert isinstance(in_index, (list, tuple))
|
| 530 |
+
assert len(in_channels) == len(in_index)
|
| 531 |
+
if input_transform == 'resize_concat':
|
| 532 |
+
self.in_channels = sum(in_channels)
|
| 533 |
+
else:
|
| 534 |
+
self.in_channels = in_channels
|
| 535 |
+
else:
|
| 536 |
+
assert isinstance(in_channels, int)
|
| 537 |
+
assert isinstance(in_index, int)
|
| 538 |
+
self.in_channels = in_channels
|
| 539 |
+
|
| 540 |
+
def init_weights(self):
|
| 541 |
+
"""Initialize weights of classification layer."""
|
| 542 |
+
normal_init(self.conv_seg, mean=0, std=0.01)
|
| 543 |
+
|
| 544 |
+
def _transform_inputs(self, inputs):
|
| 545 |
+
"""Transform inputs for decoder.
|
| 546 |
+
|
| 547 |
+
Args:
|
| 548 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 549 |
+
|
| 550 |
+
Returns:
|
| 551 |
+
Tensor: The transformed inputs
|
| 552 |
+
"""
|
| 553 |
+
|
| 554 |
+
if self.input_transform == 'resize_concat':
|
| 555 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 556 |
+
upsampled_inputs = [
|
| 557 |
+
resize(
|
| 558 |
+
input=x,
|
| 559 |
+
size=inputs[0].shape[2:],
|
| 560 |
+
mode='bilinear',
|
| 561 |
+
align_corners=self.align_corners) for x in inputs
|
| 562 |
+
]
|
| 563 |
+
inputs = torch.cat(upsampled_inputs, dim=1)
|
| 564 |
+
elif self.input_transform == 'multiple_select':
|
| 565 |
+
inputs = [inputs[i] for i in self.in_index]
|
| 566 |
+
else:
|
| 567 |
+
inputs = inputs[self.in_index]
|
| 568 |
+
|
| 569 |
+
return inputs
|
| 570 |
+
|
| 571 |
+
# @auto_fp16()
|
| 572 |
+
@abstractmethod
|
| 573 |
+
def forward(self, inputs):
|
| 574 |
+
"""Placeholder of forward function."""
|
| 575 |
+
pass
|
| 576 |
+
|
| 577 |
+
def forward_train(self, inputs, img_metas, gt_semantic_seg, train_cfg,batch_size, num_clips,img=None):
|
| 578 |
+
"""Forward function for training.
|
| 579 |
+
Args:
|
| 580 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 581 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 582 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 583 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 584 |
+
For details on the values of these keys see
|
| 585 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 586 |
+
gt_semantic_seg (Tensor): Semantic segmentation masks
|
| 587 |
+
used if the architecture supports semantic segmentation task.
|
| 588 |
+
train_cfg (dict): The training config.
|
| 589 |
+
|
| 590 |
+
Returns:
|
| 591 |
+
dict[str, Tensor]: a dictionary of loss components
|
| 592 |
+
"""
|
| 593 |
+
seg_logits = self.forward(inputs,batch_size, num_clips,img)
|
| 594 |
+
losses = self.losses(seg_logits, gt_semantic_seg)
|
| 595 |
+
return losses
|
| 596 |
+
|
| 597 |
+
def forward_test(self, inputs, img_metas, test_cfg, batch_size=None, num_clips=None, img=None):
|
| 598 |
+
"""Forward function for testing.
|
| 599 |
+
|
| 600 |
+
Args:
|
| 601 |
+
inputs (list[Tensor]): List of multi-level img features.
|
| 602 |
+
img_metas (list[dict]): List of image info dict where each dict
|
| 603 |
+
has: 'img_shape', 'scale_factor', 'flip', and may also contain
|
| 604 |
+
'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
|
| 605 |
+
For details on the values of these keys see
|
| 606 |
+
`mmseg/datasets/pipelines/formatting.py:Collect`.
|
| 607 |
+
test_cfg (dict): The testing config.
|
| 608 |
+
|
| 609 |
+
Returns:
|
| 610 |
+
Tensor: Output segmentation map.
|
| 611 |
+
"""
|
| 612 |
+
return self.forward(inputs, batch_size, num_clips,img)
|
| 613 |
+
|
| 614 |
+
def cls_seg(self, feat):
|
| 615 |
+
"""Classify each pixel."""
|
| 616 |
+
if self.dropout is not None:
|
| 617 |
+
feat = self.dropout(feat)
|
| 618 |
+
output = self.conv_seg(feat)
|
| 619 |
+
return output
|
models/SpaTrackV2/models/depth_refiner/depth_refiner.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from models.monoD.depth_anything_v2.dinov2_layers.patch_embed import PatchEmbed
|
| 5 |
+
from models.SpaTrackV2.models.depth_refiner.backbone import mit_b3
|
| 6 |
+
from models.SpaTrackV2.models.depth_refiner.stablizer import Stabilization_Network_Cross_Attention
|
| 7 |
+
from einops import rearrange
|
| 8 |
+
class TrackStablizer(nn.Module):
|
| 9 |
+
def __init__(self):
|
| 10 |
+
super().__init__()
|
| 11 |
+
|
| 12 |
+
self.backbone = mit_b3()
|
| 13 |
+
|
| 14 |
+
old_conv = self.backbone.patch_embed1.proj
|
| 15 |
+
new_conv = nn.Conv2d(old_conv.in_channels + 4, old_conv.out_channels, kernel_size=old_conv.kernel_size, stride=old_conv.stride, padding=old_conv.padding)
|
| 16 |
+
|
| 17 |
+
new_conv.weight[:, :3, :, :].data.copy_(old_conv.weight.clone())
|
| 18 |
+
self.backbone.patch_embed1.proj = new_conv
|
| 19 |
+
|
| 20 |
+
self.Track_Stabilizer = Stabilization_Network_Cross_Attention(in_channels=[64, 128, 320, 512],
|
| 21 |
+
in_index=[0, 1, 2, 3],
|
| 22 |
+
feature_strides=[4, 8, 16, 32],
|
| 23 |
+
channels=128,
|
| 24 |
+
dropout_ratio=0.1,
|
| 25 |
+
num_classes=1,
|
| 26 |
+
align_corners=False,
|
| 27 |
+
decoder_params=dict(embed_dim=256, depths=4),
|
| 28 |
+
num_clips=16,
|
| 29 |
+
norm_cfg = dict(type='SyncBN', requires_grad=True))
|
| 30 |
+
|
| 31 |
+
self.edge_conv = nn.Sequential(nn.Conv2d(in_channels=4, out_channels=64, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 32 |
+
nn.ReLU(inplace=True))
|
| 33 |
+
self.edge_conv1 = nn.Sequential(nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1, stride=2, bias=True),\
|
| 34 |
+
nn.ReLU(inplace=True))
|
| 35 |
+
self.success = False
|
| 36 |
+
self.x = None
|
| 37 |
+
|
| 38 |
+
def buffer_forward(self, inputs, num_clips=16):
|
| 39 |
+
"""
|
| 40 |
+
buffer forward for getting the pointmap and image features
|
| 41 |
+
"""
|
| 42 |
+
B, T, C, H, W = inputs.shape
|
| 43 |
+
self.x = self.backbone(inputs)
|
| 44 |
+
scale, shift = self.Track_Stabilizer.buffer_forward(self.x, num_clips=num_clips)
|
| 45 |
+
self.success = True
|
| 46 |
+
return scale, shift
|
| 47 |
+
|
| 48 |
+
def forward(self, inputs, tracks, tracks_uvd, num_clips=16, imgs=None, vis_track=None):
|
| 49 |
+
|
| 50 |
+
"""
|
| 51 |
+
Args:
|
| 52 |
+
inputs: [B, T, C, H, W], RGB + PointMap + Mask
|
| 53 |
+
tracks: [B, T, N, 4], 3D tracks in camera coordinate + visibility
|
| 54 |
+
num_clips: int, number of clips to use
|
| 55 |
+
"""
|
| 56 |
+
B, T, C, H, W = inputs.shape
|
| 57 |
+
edge_feat = self.edge_conv(inputs.view(B*T,4,H,W))
|
| 58 |
+
edge_feat1 = self.edge_conv1(edge_feat)
|
| 59 |
+
|
| 60 |
+
if not self.success:
|
| 61 |
+
scale, shift = self.Track_Stabilizer.buffer_forward(self.x,num_clips=num_clips)
|
| 62 |
+
self.success = True
|
| 63 |
+
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
| 64 |
+
else:
|
| 65 |
+
update = self.Track_Stabilizer(self.x,edge_feat,edge_feat1,tracks,tracks_uvd,num_clips=num_clips, imgs=imgs, vis_track=vis_track)
|
| 66 |
+
|
| 67 |
+
return update
|
| 68 |
+
|
| 69 |
+
def reset_success(self):
|
| 70 |
+
self.success = False
|
| 71 |
+
self.x = None
|
| 72 |
+
self.Track_Stabilizer.reset_success()
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
if __name__ == "__main__":
|
| 76 |
+
# Create test input tensors
|
| 77 |
+
batch_size = 1
|
| 78 |
+
seq_len = 16
|
| 79 |
+
channels = 7 # 3 for RGB + 3 for PointMap + 1 for Mask
|
| 80 |
+
height = 384
|
| 81 |
+
width = 512
|
| 82 |
+
|
| 83 |
+
# Create random input tensor with shape [B, T, C, H, W]
|
| 84 |
+
inputs = torch.randn(batch_size, seq_len, channels, height, width)
|
| 85 |
+
|
| 86 |
+
# Create random tracks
|
| 87 |
+
tracks = torch.randn(batch_size, seq_len, 1024, 4)
|
| 88 |
+
|
| 89 |
+
# Create random test images
|
| 90 |
+
test_imgs = torch.randn(batch_size, seq_len, 3, height, width)
|
| 91 |
+
|
| 92 |
+
# Initialize model and move to GPU
|
| 93 |
+
model = TrackStablizer().cuda()
|
| 94 |
+
|
| 95 |
+
# Move inputs to GPU and run forward pass
|
| 96 |
+
inputs = inputs.cuda()
|
| 97 |
+
tracks = tracks.cuda()
|
| 98 |
+
outputs = model.buffer_forward(inputs, num_clips=seq_len)
|
| 99 |
+
import time
|
| 100 |
+
start_time = time.time()
|
| 101 |
+
outputs = model(inputs, tracks, num_clips=seq_len)
|
| 102 |
+
end_time = time.time()
|
| 103 |
+
print(f"Time taken: {end_time - start_time} seconds")
|
| 104 |
+
import pdb; pdb.set_trace()
|
| 105 |
+
# # Print shapes for verification
|
| 106 |
+
# print(f"Input shape: {inputs.shape}")
|
| 107 |
+
# print(f"Output shape: {outputs.shape}")
|
| 108 |
+
|
| 109 |
+
# # Basic tests
|
| 110 |
+
# assert outputs.shape[0] == batch_size, "Batch size mismatch"
|
| 111 |
+
# assert len(outputs.shape) == 4, "Output should be 4D: [B,C,H,W]"
|
| 112 |
+
# assert torch.all(outputs >= 0), "Output should be non-negative after ReLU"
|
| 113 |
+
|
| 114 |
+
# print("All tests passed!")
|
| 115 |
+
|
models/SpaTrackV2/models/depth_refiner/network.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
'''
|
| 4 |
+
Author: Ke Xian
|
| 5 |
+
Email: kexian@hust.edu.cn
|
| 6 |
+
Date: 2020/07/20
|
| 7 |
+
'''
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.init as init
|
| 12 |
+
|
| 13 |
+
# ==============================================================================================================
|
| 14 |
+
|
| 15 |
+
class FTB(nn.Module):
|
| 16 |
+
def __init__(self, inchannels, midchannels=512):
|
| 17 |
+
super(FTB, self).__init__()
|
| 18 |
+
self.in1 = inchannels
|
| 19 |
+
self.mid = midchannels
|
| 20 |
+
|
| 21 |
+
self.conv1 = nn.Conv2d(in_channels=self.in1, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True)
|
| 22 |
+
self.conv_branch = nn.Sequential(nn.ReLU(inplace=True),\
|
| 23 |
+
nn.Conv2d(in_channels=self.mid, out_channels=self.mid, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 24 |
+
#nn.BatchNorm2d(num_features=self.mid),\
|
| 25 |
+
nn.ReLU(inplace=True),\
|
| 26 |
+
nn.Conv2d(in_channels=self.mid, out_channels= self.mid, kernel_size=3, padding=1, stride=1, bias=True))
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
|
| 29 |
+
self.init_params()
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
x = self.conv1(x)
|
| 33 |
+
x = x + self.conv_branch(x)
|
| 34 |
+
x = self.relu(x)
|
| 35 |
+
|
| 36 |
+
return x
|
| 37 |
+
|
| 38 |
+
def init_params(self):
|
| 39 |
+
for m in self.modules():
|
| 40 |
+
if isinstance(m, nn.Conv2d):
|
| 41 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 42 |
+
init.normal_(m.weight, std=0.01)
|
| 43 |
+
# init.xavier_normal_(m.weight)
|
| 44 |
+
if m.bias is not None:
|
| 45 |
+
init.constant_(m.bias, 0)
|
| 46 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 47 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 48 |
+
init.normal_(m.weight, std=0.01)
|
| 49 |
+
# init.xavier_normal_(m.weight)
|
| 50 |
+
if m.bias is not None:
|
| 51 |
+
init.constant_(m.bias, 0)
|
| 52 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 53 |
+
init.constant_(m.weight, 1)
|
| 54 |
+
init.constant_(m.bias, 0)
|
| 55 |
+
elif isinstance(m, nn.Linear):
|
| 56 |
+
init.normal_(m.weight, std=0.01)
|
| 57 |
+
if m.bias is not None:
|
| 58 |
+
init.constant_(m.bias, 0)
|
| 59 |
+
|
| 60 |
+
class ATA(nn.Module):
|
| 61 |
+
def __init__(self, inchannels, reduction = 8):
|
| 62 |
+
super(ATA, self).__init__()
|
| 63 |
+
self.inchannels = inchannels
|
| 64 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
| 65 |
+
self.fc = nn.Sequential(nn.Linear(self.inchannels*2, self.inchannels // reduction),
|
| 66 |
+
nn.ReLU(inplace=True),
|
| 67 |
+
nn.Linear(self.inchannels // reduction, self.inchannels),
|
| 68 |
+
nn.Sigmoid())
|
| 69 |
+
self.init_params()
|
| 70 |
+
|
| 71 |
+
def forward(self, low_x, high_x):
|
| 72 |
+
n, c, _, _ = low_x.size()
|
| 73 |
+
x = torch.cat([low_x, high_x], 1)
|
| 74 |
+
x = self.avg_pool(x)
|
| 75 |
+
x = x.view(n, -1)
|
| 76 |
+
x = self.fc(x).view(n,c,1,1)
|
| 77 |
+
x = low_x * x + high_x
|
| 78 |
+
|
| 79 |
+
return x
|
| 80 |
+
|
| 81 |
+
def init_params(self):
|
| 82 |
+
for m in self.modules():
|
| 83 |
+
if isinstance(m, nn.Conv2d):
|
| 84 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 85 |
+
#init.normal(m.weight, std=0.01)
|
| 86 |
+
init.xavier_normal_(m.weight)
|
| 87 |
+
if m.bias is not None:
|
| 88 |
+
init.constant_(m.bias, 0)
|
| 89 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 90 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 91 |
+
#init.normal_(m.weight, std=0.01)
|
| 92 |
+
init.xavier_normal_(m.weight)
|
| 93 |
+
if m.bias is not None:
|
| 94 |
+
init.constant_(m.bias, 0)
|
| 95 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 96 |
+
init.constant_(m.weight, 1)
|
| 97 |
+
init.constant_(m.bias, 0)
|
| 98 |
+
elif isinstance(m, nn.Linear):
|
| 99 |
+
init.normal_(m.weight, std=0.01)
|
| 100 |
+
if m.bias is not None:
|
| 101 |
+
init.constant_(m.bias, 0)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class FFM(nn.Module):
|
| 105 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
| 106 |
+
super(FFM, self).__init__()
|
| 107 |
+
self.inchannels = inchannels
|
| 108 |
+
self.midchannels = midchannels
|
| 109 |
+
self.outchannels = outchannels
|
| 110 |
+
self.upfactor = upfactor
|
| 111 |
+
|
| 112 |
+
self.ftb1 = FTB(inchannels=self.inchannels, midchannels=self.midchannels)
|
| 113 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
| 114 |
+
|
| 115 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
| 116 |
+
|
| 117 |
+
self.init_params()
|
| 118 |
+
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 119 |
+
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 120 |
+
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 121 |
+
|
| 122 |
+
def forward(self, low_x, high_x):
|
| 123 |
+
x = self.ftb1(low_x)
|
| 124 |
+
|
| 125 |
+
'''
|
| 126 |
+
x = torch.cat((x,high_x),1)
|
| 127 |
+
if x.shape[2] == 12:
|
| 128 |
+
x = self.p1(x)
|
| 129 |
+
elif x.shape[2] == 24:
|
| 130 |
+
x = self.p2(x)
|
| 131 |
+
elif x.shape[2] == 48:
|
| 132 |
+
x = self.p3(x)
|
| 133 |
+
'''
|
| 134 |
+
x = x + high_x ###high_x
|
| 135 |
+
x = self.ftb2(x)
|
| 136 |
+
x = self.upsample(x)
|
| 137 |
+
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
def init_params(self):
|
| 141 |
+
for m in self.modules():
|
| 142 |
+
if isinstance(m, nn.Conv2d):
|
| 143 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 144 |
+
init.normal_(m.weight, std=0.01)
|
| 145 |
+
#init.xavier_normal_(m.weight)
|
| 146 |
+
if m.bias is not None:
|
| 147 |
+
init.constant_(m.bias, 0)
|
| 148 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 149 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 150 |
+
init.normal_(m.weight, std=0.01)
|
| 151 |
+
#init.xavier_normal_(m.weight)
|
| 152 |
+
if m.bias is not None:
|
| 153 |
+
init.constant_(m.bias, 0)
|
| 154 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
| 155 |
+
init.constant_(m.weight, 1)
|
| 156 |
+
init.constant_(m.bias, 0)
|
| 157 |
+
elif isinstance(m, nn.Linear):
|
| 158 |
+
init.normal_(m.weight, std=0.01)
|
| 159 |
+
if m.bias is not None:
|
| 160 |
+
init.constant_(m.bias, 0)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class noFFM(nn.Module):
|
| 165 |
+
def __init__(self, inchannels, midchannels, outchannels, upfactor=2):
|
| 166 |
+
super(noFFM, self).__init__()
|
| 167 |
+
self.inchannels = inchannels
|
| 168 |
+
self.midchannels = midchannels
|
| 169 |
+
self.outchannels = outchannels
|
| 170 |
+
self.upfactor = upfactor
|
| 171 |
+
|
| 172 |
+
self.ftb2 = FTB(inchannels=self.midchannels, midchannels=self.outchannels)
|
| 173 |
+
|
| 174 |
+
self.upsample = nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True)
|
| 175 |
+
|
| 176 |
+
self.init_params()
|
| 177 |
+
#self.p1 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 178 |
+
#self.p2 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 179 |
+
#self.p3 = nn.Conv2d(512, 256, kernel_size=1, padding=0, bias=False)
|
| 180 |
+
|
| 181 |
+
def forward(self, low_x, high_x):
|
| 182 |
+
|
| 183 |
+
#x = self.ftb1(low_x)
|
| 184 |
+
x = high_x ###high_x
|
| 185 |
+
x = self.ftb2(x)
|
| 186 |
+
x = self.upsample(x)
|
| 187 |
+
|
| 188 |
+
return x
|
| 189 |
+
|
| 190 |
+
def init_params(self):
|
| 191 |
+
for m in self.modules():
|
| 192 |
+
if isinstance(m, nn.Conv2d):
|
| 193 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 194 |
+
init.normal_(m.weight, std=0.01)
|
| 195 |
+
#init.xavier_normal_(m.weight)
|
| 196 |
+
if m.bias is not None:
|
| 197 |
+
init.constant_(m.bias, 0)
|
| 198 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 199 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 200 |
+
init.normal_(m.weight, std=0.01)
|
| 201 |
+
#init.xavier_normal_(m.weight)
|
| 202 |
+
if m.bias is not None:
|
| 203 |
+
init.constant_(m.bias, 0)
|
| 204 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
| 205 |
+
init.constant_(m.weight, 1)
|
| 206 |
+
init.constant_(m.bias, 0)
|
| 207 |
+
elif isinstance(m, nn.Linear):
|
| 208 |
+
init.normal_(m.weight, std=0.01)
|
| 209 |
+
if m.bias is not None:
|
| 210 |
+
init.constant_(m.bias, 0)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
class AO(nn.Module):
|
| 216 |
+
# Adaptive output module
|
| 217 |
+
def __init__(self, inchannels, outchannels, upfactor=2):
|
| 218 |
+
super(AO, self).__init__()
|
| 219 |
+
self.inchannels = inchannels
|
| 220 |
+
self.outchannels = outchannels
|
| 221 |
+
self.upfactor = upfactor
|
| 222 |
+
|
| 223 |
+
"""
|
| 224 |
+
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 225 |
+
nn.BatchNorm2d(num_features=self.inchannels//2),\
|
| 226 |
+
nn.ReLU(inplace=True),\
|
| 227 |
+
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 228 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True) )#,\
|
| 229 |
+
#nn.ReLU(inplace=True)) ## get positive values
|
| 230 |
+
"""
|
| 231 |
+
self.adapt_conv = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.inchannels//2, kernel_size=3, padding=1, stride=1, bias=True),\
|
| 232 |
+
#nn.BatchNorm2d(num_features=self.inchannels//2),\
|
| 233 |
+
nn.ReLU(inplace=True),\
|
| 234 |
+
nn.Upsample(scale_factor=self.upfactor, mode='bilinear', align_corners=True), \
|
| 235 |
+
nn.Conv2d(in_channels=self.inchannels//2, out_channels=self.outchannels, kernel_size=1, padding=0, stride=1))
|
| 236 |
+
|
| 237 |
+
#nn.ReLU(inplace=True)) ## get positive values
|
| 238 |
+
|
| 239 |
+
self.init_params()
|
| 240 |
+
|
| 241 |
+
def forward(self, x):
|
| 242 |
+
x = self.adapt_conv(x)
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def init_params(self):
|
| 246 |
+
for m in self.modules():
|
| 247 |
+
if isinstance(m, nn.Conv2d):
|
| 248 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 249 |
+
init.normal_(m.weight, std=0.01)
|
| 250 |
+
#init.xavier_normal_(m.weight)
|
| 251 |
+
if m.bias is not None:
|
| 252 |
+
init.constant_(m.bias, 0)
|
| 253 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 254 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 255 |
+
init.normal_(m.weight, std=0.01)
|
| 256 |
+
#init.xavier_normal_(m.weight)
|
| 257 |
+
if m.bias is not None:
|
| 258 |
+
init.constant_(m.bias, 0)
|
| 259 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.Batchnorm2d
|
| 260 |
+
init.constant_(m.weight, 1)
|
| 261 |
+
init.constant_(m.bias, 0)
|
| 262 |
+
elif isinstance(m, nn.Linear):
|
| 263 |
+
init.normal_(m.weight, std=0.01)
|
| 264 |
+
if m.bias is not None:
|
| 265 |
+
init.constant_(m.bias, 0)
|
| 266 |
+
|
| 267 |
+
class ASPP(nn.Module):
|
| 268 |
+
def __init__(self, inchannels=256, planes=128, rates = [1, 6, 12, 18]):
|
| 269 |
+
super(ASPP, self).__init__()
|
| 270 |
+
self.inchannels = inchannels
|
| 271 |
+
self.planes = planes
|
| 272 |
+
self.rates = rates
|
| 273 |
+
self.kernel_sizes = []
|
| 274 |
+
self.paddings = []
|
| 275 |
+
for rate in self.rates:
|
| 276 |
+
if rate == 1:
|
| 277 |
+
self.kernel_sizes.append(1)
|
| 278 |
+
self.paddings.append(0)
|
| 279 |
+
else:
|
| 280 |
+
self.kernel_sizes.append(3)
|
| 281 |
+
self.paddings.append(rate)
|
| 282 |
+
self.atrous_0 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[0],
|
| 283 |
+
stride=1, padding=self.paddings[0], dilation=self.rates[0], bias=True),
|
| 284 |
+
nn.ReLU(inplace=True),
|
| 285 |
+
nn.BatchNorm2d(num_features=self.planes)
|
| 286 |
+
)
|
| 287 |
+
self.atrous_1 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[1],
|
| 288 |
+
stride=1, padding=self.paddings[1], dilation=self.rates[1], bias=True),
|
| 289 |
+
nn.ReLU(inplace=True),
|
| 290 |
+
nn.BatchNorm2d(num_features=self.planes),
|
| 291 |
+
)
|
| 292 |
+
self.atrous_2 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[2],
|
| 293 |
+
stride=1, padding=self.paddings[2], dilation=self.rates[2], bias=True),
|
| 294 |
+
nn.ReLU(inplace=True),
|
| 295 |
+
nn.BatchNorm2d(num_features=self.planes),
|
| 296 |
+
)
|
| 297 |
+
self.atrous_3 = nn.Sequential(nn.Conv2d(in_channels=self.inchannels, out_channels=self.planes, kernel_size=self.kernel_sizes[3],
|
| 298 |
+
stride=1, padding=self.paddings[3], dilation=self.rates[3], bias=True),
|
| 299 |
+
nn.ReLU(inplace=True),
|
| 300 |
+
nn.BatchNorm2d(num_features=self.planes),
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
#self.conv = nn.Conv2d(in_channels=self.planes * 4, out_channels=self.inchannels, kernel_size=3, padding=1, stride=1, bias=True)
|
| 304 |
+
def forward(self, x):
|
| 305 |
+
x = torch.cat([self.atrous_0(x), self.atrous_1(x), self.atrous_2(x), self.atrous_3(x)],1)
|
| 306 |
+
#x = self.conv(x)
|
| 307 |
+
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
# ==============================================================================================================
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
class ResidualConv(nn.Module):
|
| 314 |
+
def __init__(self, inchannels):
|
| 315 |
+
super(ResidualConv, self).__init__()
|
| 316 |
+
#nn.BatchNorm2d
|
| 317 |
+
self.conv = nn.Sequential(
|
| 318 |
+
#nn.BatchNorm2d(num_features=inchannels),
|
| 319 |
+
nn.ReLU(inplace=False),
|
| 320 |
+
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=3, padding=1, stride=1, groups=inchannels,bias=True),
|
| 321 |
+
#nn.Conv2d(in_channels=inchannels, out_channels=inchannels, kernel_size=1, padding=0, stride=1, groups=1,bias=True)
|
| 322 |
+
nn.Conv2d(in_channels=inchannels, out_channels=inchannels//2, kernel_size=3, padding=1, stride=1, bias=False),
|
| 323 |
+
nn.BatchNorm2d(num_features=inchannels//2),
|
| 324 |
+
nn.ReLU(inplace=False),
|
| 325 |
+
nn.Conv2d(in_channels=inchannels//2, out_channels=inchannels, kernel_size=3, padding=1, stride=1, bias=False)
|
| 326 |
+
)
|
| 327 |
+
self.init_params()
|
| 328 |
+
|
| 329 |
+
def forward(self, x):
|
| 330 |
+
x = self.conv(x)+x
|
| 331 |
+
return x
|
| 332 |
+
|
| 333 |
+
def init_params(self):
|
| 334 |
+
for m in self.modules():
|
| 335 |
+
if isinstance(m, nn.Conv2d):
|
| 336 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 337 |
+
init.normal_(m.weight, std=0.01)
|
| 338 |
+
#init.xavier_normal_(m.weight)
|
| 339 |
+
if m.bias is not None:
|
| 340 |
+
init.constant_(m.bias, 0)
|
| 341 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 342 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 343 |
+
init.normal_(m.weight, std=0.01)
|
| 344 |
+
#init.xavier_normal_(m.weight)
|
| 345 |
+
if m.bias is not None:
|
| 346 |
+
init.constant_(m.bias, 0)
|
| 347 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 348 |
+
init.constant_(m.weight, 1)
|
| 349 |
+
init.constant_(m.bias, 0)
|
| 350 |
+
elif isinstance(m, nn.Linear):
|
| 351 |
+
init.normal_(m.weight, std=0.01)
|
| 352 |
+
if m.bias is not None:
|
| 353 |
+
init.constant_(m.bias, 0)
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
class FeatureFusion(nn.Module):
|
| 357 |
+
def __init__(self, inchannels, outchannels):
|
| 358 |
+
super(FeatureFusion, self).__init__()
|
| 359 |
+
self.conv = ResidualConv(inchannels=inchannels)
|
| 360 |
+
#nn.BatchNorm2d
|
| 361 |
+
self.up = nn.Sequential(ResidualConv(inchannels=inchannels),
|
| 362 |
+
nn.ConvTranspose2d(in_channels=inchannels, out_channels=outchannels, kernel_size=3,stride=2, padding=1, output_padding=1),
|
| 363 |
+
nn.BatchNorm2d(num_features=outchannels),
|
| 364 |
+
nn.ReLU(inplace=True))
|
| 365 |
+
|
| 366 |
+
def forward(self, lowfeat, highfeat):
|
| 367 |
+
return self.up(highfeat + self.conv(lowfeat))
|
| 368 |
+
|
| 369 |
+
def init_params(self):
|
| 370 |
+
for m in self.modules():
|
| 371 |
+
if isinstance(m, nn.Conv2d):
|
| 372 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 373 |
+
init.normal_(m.weight, std=0.01)
|
| 374 |
+
#init.xavier_normal_(m.weight)
|
| 375 |
+
if m.bias is not None:
|
| 376 |
+
init.constant_(m.bias, 0)
|
| 377 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 378 |
+
#init.kaiming_normal_(m.weight, mode='fan_out')
|
| 379 |
+
init.normal_(m.weight, std=0.01)
|
| 380 |
+
#init.xavier_normal_(m.weight)
|
| 381 |
+
if m.bias is not None:
|
| 382 |
+
init.constant_(m.bias, 0)
|
| 383 |
+
elif isinstance(m, nn.BatchNorm2d): #nn.BatchNorm2d
|
| 384 |
+
init.constant_(m.weight, 1)
|
| 385 |
+
init.constant_(m.bias, 0)
|
| 386 |
+
elif isinstance(m, nn.Linear):
|
| 387 |
+
init.normal_(m.weight, std=0.01)
|
| 388 |
+
if m.bias is not None:
|
| 389 |
+
init.constant_(m.bias, 0)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
class SenceUnderstand(nn.Module):
|
| 393 |
+
def __init__(self, channels):
|
| 394 |
+
super(SenceUnderstand, self).__init__()
|
| 395 |
+
self.channels = channels
|
| 396 |
+
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, padding=1),
|
| 397 |
+
nn.ReLU(inplace = True))
|
| 398 |
+
self.pool = nn.AdaptiveAvgPool2d(8)
|
| 399 |
+
self.fc = nn.Sequential(nn.Linear(512*8*8, self.channels),
|
| 400 |
+
nn.ReLU(inplace = True))
|
| 401 |
+
self.conv2 = nn.Sequential(nn.Conv2d(in_channels=self.channels, out_channels=self.channels, kernel_size=1, padding=0),
|
| 402 |
+
nn.ReLU(inplace=True))
|
| 403 |
+
self.initial_params()
|
| 404 |
+
|
| 405 |
+
def forward(self, x):
|
| 406 |
+
n,c,h,w = x.size()
|
| 407 |
+
x = self.conv1(x)
|
| 408 |
+
x = self.pool(x)
|
| 409 |
+
x = x.view(n,-1)
|
| 410 |
+
x = self.fc(x)
|
| 411 |
+
x = x.view(n, self.channels, 1, 1)
|
| 412 |
+
x = self.conv2(x)
|
| 413 |
+
x = x.repeat(1,1,h,w)
|
| 414 |
+
return x
|
| 415 |
+
|
| 416 |
+
def initial_params(self, dev=0.01):
|
| 417 |
+
for m in self.modules():
|
| 418 |
+
if isinstance(m, nn.Conv2d):
|
| 419 |
+
#print torch.sum(m.weight)
|
| 420 |
+
m.weight.data.normal_(0, dev)
|
| 421 |
+
if m.bias is not None:
|
| 422 |
+
m.bias.data.fill_(0)
|
| 423 |
+
elif isinstance(m, nn.ConvTranspose2d):
|
| 424 |
+
#print torch.sum(m.weight)
|
| 425 |
+
m.weight.data.normal_(0, dev)
|
| 426 |
+
if m.bias is not None:
|
| 427 |
+
m.bias.data.fill_(0)
|
| 428 |
+
elif isinstance(m, nn.Linear):
|
| 429 |
+
m.weight.data.normal_(0, dev)
|
models/SpaTrackV2/models/depth_refiner/stablilization_attention.py
ADDED
|
@@ -0,0 +1,1187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|