iMihayo commited on
Commit
0c2efcc
·
verified ·
1 Parent(s): 2e6ea8a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. run_scripts/ffn_q2a/aloha/aloha_robotwin2_ffn_25_fredf.sh +6 -6
  2. run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25.sh +104 -0
  3. run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_inter.sh +104 -0
  4. run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_vggt.sh +108 -0
  5. run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_vr.sh +104 -0
  6. vggt/heads/__pycache__/camera_head.cpython-310.pyc +0 -0
  7. vggt/heads/__pycache__/dpt_head.cpython-310.pyc +0 -0
  8. vggt/heads/__pycache__/head_act.cpython-310.pyc +0 -0
  9. vggt/heads/__pycache__/track_head.cpython-310.pyc +0 -0
  10. vggt/heads/__pycache__/utils.cpython-310.pyc +0 -0
  11. vggt/heads/camera_head.py +162 -0
  12. vggt/heads/dpt_head.py +497 -0
  13. vggt/heads/head_act.py +125 -0
  14. vggt/heads/track_head.py +108 -0
  15. vggt/heads/track_modules/__init__.py +5 -0
  16. vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc +0 -0
  17. vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc +0 -0
  18. vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc +0 -0
  19. vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc +0 -0
  20. vggt/heads/track_modules/modules.py +218 -0
  21. vggt/heads/track_modules/utils.py +226 -0
  22. vggt/heads/utils.py +108 -0
  23. vggt/layers/__pycache__/__init__.cpython-310.pyc +0 -0
  24. vggt/layers/__pycache__/attention.cpython-310.pyc +0 -0
  25. vggt/layers/__pycache__/drop_path.cpython-310.pyc +0 -0
  26. vggt/layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
  27. vggt/layers/__pycache__/mlp.cpython-310.pyc +0 -0
  28. vggt/layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
  29. vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
  30. vggt/layers/block.py +259 -0
  31. vggt/layers/mlp.py +40 -0
  32. vggt/layers/rope.py +188 -0
  33. vggt/layers/vision_transformer.py +407 -0
  34. vggt/models/__pycache__/aggregator.cpython-310.pyc +0 -0
  35. vggt/utils/__pycache__/geometry.cpython-310.pyc +0 -0
  36. vggt/utils/__pycache__/load_fn.cpython-310.pyc +0 -0
  37. vggt/utils/__pycache__/rotation.cpython-310.pyc +0 -0
  38. vggt/utils/geometry.py +166 -0
  39. vggt/utils/load_fn.py +111 -0
  40. vggt/utils/pose_enc.py +130 -0
  41. vggt/utils/visual_track.py +239 -0
  42. wandb/offline-run-20250711_184611-k8qgu560/files/requirements.txt +199 -0
  43. wandb/offline-run-20250711_184611-k8qgu560/files/wandb-metadata.json +133 -0
  44. wandb/offline-run-20250711_184611-k8qgu560/logs/debug-internal.log +7 -0
  45. wandb/offline-run-20250711_184611-k8qgu560/logs/debug.log +0 -0
  46. wandb/offline-run-20250711_211915-s4epglyq/files/requirements.txt +199 -0
  47. wandb/offline-run-20250711_211915-s4epglyq/files/wandb-metadata.json +137 -0
  48. wandb/offline-run-20250711_211915-s4epglyq/logs/debug.log +0 -0
  49. wandb/offline-run-20250711_211915-s4epglyq/run-s4epglyq.wandb +0 -0
  50. wandb/offline-run-20250711_212208-i2mclkeg/files/requirements.txt +199 -0
run_scripts/ffn_q2a/aloha/aloha_robotwin2_ffn_25_fredf.sh CHANGED
@@ -11,8 +11,8 @@ use_multi_scaling=False
11
  mlp_type=ffn
12
  decoder_num_blocks=2
13
  robot_platform=aloha
14
- proj_type=gelu_linear
15
- ffn_type=gelu
16
  expand_inner_ratio=1
17
  linear_drop_ratio=0.1
18
  multi_queries_num=25
@@ -29,9 +29,9 @@ wandb_log_freq=1
29
  use_proprio=True
30
  use_diffusion=False
31
  use_film=True
32
- num_steps_before_decay=20000
33
- save_freq=5000
34
- max_steps=50000
35
  vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
36
  data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
37
  dataset_name=aloha_agilex_robotwin2_benchmark
@@ -76,7 +76,7 @@ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-pe
76
  --num_images_in_input "$num_images_in_input" \
77
  --use_proprio "$use_proprio" \
78
  --batch_size "$batch_size" \
79
- --learning_rate 5e-5 \
80
  --num_steps_before_decay "$num_steps_before_decay" \
81
  --max_steps "$max_steps" \
82
  --save_freq "$save_freq" \
 
11
  mlp_type=ffn
12
  decoder_num_blocks=2
13
  robot_platform=aloha
14
+ proj_type=relu_linear
15
+ ffn_type=relu
16
  expand_inner_ratio=1
17
  linear_drop_ratio=0.1
18
  multi_queries_num=25
 
29
  use_proprio=True
30
  use_diffusion=False
31
  use_film=True
32
+ num_steps_before_decay=10000
33
+ save_freq=10000
34
+ max_steps=20000
35
  vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
36
  data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
37
  dataset_name=aloha_agilex_robotwin2_benchmark
 
76
  --num_images_in_input "$num_images_in_input" \
77
  --use_proprio "$use_proprio" \
78
  --batch_size "$batch_size" \
79
+ --learning_rate 1e-4 \
80
  --num_steps_before_decay "$num_steps_before_decay" \
81
  --max_steps "$max_steps" \
82
  --save_freq "$save_freq" \
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25.sh ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=simvla_twin2
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_ts_25
6
+ use_predict_future_prop=False
7
+ batch_size=4
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=1
13
+ robot_platform=aloha
14
+ proj_type=onlynorm
15
+ ffn_type=swiglu
16
+ expand_inner_ratio=1
17
+ linear_drop_ratio=0.0
18
+ multi_queries_num=14
19
+ multi_query_norm_type=layernorm
20
+ action_norm=layernorm
21
+ use_patch_wise_loss=True
22
+ MODE=${RUN_MODE}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
23
+ #========== !NOTE! ==========#
24
+ use_l1_regression=True
25
+ num_images_in_input=3
26
+ wandb_entity=chenghaha
27
+ wandb_project=robotwin
28
+ wandb_log_freq=1
29
+ use_proprio=True
30
+ use_diffusion=False
31
+ use_film=True
32
+ num_steps_before_decay=1000
33
+ save_freq=2000
34
+ max_steps=2000
35
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
36
+ data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
37
+ dataset_name=grab_roller_aloha_agilex_50
38
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
39
+ #========== get run_id ==========#
40
+ note_parts=("${MODE}")
41
+
42
+ # if [ "$use_l1_regression" = "True" ]; then
43
+ # note_parts+=("L1_regression")
44
+ # fi
45
+
46
+ # if [ "$num_images_in_input" == 1 ]; then
47
+ # note_parts+=("3rd_person_img")
48
+ # else
49
+ # note_parts+=("3rd_person_img_and_wrist")
50
+ # fi
51
+
52
+ # if [ "$use_l1_regression" = "True" ]; then
53
+ # note_parts+=("proprio_state")
54
+ # fi
55
+
56
+ # if [ "$use_film" = "True" ]; then
57
+ # note_parts+=("Film")
58
+ # fi
59
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
60
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
61
+
62
+ #========== enter environment ==========#
63
+ conda activate openvla-oft
64
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
65
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
66
+
67
+ #========== run ==========#
68
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
69
+ --vla_path "$vla_path" \
70
+ --data_root_dir "$data_root_dir" \
71
+ --dataset_name "$dataset_name" \
72
+ --run_root_dir "$run_root_dir" \
73
+ --use_l1_regression "$use_l1_regression" \
74
+ --use_diffusion "$use_diffusion" \
75
+ --use_film "$use_film" \
76
+ --num_images_in_input "$num_images_in_input" \
77
+ --use_proprio "$use_proprio" \
78
+ --batch_size "$batch_size" \
79
+ --learning_rate 1e-4 \
80
+ --num_steps_before_decay "$num_steps_before_decay" \
81
+ --max_steps "$max_steps" \
82
+ --save_freq "$save_freq" \
83
+ --save_latest_checkpoint_only False \
84
+ --image_aug True \
85
+ --lora_rank 32 \
86
+ --wandb_entity "$wandb_entity" \
87
+ --wandb_project "$wandb_project" \
88
+ --wandb_log_freq "$wandb_log_freq" \
89
+ --run_id_note "$run_id_note_value" \
90
+ --use_predict_future_prop "$use_predict_future_prop" \
91
+ --use_action_ts_head "$use_action_ts_head" \
92
+ --use_one_embed "$use_one_embed" \
93
+ --use_multi_scaling "$use_multi_scaling" \
94
+ --mlp_type "$mlp_type" \
95
+ --decoder_num_blocks "$decoder_num_blocks" \
96
+ --robot_platform "$robot_platform" \
97
+ --proj_type "$proj_type" \
98
+ --ffn_type "$ffn_type" \
99
+ --expand_inner_ratio "$expand_inner_ratio" \
100
+ --linear_drop_ratio "$linear_drop_ratio" \
101
+ --multi_query_norm_type "$multi_query_norm_type" \
102
+ --multi_queries_num "$multi_queries_num" \
103
+ --action_norm "$action_norm" \
104
+ --use_patch_wise_loss "$use_patch_wise_loss"
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_inter.sh ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=simvla_twin2
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_vr_25
6
+ use_predict_future_prop=False
7
+ batch_size=4
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=4
13
+ robot_platform=aloha
14
+ proj_type=onlynorm
15
+ ffn_type=swiglu
16
+ expand_inner_ratio=1
17
+ linear_drop_ratio=0.0
18
+ multi_queries_num=25
19
+ multi_query_norm_type=layernorm
20
+ action_norm=layernorm
21
+ use_keyframe_prediction=True
22
+ MODE=${RUN_MODE}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
23
+ #========== !NOTE! ==========#
24
+ use_l1_regression=True
25
+ num_images_in_input=3
26
+ wandb_entity=chenghaha
27
+ wandb_project=robotwin
28
+ wandb_log_freq=1
29
+ use_proprio=True
30
+ use_diffusion=False
31
+ use_film=True
32
+ num_steps_before_decay=1000
33
+ save_freq=2000
34
+ max_steps=2000
35
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
36
+ data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
37
+ dataset_name=grab_roller_aloha_agilex_50
38
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
39
+ #========== get run_id ==========#
40
+ note_parts=("${MODE}")
41
+
42
+ # if [ "$use_l1_regression" = "True" ]; then
43
+ # note_parts+=("L1_regression")
44
+ # fi
45
+
46
+ # if [ "$num_images_in_input" == 1 ]; then
47
+ # note_parts+=("3rd_person_img")
48
+ # else
49
+ # note_parts+=("3rd_person_img_and_wrist")
50
+ # fi
51
+
52
+ # if [ "$use_l1_regression" = "True" ]; then
53
+ # note_parts+=("proprio_state")
54
+ # fi
55
+
56
+ # if [ "$use_film" = "True" ]; then
57
+ # note_parts+=("Film")
58
+ # fi
59
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
60
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
61
+
62
+ #========== enter environment ==========#
63
+ conda activate openvla-oft
64
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
65
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
66
+
67
+ #========== run ==========#
68
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
69
+ --vla_path "$vla_path" \
70
+ --data_root_dir "$data_root_dir" \
71
+ --dataset_name "$dataset_name" \
72
+ --run_root_dir "$run_root_dir" \
73
+ --use_l1_regression "$use_l1_regression" \
74
+ --use_diffusion "$use_diffusion" \
75
+ --use_film "$use_film" \
76
+ --num_images_in_input "$num_images_in_input" \
77
+ --use_proprio "$use_proprio" \
78
+ --batch_size "$batch_size" \
79
+ --learning_rate 1e-4 \
80
+ --num_steps_before_decay "$num_steps_before_decay" \
81
+ --max_steps "$max_steps" \
82
+ --save_freq "$save_freq" \
83
+ --save_latest_checkpoint_only False \
84
+ --image_aug True \
85
+ --lora_rank 32 \
86
+ --wandb_entity "$wandb_entity" \
87
+ --wandb_project "$wandb_project" \
88
+ --wandb_log_freq "$wandb_log_freq" \
89
+ --run_id_note "$run_id_note_value" \
90
+ --use_predict_future_prop "$use_predict_future_prop" \
91
+ --use_action_ts_head "$use_action_ts_head" \
92
+ --use_one_embed "$use_one_embed" \
93
+ --use_multi_scaling "$use_multi_scaling" \
94
+ --mlp_type "$mlp_type" \
95
+ --decoder_num_blocks "$decoder_num_blocks" \
96
+ --robot_platform "$robot_platform" \
97
+ --proj_type "$proj_type" \
98
+ --ffn_type "$ffn_type" \
99
+ --expand_inner_ratio "$expand_inner_ratio" \
100
+ --linear_drop_ratio "$linear_drop_ratio" \
101
+ --multi_query_norm_type "$multi_query_norm_type" \
102
+ --multi_queries_num "$multi_queries_num" \
103
+ --action_norm "$action_norm" \
104
+ --use_keyframe_prediction "$use_keyframe_prediction"
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_vggt.sh ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=simvla_twin2
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_vggt_25
6
+ use_predict_future_prop=False
7
+ batch_size=4
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=2
13
+ robot_platform=aloha
14
+ proj_type=gelu_linear
15
+ ffn_type=gelu
16
+ expand_inner_ratio=1
17
+ linear_drop_ratio=0.0
18
+ multi_queries_num=25
19
+ multi_query_norm_type=layernorm
20
+ action_norm=layernorm
21
+ use_fredf=False
22
+ MODE=${RUN_MODE}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
23
+ #========== !NOTE! ==========#
24
+ use_l1_regression=True
25
+ num_images_in_input=3
26
+ wandb_entity=chenghaha
27
+ wandb_project=robotwin
28
+ wandb_log_freq=1
29
+ use_proprio=True
30
+ use_diffusion=False
31
+ use_film=True
32
+ num_steps_before_decay=1000
33
+ save_freq=2000
34
+ max_steps=2000
35
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
36
+ data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
37
+ dataset_name=grab_roller_aloha_agilex_50
38
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
39
+ vggt_model_path=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/facebook/VGGT-1B/model.pt
40
+ use_3d_visual_regression=True
41
+ #========== get run_id ==========#
42
+ note_parts=("${MODE}")
43
+
44
+ # if [ "$use_l1_regression" = "True" ]; then
45
+ # note_parts+=("L1_regression")
46
+ # fi
47
+
48
+ # if [ "$num_images_in_input" == 1 ]; then
49
+ # note_parts+=("3rd_person_img")
50
+ # else
51
+ # note_parts+=("3rd_person_img_and_wrist")
52
+ # fi
53
+
54
+ # if [ "$use_l1_regression" = "True" ]; then
55
+ # note_parts+=("proprio_state")
56
+ # fi
57
+
58
+ # if [ "$use_film" = "True" ]; then
59
+ # note_parts+=("Film")
60
+ # fi
61
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
62
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
63
+
64
+ #========== enter environment ==========#
65
+ conda activate openvla-oft
66
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
67
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
68
+
69
+ #========== run ==========#
70
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune_3d.py \
71
+ --vla_path "$vla_path" \
72
+ --data_root_dir "$data_root_dir" \
73
+ --dataset_name "$dataset_name" \
74
+ --run_root_dir "$run_root_dir" \
75
+ --use_l1_regression "$use_l1_regression" \
76
+ --use_diffusion "$use_diffusion" \
77
+ --use_film "$use_film" \
78
+ --num_images_in_input "$num_images_in_input" \
79
+ --use_proprio "$use_proprio" \
80
+ --batch_size "$batch_size" \
81
+ --learning_rate 1e-4 \
82
+ --num_steps_before_decay "$num_steps_before_decay" \
83
+ --max_steps "$max_steps" \
84
+ --save_freq "$save_freq" \
85
+ --save_latest_checkpoint_only False \
86
+ --image_aug True \
87
+ --lora_rank 32 \
88
+ --wandb_entity "$wandb_entity" \
89
+ --wandb_project "$wandb_project" \
90
+ --wandb_log_freq "$wandb_log_freq" \
91
+ --run_id_note "$run_id_note_value" \
92
+ --use_predict_future_prop "$use_predict_future_prop" \
93
+ --use_action_ts_head "$use_action_ts_head" \
94
+ --use_one_embed "$use_one_embed" \
95
+ --use_multi_scaling "$use_multi_scaling" \
96
+ --mlp_type "$mlp_type" \
97
+ --decoder_num_blocks "$decoder_num_blocks" \
98
+ --robot_platform "$robot_platform" \
99
+ --proj_type "$proj_type" \
100
+ --ffn_type "$ffn_type" \
101
+ --expand_inner_ratio "$expand_inner_ratio" \
102
+ --linear_drop_ratio "$linear_drop_ratio" \
103
+ --multi_query_norm_type "$multi_query_norm_type" \
104
+ --multi_queries_num "$multi_queries_num" \
105
+ --action_norm "$action_norm" \
106
+ --use_fredf "$use_fredf" \
107
+ --vggt_model_path "$vggt_model_path" \
108
+ --use_3d_visual_regression "$use_3d_visual_regression"
run_scripts/ffn_q2a/aloha/test_aloha_robotwin2_ffn_25_vr.sh ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #========== settings ==========#
2
+ PROJECT_PATH=simvla_twin2
3
+ ROOT_PATH=/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137
4
+ #========== !NOTE! ==========#
5
+ RUN_MODE=simvla_vr_25
6
+ use_predict_future_prop=False
7
+ batch_size=4
8
+ use_action_ts_head=True
9
+ use_one_embed=True
10
+ use_multi_scaling=False
11
+ mlp_type=ffn
12
+ decoder_num_blocks=4
13
+ robot_platform=aloha
14
+ proj_type=onlynorm
15
+ ffn_type=swiglu
16
+ expand_inner_ratio=1
17
+ linear_drop_ratio=0.0
18
+ multi_queries_num=25
19
+ multi_query_norm_type=layernorm
20
+ action_norm=layernorm
21
+ use_visual_regression=True
22
+ MODE=${RUN_MODE}_inner${expand_inner_ratio}_proj_type_${proj_type}_ffn_type_${ffn_type}_mlp_${mlp_type}_decoder_num_blocks_${decoder_num_blocks}
23
+ #========== !NOTE! ==========#
24
+ use_l1_regression=True
25
+ num_images_in_input=3
26
+ wandb_entity=chenghaha
27
+ wandb_project=robotwin
28
+ wandb_log_freq=1
29
+ use_proprio=True
30
+ use_diffusion=False
31
+ use_film=True
32
+ num_steps_before_decay=1000
33
+ save_freq=2000
34
+ max_steps=2000
35
+ vla_path=$ROOT_PATH/ai_models/openvla/openvla-7b
36
+ data_root_dir=$ROOT_PATH/datasets/TianxingChen/RoboTwin2.0/tfds
37
+ dataset_name=place_dual_shoes_aloha_agilex_50
38
+ run_root_dir=$ROOT_PATH/vla_projects/$PROJECT_PATH/results/$RUN_MODE
39
+ #========== get run_id ==========#
40
+ note_parts=("${MODE}")
41
+
42
+ # if [ "$use_l1_regression" = "True" ]; then
43
+ # note_parts+=("L1_regression")
44
+ # fi
45
+
46
+ # if [ "$num_images_in_input" == 1 ]; then
47
+ # note_parts+=("3rd_person_img")
48
+ # else
49
+ # note_parts+=("3rd_person_img_and_wrist")
50
+ # fi
51
+
52
+ # if [ "$use_l1_regression" = "True" ]; then
53
+ # note_parts+=("proprio_state")
54
+ # fi
55
+
56
+ # if [ "$use_film" = "True" ]; then
57
+ # note_parts+=("Film")
58
+ # fi
59
+ note_parts+=("M$max_steps-F$save_freq-D$num_steps_before_decay")
60
+ run_id_note_value=$(IFS='--'; echo "${note_parts[*]}")
61
+
62
+ #========== enter environment ==========#
63
+ conda activate openvla-oft
64
+ cd $ROOT_PATH/vla_projects/$PROJECT_PATH
65
+ export PYTHONPATH=$ROOT_PATH/vla_projects/$PROJECT_PATH
66
+
67
+ #========== run ==========#
68
+ WANDB_CONSOLE=off WANDB_MODE=offline torchrun --standalone --nnodes 1 --nproc-per-node 4 vla-scripts/finetune.py \
69
+ --vla_path "$vla_path" \
70
+ --data_root_dir "$data_root_dir" \
71
+ --dataset_name "$dataset_name" \
72
+ --run_root_dir "$run_root_dir" \
73
+ --use_l1_regression "$use_l1_regression" \
74
+ --use_diffusion "$use_diffusion" \
75
+ --use_film "$use_film" \
76
+ --num_images_in_input "$num_images_in_input" \
77
+ --use_proprio "$use_proprio" \
78
+ --batch_size "$batch_size" \
79
+ --learning_rate 1e-4 \
80
+ --num_steps_before_decay "$num_steps_before_decay" \
81
+ --max_steps "$max_steps" \
82
+ --save_freq "$save_freq" \
83
+ --save_latest_checkpoint_only False \
84
+ --image_aug True \
85
+ --lora_rank 32 \
86
+ --wandb_entity "$wandb_entity" \
87
+ --wandb_project "$wandb_project" \
88
+ --wandb_log_freq "$wandb_log_freq" \
89
+ --run_id_note "$run_id_note_value" \
90
+ --use_predict_future_prop "$use_predict_future_prop" \
91
+ --use_action_ts_head "$use_action_ts_head" \
92
+ --use_one_embed "$use_one_embed" \
93
+ --use_multi_scaling "$use_multi_scaling" \
94
+ --mlp_type "$mlp_type" \
95
+ --decoder_num_blocks "$decoder_num_blocks" \
96
+ --robot_platform "$robot_platform" \
97
+ --proj_type "$proj_type" \
98
+ --ffn_type "$ffn_type" \
99
+ --expand_inner_ratio "$expand_inner_ratio" \
100
+ --linear_drop_ratio "$linear_drop_ratio" \
101
+ --multi_query_norm_type "$multi_query_norm_type" \
102
+ --multi_queries_num "$multi_queries_num" \
103
+ --action_norm "$action_norm" \
104
+ --use_visual_regression "$use_visual_regression"
vggt/heads/__pycache__/camera_head.cpython-310.pyc ADDED
Binary file (4.37 kB). View file
 
vggt/heads/__pycache__/dpt_head.cpython-310.pyc ADDED
Binary file (12.7 kB). View file
 
vggt/heads/__pycache__/head_act.cpython-310.pyc ADDED
Binary file (3.2 kB). View file
 
vggt/heads/__pycache__/track_head.cpython-310.pyc ADDED
Binary file (3.52 kB). View file
 
vggt/heads/__pycache__/utils.cpython-310.pyc ADDED
Binary file (3.23 kB). View file
 
vggt/heads/camera_head.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ import numpy as np
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+
14
+ from vggt.layers import Mlp
15
+ from vggt.layers.block import Block
16
+ from vggt.heads.head_act import activate_pose
17
+
18
+
19
+ class CameraHead(nn.Module):
20
+ """
21
+ CameraHead predicts camera parameters from token representations using iterative refinement.
22
+
23
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int = 2048,
29
+ trunk_depth: int = 4,
30
+ pose_encoding_type: str = "absT_quaR_FoV",
31
+ num_heads: int = 16,
32
+ mlp_ratio: int = 4,
33
+ init_values: float = 0.01,
34
+ trans_act: str = "linear",
35
+ quat_act: str = "linear",
36
+ fl_act: str = "relu", # Field of view activations: ensures FOV values are positive.
37
+ ):
38
+ super().__init__()
39
+
40
+ if pose_encoding_type == "absT_quaR_FoV":
41
+ self.target_dim = 9
42
+ else:
43
+ raise ValueError(f"Unsupported camera encoding type: {pose_encoding_type}")
44
+
45
+ self.trans_act = trans_act
46
+ self.quat_act = quat_act
47
+ self.fl_act = fl_act
48
+ self.trunk_depth = trunk_depth
49
+
50
+ # Build the trunk using a sequence of transformer blocks.
51
+ self.trunk = nn.Sequential(
52
+ *[
53
+ Block(
54
+ dim=dim_in,
55
+ num_heads=num_heads,
56
+ mlp_ratio=mlp_ratio,
57
+ init_values=init_values,
58
+ )
59
+ for _ in range(trunk_depth)
60
+ ]
61
+ )
62
+
63
+ # Normalizations for camera token and trunk output.
64
+ self.token_norm = nn.LayerNorm(dim_in)
65
+ self.trunk_norm = nn.LayerNorm(dim_in)
66
+
67
+ # Learnable empty camera pose token.
68
+ self.empty_pose_tokens = nn.Parameter(torch.zeros(1, 1, self.target_dim))
69
+ self.embed_pose = nn.Linear(self.target_dim, dim_in)
70
+
71
+ # Module for producing modulation parameters: shift, scale, and a gate.
72
+ self.poseLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(dim_in, 3 * dim_in, bias=True))
73
+
74
+ # Adaptive layer normalization without affine parameters.
75
+ self.adaln_norm = nn.LayerNorm(dim_in, elementwise_affine=False, eps=1e-6)
76
+ self.pose_branch = Mlp(
77
+ in_features=dim_in,
78
+ hidden_features=dim_in // 2,
79
+ out_features=self.target_dim,
80
+ drop=0,
81
+ )
82
+
83
+ def forward(self, aggregated_tokens_list: list, num_iterations: int = 4) -> list:
84
+ """
85
+ Forward pass to predict camera parameters.
86
+
87
+ Args:
88
+ aggregated_tokens_list (list): List of token tensors from the network;
89
+ the last tensor is used for prediction.
90
+ num_iterations (int, optional): Number of iterative refinement steps. Defaults to 4.
91
+
92
+ Returns:
93
+ list: A list of predicted camera encodings (post-activation) from each iteration.
94
+ """
95
+ # Use tokens from the last block for camera prediction.
96
+ tokens = aggregated_tokens_list[-1]
97
+
98
+ # Extract the camera tokens
99
+ pose_tokens = tokens[:, :, 0]
100
+ pose_tokens = self.token_norm(pose_tokens)
101
+
102
+ pred_pose_enc_list = self.trunk_fn(pose_tokens, num_iterations)
103
+ return pred_pose_enc_list
104
+
105
+ def trunk_fn(self, pose_tokens: torch.Tensor, num_iterations: int) -> list:
106
+ """
107
+ Iteratively refine camera pose predictions.
108
+
109
+ Args:
110
+ pose_tokens (torch.Tensor): Normalized camera tokens with shape [B, 1, C].
111
+ num_iterations (int): Number of refinement iterations.
112
+
113
+ Returns:
114
+ list: List of activated camera encodings from each iteration.
115
+ """
116
+ B, S, C = pose_tokens.shape # S is expected to be 1.
117
+ pred_pose_enc = None
118
+ pred_pose_enc_list = []
119
+
120
+ for _ in range(num_iterations):
121
+ # Use a learned empty pose for the first iteration.
122
+ if pred_pose_enc is None:
123
+ module_input = self.embed_pose(self.empty_pose_tokens.expand(B, S, -1))
124
+ else:
125
+ # Detach the previous prediction to avoid backprop through time.
126
+ pred_pose_enc = pred_pose_enc.detach()
127
+ module_input = self.embed_pose(pred_pose_enc)
128
+
129
+ # Generate modulation parameters and split them into shift, scale, and gate components.
130
+ shift_msa, scale_msa, gate_msa = self.poseLN_modulation(module_input).chunk(3, dim=-1)
131
+
132
+ # Adaptive layer normalization and modulation.
133
+ pose_tokens_modulated = gate_msa * modulate(self.adaln_norm(pose_tokens), shift_msa, scale_msa)
134
+ pose_tokens_modulated = pose_tokens_modulated + pose_tokens
135
+
136
+ pose_tokens_modulated = self.trunk(pose_tokens_modulated)
137
+ # Compute the delta update for the pose encoding.
138
+ pred_pose_enc_delta = self.pose_branch(self.trunk_norm(pose_tokens_modulated))
139
+
140
+ if pred_pose_enc is None:
141
+ pred_pose_enc = pred_pose_enc_delta
142
+ else:
143
+ pred_pose_enc = pred_pose_enc + pred_pose_enc_delta
144
+
145
+ # Apply final activation functions for translation, quaternion, and field-of-view.
146
+ activated_pose = activate_pose(
147
+ pred_pose_enc,
148
+ trans_act=self.trans_act,
149
+ quat_act=self.quat_act,
150
+ fl_act=self.fl_act,
151
+ )
152
+ pred_pose_enc_list.append(activated_pose)
153
+
154
+ return pred_pose_enc_list
155
+
156
+
157
+ def modulate(x: torch.Tensor, shift: torch.Tensor, scale: torch.Tensor) -> torch.Tensor:
158
+ """
159
+ Modulate the input tensor using scaling and shifting parameters.
160
+ """
161
+ # modified from https://github.com/facebookresearch/DiT/blob/796c29e532f47bba17c5b9c5eb39b9354b8b7c64/models.py#L19
162
+ return x * (1 + scale) + shift
vggt/heads/dpt_head.py ADDED
@@ -0,0 +1,497 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ # Inspired by https://github.com/DepthAnything/Depth-Anything-V2
9
+
10
+
11
+ import os
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ from .head_act import activate_head
18
+ from .utils import create_uv_grid, position_grid_to_embed
19
+
20
+
21
+ class DPTHead(nn.Module):
22
+ """
23
+ DPT Head for dense prediction tasks.
24
+
25
+ This implementation follows the architecture described in "Vision Transformers for Dense Prediction"
26
+ (https://arxiv.org/abs/2103.13413). The DPT head processes features from a vision transformer
27
+ backbone and produces dense predictions by fusing multi-scale features.
28
+
29
+ Args:
30
+ dim_in (int): Input dimension (channels).
31
+ patch_size (int, optional): Patch size. Default is 14.
32
+ output_dim (int, optional): Number of output channels. Default is 4.
33
+ activation (str, optional): Activation type. Default is "inv_log".
34
+ conf_activation (str, optional): Confidence activation type. Default is "expp1".
35
+ features (int, optional): Feature channels for intermediate representations. Default is 256.
36
+ out_channels (List[int], optional): Output channels for each intermediate layer.
37
+ intermediate_layer_idx (List[int], optional): Indices of layers from aggregated tokens used for DPT.
38
+ pos_embed (bool, optional): Whether to use positional embedding. Default is True.
39
+ feature_only (bool, optional): If True, return features only without the last several layers and activation head. Default is False.
40
+ down_ratio (int, optional): Downscaling factor for the output resolution. Default is 1.
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ dim_in: int,
46
+ patch_size: int = 14,
47
+ output_dim: int = 4,
48
+ activation: str = "inv_log",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: List[int] = [256, 512, 1024, 1024],
52
+ intermediate_layer_idx: List[int] = [4, 11, 17, 23],
53
+ pos_embed: bool = True,
54
+ feature_only: bool = False,
55
+ down_ratio: int = 1,
56
+ ) -> None:
57
+ super(DPTHead, self).__init__()
58
+ self.patch_size = patch_size
59
+ self.activation = activation
60
+ self.conf_activation = conf_activation
61
+ self.pos_embed = pos_embed
62
+ self.feature_only = feature_only
63
+ self.down_ratio = down_ratio
64
+ self.intermediate_layer_idx = intermediate_layer_idx
65
+
66
+ self.norm = nn.LayerNorm(dim_in)
67
+
68
+ # Projection layers for each output channel from tokens.
69
+ self.projects = nn.ModuleList(
70
+ [
71
+ nn.Conv2d(
72
+ in_channels=dim_in,
73
+ out_channels=oc,
74
+ kernel_size=1,
75
+ stride=1,
76
+ padding=0,
77
+ )
78
+ for oc in out_channels
79
+ ]
80
+ )
81
+
82
+ # Resize layers for upsampling feature maps.
83
+ self.resize_layers = nn.ModuleList(
84
+ [
85
+ nn.ConvTranspose2d(
86
+ in_channels=out_channels[0], out_channels=out_channels[0], kernel_size=4, stride=4, padding=0
87
+ ),
88
+ nn.ConvTranspose2d(
89
+ in_channels=out_channels[1], out_channels=out_channels[1], kernel_size=2, stride=2, padding=0
90
+ ),
91
+ nn.Identity(),
92
+ nn.Conv2d(
93
+ in_channels=out_channels[3], out_channels=out_channels[3], kernel_size=3, stride=2, padding=1
94
+ ),
95
+ ]
96
+ )
97
+
98
+ self.scratch = _make_scratch(
99
+ out_channels,
100
+ features,
101
+ expand=False,
102
+ )
103
+
104
+ # Attach additional modules to scratch.
105
+ self.scratch.stem_transpose = None
106
+ self.scratch.refinenet1 = _make_fusion_block(features)
107
+ self.scratch.refinenet2 = _make_fusion_block(features)
108
+ self.scratch.refinenet3 = _make_fusion_block(features)
109
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
110
+
111
+ head_features_1 = features
112
+ head_features_2 = 32
113
+
114
+ if feature_only:
115
+ self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1, kernel_size=3, stride=1, padding=1)
116
+ else:
117
+ self.scratch.output_conv1 = nn.Conv2d(
118
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
119
+ )
120
+ conv2_in_channels = head_features_1 // 2
121
+
122
+ self.scratch.output_conv2 = nn.Sequential(
123
+ nn.Conv2d(conv2_in_channels, head_features_2, kernel_size=3, stride=1, padding=1),
124
+ nn.ReLU(inplace=True),
125
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
126
+ )
127
+
128
+ def forward(
129
+ self,
130
+ aggregated_tokens_list: List[torch.Tensor],
131
+ images: torch.Tensor,
132
+ patch_start_idx: int,
133
+ frames_chunk_size: int = 8,
134
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
135
+ """
136
+ Forward pass through the DPT head, supports processing by chunking frames.
137
+ Args:
138
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
139
+ images (Tensor): Input images with shape [B, S, 3, H, W], in range [0, 1].
140
+ patch_start_idx (int): Starting index for patch tokens in the token sequence.
141
+ Used to separate patch tokens from other tokens (e.g., camera or register tokens).
142
+ frames_chunk_size (int, optional): Number of frames to process in each chunk.
143
+ If None or larger than S, all frames are processed at once. Default: 8.
144
+
145
+ Returns:
146
+ Tensor or Tuple[Tensor, Tensor]:
147
+ - If feature_only=True: Feature maps with shape [B, S, C, H, W]
148
+ - Otherwise: Tuple of (predictions, confidence) both with shape [B, S, 1, H, W]
149
+ """
150
+ B, S, _, H, W = images.shape
151
+
152
+ # If frames_chunk_size is not specified or greater than S, process all frames at once
153
+ if frames_chunk_size is None or frames_chunk_size >= S:
154
+ return self._forward_impl(aggregated_tokens_list, images, patch_start_idx)
155
+
156
+ # Otherwise, process frames in chunks to manage memory usage
157
+ assert frames_chunk_size > 0
158
+
159
+ # Process frames in batches
160
+ all_preds = []
161
+ all_conf = []
162
+
163
+ for frames_start_idx in range(0, S, frames_chunk_size):
164
+ frames_end_idx = min(frames_start_idx + frames_chunk_size, S)
165
+
166
+ # Process batch of frames
167
+ if self.feature_only:
168
+ chunk_output = self._forward_impl(
169
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
170
+ )
171
+ all_preds.append(chunk_output)
172
+ else:
173
+ chunk_preds, chunk_conf = self._forward_impl(
174
+ aggregated_tokens_list, images, patch_start_idx, frames_start_idx, frames_end_idx
175
+ )
176
+ all_preds.append(chunk_preds)
177
+ all_conf.append(chunk_conf)
178
+
179
+ # Concatenate results along the sequence dimension
180
+ if self.feature_only:
181
+ return torch.cat(all_preds, dim=1)
182
+ else:
183
+ return torch.cat(all_preds, dim=1), torch.cat(all_conf, dim=1)
184
+
185
+ def _forward_impl(
186
+ self,
187
+ aggregated_tokens_list: List[torch.Tensor],
188
+ images: torch.Tensor,
189
+ patch_start_idx: int,
190
+ frames_start_idx: int = None,
191
+ frames_end_idx: int = None,
192
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
193
+ """
194
+ Implementation of the forward pass through the DPT head.
195
+
196
+ This method processes a specific chunk of frames from the sequence.
197
+
198
+ Args:
199
+ aggregated_tokens_list (List[Tensor]): List of token tensors from different transformer layers.
200
+ images (Tensor): Input images with shape [B, S, 3, H, W].
201
+ patch_start_idx (int): Starting index for patch tokens.
202
+ frames_start_idx (int, optional): Starting index for frames to process.
203
+ frames_end_idx (int, optional): Ending index for frames to process.
204
+
205
+ Returns:
206
+ Tensor or Tuple[Tensor, Tensor]: Feature maps or (predictions, confidence).
207
+ """
208
+ if frames_start_idx is not None and frames_end_idx is not None:
209
+ images = images[:, frames_start_idx:frames_end_idx]
210
+
211
+ B, S, _, H, W = images.shape
212
+
213
+ patch_h, patch_w = H // self.patch_size, W // self.patch_size
214
+
215
+ out = []
216
+ dpt_idx = 0
217
+
218
+ for layer_idx in self.intermediate_layer_idx:
219
+ x = aggregated_tokens_list[layer_idx][:, :, patch_start_idx:]
220
+
221
+ # Select frames if processing a chunk
222
+ if frames_start_idx is not None and frames_end_idx is not None:
223
+ x = x[:, frames_start_idx:frames_end_idx]
224
+
225
+ x = x.view(B * S, -1, x.shape[-1])
226
+
227
+ x = self.norm(x)
228
+
229
+ x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
230
+
231
+ x = self.projects[dpt_idx](x)
232
+ if self.pos_embed:
233
+ x = self._apply_pos_embed(x, W, H)
234
+ x = self.resize_layers[dpt_idx](x)
235
+
236
+ out.append(x)
237
+ dpt_idx += 1
238
+
239
+ # Fuse features from multiple layers.
240
+ out = self.scratch_forward(out)
241
+ # Interpolate fused output to match target image resolution.
242
+ out = custom_interpolate(
243
+ out,
244
+ (int(patch_h * self.patch_size / self.down_ratio), int(patch_w * self.patch_size / self.down_ratio)),
245
+ mode="bilinear",
246
+ align_corners=True,
247
+ )
248
+
249
+ if self.pos_embed:
250
+ out = self._apply_pos_embed(out, W, H)
251
+
252
+ if self.feature_only:
253
+ return out.view(B, S, *out.shape[1:])
254
+
255
+ out = self.scratch.output_conv2(out)
256
+ preds, conf = activate_head(out, activation=self.activation, conf_activation=self.conf_activation)
257
+
258
+ preds = preds.view(B, S, *preds.shape[1:])
259
+ conf = conf.view(B, S, *conf.shape[1:])
260
+ return preds, conf
261
+
262
+ def _apply_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
263
+ """
264
+ Apply positional embedding to tensor x.
265
+ """
266
+ patch_w = x.shape[-1]
267
+ patch_h = x.shape[-2]
268
+ pos_embed = create_uv_grid(patch_w, patch_h, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
269
+ pos_embed = position_grid_to_embed(pos_embed, x.shape[1])
270
+ pos_embed = pos_embed * ratio
271
+ pos_embed = pos_embed.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
272
+ return x + pos_embed
273
+
274
+ def scratch_forward(self, features: List[torch.Tensor]) -> torch.Tensor:
275
+ """
276
+ Forward pass through the fusion blocks.
277
+
278
+ Args:
279
+ features (List[Tensor]): List of feature maps from different layers.
280
+
281
+ Returns:
282
+ Tensor: Fused feature map.
283
+ """
284
+ layer_1, layer_2, layer_3, layer_4 = features
285
+
286
+ layer_1_rn = self.scratch.layer1_rn(layer_1)
287
+ layer_2_rn = self.scratch.layer2_rn(layer_2)
288
+ layer_3_rn = self.scratch.layer3_rn(layer_3)
289
+ layer_4_rn = self.scratch.layer4_rn(layer_4)
290
+
291
+ out = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
292
+ del layer_4_rn, layer_4
293
+
294
+ out = self.scratch.refinenet3(out, layer_3_rn, size=layer_2_rn.shape[2:])
295
+ del layer_3_rn, layer_3
296
+
297
+ out = self.scratch.refinenet2(out, layer_2_rn, size=layer_1_rn.shape[2:])
298
+ del layer_2_rn, layer_2
299
+
300
+ out = self.scratch.refinenet1(out, layer_1_rn)
301
+ del layer_1_rn, layer_1
302
+
303
+ out = self.scratch.output_conv1(out)
304
+ return out
305
+
306
+
307
+ ################################################################################
308
+ # Modules
309
+ ################################################################################
310
+
311
+
312
+ def _make_fusion_block(features: int, size: int = None, has_residual: bool = True, groups: int = 1) -> nn.Module:
313
+ return FeatureFusionBlock(
314
+ features,
315
+ nn.ReLU(inplace=True),
316
+ deconv=False,
317
+ bn=False,
318
+ expand=False,
319
+ align_corners=True,
320
+ size=size,
321
+ has_residual=has_residual,
322
+ groups=groups,
323
+ )
324
+
325
+
326
+ def _make_scratch(in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False) -> nn.Module:
327
+ scratch = nn.Module()
328
+ out_shape1 = out_shape
329
+ out_shape2 = out_shape
330
+ out_shape3 = out_shape
331
+ if len(in_shape) >= 4:
332
+ out_shape4 = out_shape
333
+
334
+ if expand:
335
+ out_shape1 = out_shape
336
+ out_shape2 = out_shape * 2
337
+ out_shape3 = out_shape * 4
338
+ if len(in_shape) >= 4:
339
+ out_shape4 = out_shape * 8
340
+
341
+ scratch.layer1_rn = nn.Conv2d(
342
+ in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
343
+ )
344
+ scratch.layer2_rn = nn.Conv2d(
345
+ in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
346
+ )
347
+ scratch.layer3_rn = nn.Conv2d(
348
+ in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
349
+ )
350
+ if len(in_shape) >= 4:
351
+ scratch.layer4_rn = nn.Conv2d(
352
+ in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups
353
+ )
354
+ return scratch
355
+
356
+
357
+ class ResidualConvUnit(nn.Module):
358
+ """Residual convolution module."""
359
+
360
+ def __init__(self, features, activation, bn, groups=1):
361
+ """Init.
362
+
363
+ Args:
364
+ features (int): number of features
365
+ """
366
+ super().__init__()
367
+
368
+ self.bn = bn
369
+ self.groups = groups
370
+ self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
371
+ self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups)
372
+
373
+ self.norm1 = None
374
+ self.norm2 = None
375
+
376
+ self.activation = activation
377
+ self.skip_add = nn.quantized.FloatFunctional()
378
+
379
+ def forward(self, x):
380
+ """Forward pass.
381
+
382
+ Args:
383
+ x (tensor): input
384
+
385
+ Returns:
386
+ tensor: output
387
+ """
388
+
389
+ out = self.activation(x)
390
+ out = self.conv1(out)
391
+ if self.norm1 is not None:
392
+ out = self.norm1(out)
393
+
394
+ out = self.activation(out)
395
+ out = self.conv2(out)
396
+ if self.norm2 is not None:
397
+ out = self.norm2(out)
398
+
399
+ return self.skip_add.add(out, x)
400
+
401
+
402
+ class FeatureFusionBlock(nn.Module):
403
+ """Feature fusion block."""
404
+
405
+ def __init__(
406
+ self,
407
+ features,
408
+ activation,
409
+ deconv=False,
410
+ bn=False,
411
+ expand=False,
412
+ align_corners=True,
413
+ size=None,
414
+ has_residual=True,
415
+ groups=1,
416
+ ):
417
+ """Init.
418
+
419
+ Args:
420
+ features (int): number of features
421
+ """
422
+ super(FeatureFusionBlock, self).__init__()
423
+
424
+ self.deconv = deconv
425
+ self.align_corners = align_corners
426
+ self.groups = groups
427
+ self.expand = expand
428
+ out_features = features
429
+ if self.expand == True:
430
+ out_features = features // 2
431
+
432
+ self.out_conv = nn.Conv2d(
433
+ features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=self.groups
434
+ )
435
+
436
+ if has_residual:
437
+ self.resConfUnit1 = ResidualConvUnit(features, activation, bn, groups=self.groups)
438
+
439
+ self.has_residual = has_residual
440
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=self.groups)
441
+
442
+ self.skip_add = nn.quantized.FloatFunctional()
443
+ self.size = size
444
+
445
+ def forward(self, *xs, size=None):
446
+ """Forward pass.
447
+
448
+ Returns:
449
+ tensor: output
450
+ """
451
+ output = xs[0]
452
+
453
+ if self.has_residual:
454
+ res = self.resConfUnit1(xs[1])
455
+ output = self.skip_add.add(output, res)
456
+
457
+ output = self.resConfUnit2(output)
458
+
459
+ if (size is None) and (self.size is None):
460
+ modifier = {"scale_factor": 2}
461
+ elif size is None:
462
+ modifier = {"size": self.size}
463
+ else:
464
+ modifier = {"size": size}
465
+
466
+ output = custom_interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners)
467
+ output = self.out_conv(output)
468
+
469
+ return output
470
+
471
+
472
+ def custom_interpolate(
473
+ x: torch.Tensor,
474
+ size: Tuple[int, int] = None,
475
+ scale_factor: float = None,
476
+ mode: str = "bilinear",
477
+ align_corners: bool = True,
478
+ ) -> torch.Tensor:
479
+ """
480
+ Custom interpolate to avoid INT_MAX issues in nn.functional.interpolate.
481
+ """
482
+ if size is None:
483
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
484
+
485
+ INT_MAX = 1610612736
486
+
487
+ input_elements = size[0] * size[1] * x.shape[0] * x.shape[1]
488
+
489
+ if input_elements > INT_MAX:
490
+ chunks = torch.chunk(x, chunks=(input_elements // INT_MAX) + 1, dim=0)
491
+ interpolated_chunks = [
492
+ nn.functional.interpolate(chunk, size=size, mode=mode, align_corners=align_corners) for chunk in chunks
493
+ ]
494
+ x = torch.cat(interpolated_chunks, dim=0)
495
+ return x.contiguous()
496
+ else:
497
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
vggt/heads/head_act.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+
11
+
12
+ def activate_pose(pred_pose_enc, trans_act="linear", quat_act="linear", fl_act="linear"):
13
+ """
14
+ Activate pose parameters with specified activation functions.
15
+
16
+ Args:
17
+ pred_pose_enc: Tensor containing encoded pose parameters [translation, quaternion, focal length]
18
+ trans_act: Activation type for translation component
19
+ quat_act: Activation type for quaternion component
20
+ fl_act: Activation type for focal length component
21
+
22
+ Returns:
23
+ Activated pose parameters tensor
24
+ """
25
+ T = pred_pose_enc[..., :3]
26
+ quat = pred_pose_enc[..., 3:7]
27
+ fl = pred_pose_enc[..., 7:] # or fov
28
+
29
+ T = base_pose_act(T, trans_act)
30
+ quat = base_pose_act(quat, quat_act)
31
+ fl = base_pose_act(fl, fl_act) # or fov
32
+
33
+ pred_pose_enc = torch.cat([T, quat, fl], dim=-1)
34
+
35
+ return pred_pose_enc
36
+
37
+
38
+ def base_pose_act(pose_enc, act_type="linear"):
39
+ """
40
+ Apply basic activation function to pose parameters.
41
+
42
+ Args:
43
+ pose_enc: Tensor containing encoded pose parameters
44
+ act_type: Activation type ("linear", "inv_log", "exp", "relu")
45
+
46
+ Returns:
47
+ Activated pose parameters
48
+ """
49
+ if act_type == "linear":
50
+ return pose_enc
51
+ elif act_type == "inv_log":
52
+ return inverse_log_transform(pose_enc)
53
+ elif act_type == "exp":
54
+ return torch.exp(pose_enc)
55
+ elif act_type == "relu":
56
+ return F.relu(pose_enc)
57
+ else:
58
+ raise ValueError(f"Unknown act_type: {act_type}")
59
+
60
+
61
+ def activate_head(out, activation="norm_exp", conf_activation="expp1"):
62
+ """
63
+ Process network output to extract 3D points and confidence values.
64
+
65
+ Args:
66
+ out: Network output tensor (B, C, H, W)
67
+ activation: Activation type for 3D points
68
+ conf_activation: Activation type for confidence values
69
+
70
+ Returns:
71
+ Tuple of (3D points tensor, confidence tensor)
72
+ """
73
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
74
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
75
+
76
+ # Split into xyz (first C-1 channels) and confidence (last channel)
77
+ xyz = fmap[:, :, :, :-1]
78
+ conf = fmap[:, :, :, -1]
79
+
80
+ if activation == "norm_exp":
81
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
82
+ xyz_normed = xyz / d
83
+ pts3d = xyz_normed * torch.expm1(d)
84
+ elif activation == "norm":
85
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
86
+ elif activation == "exp":
87
+ pts3d = torch.exp(xyz)
88
+ elif activation == "relu":
89
+ pts3d = F.relu(xyz)
90
+ elif activation == "inv_log":
91
+ pts3d = inverse_log_transform(xyz)
92
+ elif activation == "xy_inv_log":
93
+ xy, z = xyz.split([2, 1], dim=-1)
94
+ z = inverse_log_transform(z)
95
+ pts3d = torch.cat([xy * z, z], dim=-1)
96
+ elif activation == "sigmoid":
97
+ pts3d = torch.sigmoid(xyz)
98
+ elif activation == "linear":
99
+ pts3d = xyz
100
+ else:
101
+ raise ValueError(f"Unknown activation: {activation}")
102
+
103
+ if conf_activation == "expp1":
104
+ conf_out = 1 + conf.exp()
105
+ elif conf_activation == "expp0":
106
+ conf_out = conf.exp()
107
+ elif conf_activation == "sigmoid":
108
+ conf_out = torch.sigmoid(conf)
109
+ else:
110
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
111
+
112
+ return pts3d, conf_out
113
+
114
+
115
+ def inverse_log_transform(y):
116
+ """
117
+ Apply inverse log transform: sign(y) * (exp(|y|) - 1)
118
+
119
+ Args:
120
+ y: Input tensor
121
+
122
+ Returns:
123
+ Transformed tensor
124
+ """
125
+ return torch.sign(y) * (torch.expm1(torch.abs(y)))
vggt/heads/track_head.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch.nn as nn
8
+ from .dpt_head import DPTHead
9
+ from .track_modules.base_track_predictor import BaseTrackerPredictor
10
+
11
+
12
+ class TrackHead(nn.Module):
13
+ """
14
+ Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
15
+ The tracking is performed iteratively, refining predictions over multiple iterations.
16
+ """
17
+
18
+ def __init__(
19
+ self,
20
+ dim_in,
21
+ patch_size=14,
22
+ features=128,
23
+ iters=4,
24
+ predict_conf=True,
25
+ stride=2,
26
+ corr_levels=7,
27
+ corr_radius=4,
28
+ hidden_size=384,
29
+ ):
30
+ """
31
+ Initialize the TrackHead module.
32
+
33
+ Args:
34
+ dim_in (int): Input dimension of tokens from the backbone.
35
+ patch_size (int): Size of image patches used in the vision transformer.
36
+ features (int): Number of feature channels in the feature extractor output.
37
+ iters (int): Number of refinement iterations for tracking predictions.
38
+ predict_conf (bool): Whether to predict confidence scores for tracked points.
39
+ stride (int): Stride value for the tracker predictor.
40
+ corr_levels (int): Number of correlation pyramid levels
41
+ corr_radius (int): Radius for correlation computation, controlling the search area.
42
+ hidden_size (int): Size of hidden layers in the tracker network.
43
+ """
44
+ super().__init__()
45
+
46
+ self.patch_size = patch_size
47
+
48
+ # Feature extractor based on DPT architecture
49
+ # Processes tokens into feature maps for tracking
50
+ self.feature_extractor = DPTHead(
51
+ dim_in=dim_in,
52
+ patch_size=patch_size,
53
+ features=features,
54
+ feature_only=True, # Only output features, no activation
55
+ down_ratio=2, # Reduces spatial dimensions by factor of 2
56
+ pos_embed=False,
57
+ )
58
+
59
+ # Tracker module that predicts point trajectories
60
+ # Takes feature maps and predicts coordinates and visibility
61
+ self.tracker = BaseTrackerPredictor(
62
+ latent_dim=features, # Match the output_dim of feature extractor
63
+ predict_conf=predict_conf,
64
+ stride=stride,
65
+ corr_levels=corr_levels,
66
+ corr_radius=corr_radius,
67
+ hidden_size=hidden_size,
68
+ )
69
+
70
+ self.iters = iters
71
+
72
+ def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
73
+ """
74
+ Forward pass of the TrackHead.
75
+
76
+ Args:
77
+ aggregated_tokens_list (list): List of aggregated tokens from the backbone.
78
+ images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
79
+ B = batch size, S = sequence length.
80
+ patch_start_idx (int): Starting index for patch tokens.
81
+ query_points (torch.Tensor, optional): Initial query points to track.
82
+ If None, points are initialized by the tracker.
83
+ iters (int, optional): Number of refinement iterations. If None, uses self.iters.
84
+
85
+ Returns:
86
+ tuple:
87
+ - coord_preds (torch.Tensor): Predicted coordinates for tracked points.
88
+ - vis_scores (torch.Tensor): Visibility scores for tracked points.
89
+ - conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
90
+ """
91
+ B, S, _, H, W = images.shape
92
+
93
+ # Extract features from tokens
94
+ # feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
95
+ feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
96
+
97
+ # Use default iterations if not specified
98
+ if iters is None:
99
+ iters = self.iters
100
+
101
+ # Perform tracking using the extracted features
102
+ coord_preds, vis_scores, conf_scores = self.tracker(
103
+ query_points=query_points,
104
+ fmaps=feature_maps,
105
+ iters=iters,
106
+ )
107
+
108
+ return coord_preds, vis_scores, conf_scores
vggt/heads/track_modules/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
vggt/heads/track_modules/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (250 Bytes). View file
 
vggt/heads/track_modules/__pycache__/base_track_predictor.cpython-310.pyc ADDED
Binary file (4.36 kB). View file
 
vggt/heads/track_modules/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (6.67 kB). View file
 
vggt/heads/track_modules/__pycache__/modules.cpython-310.pyc ADDED
Binary file (5.36 kB). View file
 
vggt/heads/track_modules/modules.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from functools import partial
12
+ from typing import Callable
13
+ import collections
14
+ from torch import Tensor
15
+ from itertools import repeat
16
+
17
+
18
+ # From PyTorch internals
19
+ def _ntuple(n):
20
+ def parse(x):
21
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
22
+ return tuple(x)
23
+ return tuple(repeat(x, n))
24
+
25
+ return parse
26
+
27
+
28
+ def exists(val):
29
+ return val is not None
30
+
31
+
32
+ def default(val, d):
33
+ return val if exists(val) else d
34
+
35
+
36
+ to_2tuple = _ntuple(2)
37
+
38
+
39
+ class ResidualBlock(nn.Module):
40
+ """
41
+ ResidualBlock: construct a block of two conv layers with residual connections
42
+ """
43
+
44
+ def __init__(self, in_planes, planes, norm_fn="group", stride=1, kernel_size=3):
45
+ super(ResidualBlock, self).__init__()
46
+
47
+ self.conv1 = nn.Conv2d(
48
+ in_planes,
49
+ planes,
50
+ kernel_size=kernel_size,
51
+ padding=1,
52
+ stride=stride,
53
+ padding_mode="zeros",
54
+ )
55
+ self.conv2 = nn.Conv2d(
56
+ planes,
57
+ planes,
58
+ kernel_size=kernel_size,
59
+ padding=1,
60
+ padding_mode="zeros",
61
+ )
62
+ self.relu = nn.ReLU(inplace=True)
63
+
64
+ num_groups = planes // 8
65
+
66
+ if norm_fn == "group":
67
+ self.norm1 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
68
+ self.norm2 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
69
+ if not stride == 1:
70
+ self.norm3 = nn.GroupNorm(num_groups=num_groups, num_channels=planes)
71
+
72
+ elif norm_fn == "batch":
73
+ self.norm1 = nn.BatchNorm2d(planes)
74
+ self.norm2 = nn.BatchNorm2d(planes)
75
+ if not stride == 1:
76
+ self.norm3 = nn.BatchNorm2d(planes)
77
+
78
+ elif norm_fn == "instance":
79
+ self.norm1 = nn.InstanceNorm2d(planes)
80
+ self.norm2 = nn.InstanceNorm2d(planes)
81
+ if not stride == 1:
82
+ self.norm3 = nn.InstanceNorm2d(planes)
83
+
84
+ elif norm_fn == "none":
85
+ self.norm1 = nn.Sequential()
86
+ self.norm2 = nn.Sequential()
87
+ if not stride == 1:
88
+ self.norm3 = nn.Sequential()
89
+ else:
90
+ raise NotImplementedError
91
+
92
+ if stride == 1:
93
+ self.downsample = None
94
+ else:
95
+ self.downsample = nn.Sequential(
96
+ nn.Conv2d(in_planes, planes, kernel_size=1, stride=stride),
97
+ self.norm3,
98
+ )
99
+
100
+ def forward(self, x):
101
+ y = x
102
+ y = self.relu(self.norm1(self.conv1(y)))
103
+ y = self.relu(self.norm2(self.conv2(y)))
104
+
105
+ if self.downsample is not None:
106
+ x = self.downsample(x)
107
+
108
+ return self.relu(x + y)
109
+
110
+
111
+ class Mlp(nn.Module):
112
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
113
+
114
+ def __init__(
115
+ self,
116
+ in_features,
117
+ hidden_features=None,
118
+ out_features=None,
119
+ act_layer=nn.GELU,
120
+ norm_layer=None,
121
+ bias=True,
122
+ drop=0.0,
123
+ use_conv=False,
124
+ ):
125
+ super().__init__()
126
+ out_features = out_features or in_features
127
+ hidden_features = hidden_features or in_features
128
+ bias = to_2tuple(bias)
129
+ drop_probs = to_2tuple(drop)
130
+ linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
131
+
132
+ self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
133
+ self.act = act_layer()
134
+ self.drop1 = nn.Dropout(drop_probs[0])
135
+ self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
136
+ self.drop2 = nn.Dropout(drop_probs[1])
137
+
138
+ def forward(self, x):
139
+ x = self.fc1(x)
140
+ x = self.act(x)
141
+ x = self.drop1(x)
142
+ x = self.fc2(x)
143
+ x = self.drop2(x)
144
+ return x
145
+
146
+
147
+ class AttnBlock(nn.Module):
148
+ def __init__(
149
+ self,
150
+ hidden_size,
151
+ num_heads,
152
+ attn_class: Callable[..., nn.Module] = nn.MultiheadAttention,
153
+ mlp_ratio=4.0,
154
+ **block_kwargs
155
+ ):
156
+ """
157
+ Self attention block
158
+ """
159
+ super().__init__()
160
+
161
+ self.norm1 = nn.LayerNorm(hidden_size)
162
+ self.norm2 = nn.LayerNorm(hidden_size)
163
+
164
+ self.attn = attn_class(embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs)
165
+
166
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
167
+
168
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
169
+
170
+ def forward(self, x, mask=None):
171
+ # Prepare the mask for PyTorch's attention (it expects a different format)
172
+ # attn_mask = mask if mask is not None else None
173
+ # Normalize before attention
174
+ x = self.norm1(x)
175
+
176
+ # PyTorch's MultiheadAttention returns attn_output, attn_output_weights
177
+ # attn_output, _ = self.attn(x, x, x, attn_mask=attn_mask)
178
+
179
+ attn_output, _ = self.attn(x, x, x)
180
+
181
+ # Add & Norm
182
+ x = x + attn_output
183
+ x = x + self.mlp(self.norm2(x))
184
+ return x
185
+
186
+
187
+ class CrossAttnBlock(nn.Module):
188
+ def __init__(self, hidden_size, context_dim, num_heads=1, mlp_ratio=4.0, **block_kwargs):
189
+ """
190
+ Cross attention block
191
+ """
192
+ super().__init__()
193
+
194
+ self.norm1 = nn.LayerNorm(hidden_size)
195
+ self.norm_context = nn.LayerNorm(hidden_size)
196
+ self.norm2 = nn.LayerNorm(hidden_size)
197
+
198
+ self.cross_attn = nn.MultiheadAttention(
199
+ embed_dim=hidden_size, num_heads=num_heads, batch_first=True, **block_kwargs
200
+ )
201
+
202
+ mlp_hidden_dim = int(hidden_size * mlp_ratio)
203
+
204
+ self.mlp = Mlp(in_features=hidden_size, hidden_features=mlp_hidden_dim, drop=0)
205
+
206
+ def forward(self, x, context, mask=None):
207
+ # Normalize inputs
208
+ x = self.norm1(x)
209
+ context = self.norm_context(context)
210
+
211
+ # Apply cross attention
212
+ # Note: nn.MultiheadAttention returns attn_output, attn_output_weights
213
+ attn_output, _ = self.cross_attn(x, context, context, attn_mask=mask)
214
+
215
+ # Add & Norm
216
+ x = x + attn_output
217
+ x = x + self.mlp(self.norm2(x))
218
+ return x
vggt/heads/track_modules/utils.py ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from https://github.com/facebookresearch/vggsfm
8
+ # and https://github.com/facebookresearch/co-tracker/tree/main
9
+
10
+
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+
15
+ from typing import Optional, Tuple, Union
16
+
17
+
18
+ def get_2d_sincos_pos_embed(embed_dim: int, grid_size: Union[int, Tuple[int, int]], return_grid=False) -> torch.Tensor:
19
+ """
20
+ This function initializes a grid and generates a 2D positional embedding using sine and cosine functions.
21
+ It is a wrapper of get_2d_sincos_pos_embed_from_grid.
22
+ Args:
23
+ - embed_dim: The embedding dimension.
24
+ - grid_size: The grid size.
25
+ Returns:
26
+ - pos_embed: The generated 2D positional embedding.
27
+ """
28
+ if isinstance(grid_size, tuple):
29
+ grid_size_h, grid_size_w = grid_size
30
+ else:
31
+ grid_size_h = grid_size_w = grid_size
32
+ grid_h = torch.arange(grid_size_h, dtype=torch.float)
33
+ grid_w = torch.arange(grid_size_w, dtype=torch.float)
34
+ grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
35
+ grid = torch.stack(grid, dim=0)
36
+ grid = grid.reshape([2, 1, grid_size_h, grid_size_w])
37
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
38
+ if return_grid:
39
+ return (
40
+ pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2),
41
+ grid,
42
+ )
43
+ return pos_embed.reshape(1, grid_size_h, grid_size_w, -1).permute(0, 3, 1, 2)
44
+
45
+
46
+ def get_2d_sincos_pos_embed_from_grid(embed_dim: int, grid: torch.Tensor) -> torch.Tensor:
47
+ """
48
+ This function generates a 2D positional embedding from a given grid using sine and cosine functions.
49
+
50
+ Args:
51
+ - embed_dim: The embedding dimension.
52
+ - grid: The grid to generate the embedding from.
53
+
54
+ Returns:
55
+ - emb: The generated 2D positional embedding.
56
+ """
57
+ assert embed_dim % 2 == 0
58
+
59
+ # use half of dimensions to encode grid_h
60
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
61
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
62
+
63
+ emb = torch.cat([emb_h, emb_w], dim=2) # (H*W, D)
64
+ return emb
65
+
66
+
67
+ def get_1d_sincos_pos_embed_from_grid(embed_dim: int, pos: torch.Tensor) -> torch.Tensor:
68
+ """
69
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
70
+
71
+ Args:
72
+ - embed_dim: The embedding dimension.
73
+ - pos: The position to generate the embedding from.
74
+
75
+ Returns:
76
+ - emb: The generated 1D positional embedding.
77
+ """
78
+ assert embed_dim % 2 == 0
79
+ omega = torch.arange(embed_dim // 2, dtype=torch.double)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = torch.sin(out) # (M, D/2)
87
+ emb_cos = torch.cos(out) # (M, D/2)
88
+
89
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
90
+ return emb[None].float()
91
+
92
+
93
+ def get_2d_embedding(xy: torch.Tensor, C: int, cat_coords: bool = True) -> torch.Tensor:
94
+ """
95
+ This function generates a 2D positional embedding from given coordinates using sine and cosine functions.
96
+
97
+ Args:
98
+ - xy: The coordinates to generate the embedding from.
99
+ - C: The size of the embedding.
100
+ - cat_coords: A flag to indicate whether to concatenate the original coordinates to the embedding.
101
+
102
+ Returns:
103
+ - pe: The generated 2D positional embedding.
104
+ """
105
+ B, N, D = xy.shape
106
+ assert D == 2
107
+
108
+ x = xy[:, :, 0:1]
109
+ y = xy[:, :, 1:2]
110
+ div_term = (torch.arange(0, C, 2, device=xy.device, dtype=torch.float32) * (1000.0 / C)).reshape(1, 1, int(C / 2))
111
+
112
+ pe_x = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
113
+ pe_y = torch.zeros(B, N, C, device=xy.device, dtype=torch.float32)
114
+
115
+ pe_x[:, :, 0::2] = torch.sin(x * div_term)
116
+ pe_x[:, :, 1::2] = torch.cos(x * div_term)
117
+
118
+ pe_y[:, :, 0::2] = torch.sin(y * div_term)
119
+ pe_y[:, :, 1::2] = torch.cos(y * div_term)
120
+
121
+ pe = torch.cat([pe_x, pe_y], dim=2) # (B, N, C*3)
122
+ if cat_coords:
123
+ pe = torch.cat([xy, pe], dim=2) # (B, N, C*3+3)
124
+ return pe
125
+
126
+
127
+ def bilinear_sampler(input, coords, align_corners=True, padding_mode="border"):
128
+ r"""Sample a tensor using bilinear interpolation
129
+
130
+ `bilinear_sampler(input, coords)` samples a tensor :attr:`input` at
131
+ coordinates :attr:`coords` using bilinear interpolation. It is the same
132
+ as `torch.nn.functional.grid_sample()` but with a different coordinate
133
+ convention.
134
+
135
+ The input tensor is assumed to be of shape :math:`(B, C, H, W)`, where
136
+ :math:`B` is the batch size, :math:`C` is the number of channels,
137
+ :math:`H` is the height of the image, and :math:`W` is the width of the
138
+ image. The tensor :attr:`coords` of shape :math:`(B, H_o, W_o, 2)` is
139
+ interpreted as an array of 2D point coordinates :math:`(x_i,y_i)`.
140
+
141
+ Alternatively, the input tensor can be of size :math:`(B, C, T, H, W)`,
142
+ in which case sample points are triplets :math:`(t_i,x_i,y_i)`. Note
143
+ that in this case the order of the components is slightly different
144
+ from `grid_sample()`, which would expect :math:`(x_i,y_i,t_i)`.
145
+
146
+ If `align_corners` is `True`, the coordinate :math:`x` is assumed to be
147
+ in the range :math:`[0,W-1]`, with 0 corresponding to the center of the
148
+ left-most image pixel :math:`W-1` to the center of the right-most
149
+ pixel.
150
+
151
+ If `align_corners` is `False`, the coordinate :math:`x` is assumed to
152
+ be in the range :math:`[0,W]`, with 0 corresponding to the left edge of
153
+ the left-most pixel :math:`W` to the right edge of the right-most
154
+ pixel.
155
+
156
+ Similar conventions apply to the :math:`y` for the range
157
+ :math:`[0,H-1]` and :math:`[0,H]` and to :math:`t` for the range
158
+ :math:`[0,T-1]` and :math:`[0,T]`.
159
+
160
+ Args:
161
+ input (Tensor): batch of input images.
162
+ coords (Tensor): batch of coordinates.
163
+ align_corners (bool, optional): Coordinate convention. Defaults to `True`.
164
+ padding_mode (str, optional): Padding mode. Defaults to `"border"`.
165
+
166
+ Returns:
167
+ Tensor: sampled points.
168
+ """
169
+ coords = coords.detach().clone()
170
+ ############################################################
171
+ # IMPORTANT:
172
+ coords = coords.to(input.device).to(input.dtype)
173
+ ############################################################
174
+
175
+ sizes = input.shape[2:]
176
+
177
+ assert len(sizes) in [2, 3]
178
+
179
+ if len(sizes) == 3:
180
+ # t x y -> x y t to match dimensions T H W in grid_sample
181
+ coords = coords[..., [1, 2, 0]]
182
+
183
+ if align_corners:
184
+ scale = torch.tensor(
185
+ [2 / max(size - 1, 1) for size in reversed(sizes)], device=coords.device, dtype=coords.dtype
186
+ )
187
+ else:
188
+ scale = torch.tensor([2 / size for size in reversed(sizes)], device=coords.device, dtype=coords.dtype)
189
+
190
+ coords.mul_(scale) # coords = coords * scale
191
+ coords.sub_(1) # coords = coords - 1
192
+
193
+ return F.grid_sample(input, coords, align_corners=align_corners, padding_mode=padding_mode)
194
+
195
+
196
+ def sample_features4d(input, coords):
197
+ r"""Sample spatial features
198
+
199
+ `sample_features4d(input, coords)` samples the spatial features
200
+ :attr:`input` represented by a 4D tensor :math:`(B, C, H, W)`.
201
+
202
+ The field is sampled at coordinates :attr:`coords` using bilinear
203
+ interpolation. :attr:`coords` is assumed to be of shape :math:`(B, R,
204
+ 2)`, where each sample has the format :math:`(x_i, y_i)`. This uses the
205
+ same convention as :func:`bilinear_sampler` with `align_corners=True`.
206
+
207
+ The output tensor has one feature per point, and has shape :math:`(B,
208
+ R, C)`.
209
+
210
+ Args:
211
+ input (Tensor): spatial features.
212
+ coords (Tensor): points.
213
+
214
+ Returns:
215
+ Tensor: sampled features.
216
+ """
217
+
218
+ B, _, _, _ = input.shape
219
+
220
+ # B R 2 -> B R 1 2
221
+ coords = coords.unsqueeze(2)
222
+
223
+ # B C R 1
224
+ feats = bilinear_sampler(input, coords)
225
+
226
+ return feats.permute(0, 2, 1, 3).view(B, -1, feats.shape[1] * feats.shape[3]) # B C R 1 -> B R C
vggt/heads/utils.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+
11
+ def position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100) -> torch.Tensor:
12
+ """
13
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
14
+
15
+ Args:
16
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
17
+ embed_dim: Output channel dimension for embeddings
18
+
19
+ Returns:
20
+ Tensor of shape (H, W, embed_dim) with positional embeddings
21
+ """
22
+ H, W, grid_dim = pos_grid.shape
23
+ assert grid_dim == 2
24
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
25
+
26
+ # Process x and y coordinates separately
27
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
28
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
29
+
30
+ # Combine and reshape
31
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
32
+
33
+ return emb.view(H, W, embed_dim) # [H, W, D]
34
+
35
+
36
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
37
+ """
38
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions.
39
+
40
+ Args:
41
+ - embed_dim: The embedding dimension.
42
+ - pos: The position to generate the embedding from.
43
+
44
+ Returns:
45
+ - emb: The generated 1D positional embedding.
46
+ """
47
+ assert embed_dim % 2 == 0
48
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
49
+ omega /= embed_dim / 2.0
50
+ omega = 1.0 / omega_0**omega # (D/2,)
51
+
52
+ pos = pos.reshape(-1) # (M,)
53
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
54
+
55
+ emb_sin = torch.sin(out) # (M, D/2)
56
+ emb_cos = torch.cos(out) # (M, D/2)
57
+
58
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
59
+ return emb.float()
60
+
61
+
62
+ # Inspired by https://github.com/microsoft/moge
63
+
64
+
65
+ def create_uv_grid(
66
+ width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None
67
+ ) -> torch.Tensor:
68
+ """
69
+ Create a normalized UV grid of shape (width, height, 2).
70
+
71
+ The grid spans horizontally and vertically according to an aspect ratio,
72
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
73
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
74
+
75
+ Args:
76
+ width (int): Number of points horizontally.
77
+ height (int): Number of points vertically.
78
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
79
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
80
+ device (torch.device, optional): Device on which the tensor is created.
81
+
82
+ Returns:
83
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
84
+ """
85
+ # Derive aspect ratio if not explicitly provided
86
+ if aspect_ratio is None:
87
+ aspect_ratio = float(width) / float(height)
88
+
89
+ # Compute normalized spans for X and Y
90
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
91
+ span_x = aspect_ratio / diag_factor
92
+ span_y = 1.0 / diag_factor
93
+
94
+ # Establish the linspace boundaries
95
+ left_x = -span_x * (width - 1) / width
96
+ right_x = span_x * (width - 1) / width
97
+ top_y = -span_y * (height - 1) / height
98
+ bottom_y = span_y * (height - 1) / height
99
+
100
+ # Generate 1D coordinates
101
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
102
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
103
+
104
+ # Create 2D meshgrid (width x height) and stack into UV
105
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
106
+ uv_grid = torch.stack((uu, vv), dim=-1)
107
+
108
+ return uv_grid
vggt/layers/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (479 Bytes). View file
 
vggt/layers/__pycache__/attention.cpython-310.pyc ADDED
Binary file (2.87 kB). View file
 
vggt/layers/__pycache__/drop_path.cpython-310.pyc ADDED
Binary file (1.28 kB). View file
 
vggt/layers/__pycache__/layer_scale.cpython-310.pyc ADDED
Binary file (1.08 kB). View file
 
vggt/layers/__pycache__/mlp.cpython-310.pyc ADDED
Binary file (1.27 kB). View file
 
vggt/layers/__pycache__/patch_embed.cpython-310.pyc ADDED
Binary file (2.72 kB). View file
 
vggt/layers/__pycache__/swiglu_ffn.cpython-310.pyc ADDED
Binary file (2.09 kB). View file
 
vggt/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ XFORMERS_AVAILABLE = False
25
+
26
+
27
+ class Block(nn.Module):
28
+ def __init__(
29
+ self,
30
+ dim: int,
31
+ num_heads: int,
32
+ mlp_ratio: float = 4.0,
33
+ qkv_bias: bool = True,
34
+ proj_bias: bool = True,
35
+ ffn_bias: bool = True,
36
+ drop: float = 0.0,
37
+ attn_drop: float = 0.0,
38
+ init_values=None,
39
+ drop_path: float = 0.0,
40
+ act_layer: Callable[..., nn.Module] = nn.GELU,
41
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
42
+ attn_class: Callable[..., nn.Module] = Attention,
43
+ ffn_layer: Callable[..., nn.Module] = Mlp,
44
+ qk_norm: bool = False,
45
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
46
+ rope=None,
47
+ ) -> None:
48
+ super().__init__()
49
+
50
+ self.norm1 = norm_layer(dim)
51
+
52
+ self.attn = attn_class(
53
+ dim,
54
+ num_heads=num_heads,
55
+ qkv_bias=qkv_bias,
56
+ proj_bias=proj_bias,
57
+ attn_drop=attn_drop,
58
+ proj_drop=drop,
59
+ qk_norm=qk_norm,
60
+ fused_attn=fused_attn,
61
+ rope=rope,
62
+ )
63
+
64
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
65
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
66
+
67
+ self.norm2 = norm_layer(dim)
68
+ mlp_hidden_dim = int(dim * mlp_ratio)
69
+ self.mlp = ffn_layer(
70
+ in_features=dim,
71
+ hidden_features=mlp_hidden_dim,
72
+ act_layer=act_layer,
73
+ drop=drop,
74
+ bias=ffn_bias,
75
+ )
76
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
77
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
78
+
79
+ self.sample_drop_ratio = drop_path
80
+
81
+ def forward(self, x: Tensor, pos=None) -> Tensor:
82
+ def attn_residual_func(x: Tensor, pos=None) -> Tensor:
83
+ return self.ls1(self.attn(self.norm1(x), pos=pos))
84
+
85
+ def ffn_residual_func(x: Tensor) -> Tensor:
86
+ return self.ls2(self.mlp(self.norm2(x)))
87
+
88
+ if self.training and self.sample_drop_ratio > 0.1:
89
+ # the overhead is compensated only for a drop path rate larger than 0.1
90
+ x = drop_add_residual_stochastic_depth(
91
+ x,
92
+ pos=pos,
93
+ residual_func=attn_residual_func,
94
+ sample_drop_ratio=self.sample_drop_ratio,
95
+ )
96
+ x = drop_add_residual_stochastic_depth(
97
+ x,
98
+ residual_func=ffn_residual_func,
99
+ sample_drop_ratio=self.sample_drop_ratio,
100
+ )
101
+ elif self.training and self.sample_drop_ratio > 0.0:
102
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos))
103
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
104
+ else:
105
+ x = x + attn_residual_func(x, pos=pos)
106
+ x = x + ffn_residual_func(x)
107
+ return x
108
+
109
+
110
+ def drop_add_residual_stochastic_depth(
111
+ x: Tensor,
112
+ residual_func: Callable[[Tensor], Tensor],
113
+ sample_drop_ratio: float = 0.0,
114
+ pos=None,
115
+ ) -> Tensor:
116
+ # 1) extract subset using permutation
117
+ b, n, d = x.shape
118
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
119
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
120
+ x_subset = x[brange]
121
+
122
+ # 2) apply residual_func to get residual
123
+ if pos is not None:
124
+ # if necessary, apply rope to the subset
125
+ pos = pos[brange]
126
+ residual = residual_func(x_subset, pos=pos)
127
+ else:
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
vggt/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
vggt/layers/rope.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ import numpy as np
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+ from typing import Dict, Tuple
22
+
23
+
24
+ class PositionGetter:
25
+ """Generates and caches 2D spatial positions for patches in a grid.
26
+
27
+ This class efficiently manages the generation of spatial coordinates for patches
28
+ in a 2D grid, caching results to avoid redundant computations.
29
+
30
+ Attributes:
31
+ position_cache: Dictionary storing precomputed position tensors for different
32
+ grid dimensions.
33
+ """
34
+
35
+ def __init__(self):
36
+ """Initializes the position generator with an empty cache."""
37
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
38
+
39
+ def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor:
40
+ """Generates spatial positions for a batch of patches.
41
+
42
+ Args:
43
+ batch_size: Number of samples in the batch.
44
+ height: Height of the grid in patches.
45
+ width: Width of the grid in patches.
46
+ device: Target device for the position tensor.
47
+
48
+ Returns:
49
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
50
+ for each position in the grid, repeated for each batch item.
51
+ """
52
+ if (height, width) not in self.position_cache:
53
+ y_coords = torch.arange(height, device=device)
54
+ x_coords = torch.arange(width, device=device)
55
+ positions = torch.cartesian_prod(y_coords, x_coords)
56
+ self.position_cache[height, width] = positions
57
+
58
+ cached_positions = self.position_cache[height, width]
59
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
60
+
61
+
62
+ class RotaryPositionEmbedding2D(nn.Module):
63
+ """2D Rotary Position Embedding implementation.
64
+
65
+ This module applies rotary position embeddings to input tokens based on their
66
+ 2D spatial positions. It handles the position-dependent rotation of features
67
+ separately for vertical and horizontal dimensions.
68
+
69
+ Args:
70
+ frequency: Base frequency for the position embeddings. Default: 100.0
71
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
72
+
73
+ Attributes:
74
+ base_frequency: Base frequency for computing position embeddings.
75
+ scaling_factor: Factor to scale the computed frequencies.
76
+ frequency_cache: Cache for storing precomputed frequency components.
77
+ """
78
+
79
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
80
+ """Initializes the 2D RoPE module."""
81
+ super().__init__()
82
+ self.base_frequency = frequency
83
+ self.scaling_factor = scaling_factor
84
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
85
+
86
+ def _compute_frequency_components(
87
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
88
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
89
+ """Computes frequency components for rotary embeddings.
90
+
91
+ Args:
92
+ dim: Feature dimension (must be even).
93
+ seq_len: Maximum sequence length.
94
+ device: Target device for computations.
95
+ dtype: Data type for the computed tensors.
96
+
97
+ Returns:
98
+ Tuple of (cosine, sine) tensors for frequency components.
99
+ """
100
+ cache_key = (dim, seq_len, device, dtype)
101
+ if cache_key not in self.frequency_cache:
102
+ # Compute frequency bands
103
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
104
+ inv_freq = 1.0 / (self.base_frequency**exponents)
105
+
106
+ # Generate position-dependent frequencies
107
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
108
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
109
+
110
+ # Compute and cache frequency components
111
+ angles = angles.to(dtype)
112
+ angles = torch.cat((angles, angles), dim=-1)
113
+ cos_components = angles.cos().to(dtype)
114
+ sin_components = angles.sin().to(dtype)
115
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
116
+
117
+ return self.frequency_cache[cache_key]
118
+
119
+ @staticmethod
120
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
121
+ """Performs feature rotation by splitting and recombining feature dimensions.
122
+
123
+ Args:
124
+ x: Input tensor to rotate.
125
+
126
+ Returns:
127
+ Rotated feature tensor.
128
+ """
129
+ feature_dim = x.shape[-1]
130
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
131
+ return torch.cat((-x2, x1), dim=-1)
132
+
133
+ def _apply_1d_rope(
134
+ self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor
135
+ ) -> torch.Tensor:
136
+ """Applies 1D rotary position embeddings along one dimension.
137
+
138
+ Args:
139
+ tokens: Input token features.
140
+ positions: Position indices.
141
+ cos_comp: Cosine components for rotation.
142
+ sin_comp: Sine components for rotation.
143
+
144
+ Returns:
145
+ Tokens with applied rotary position embeddings.
146
+ """
147
+ # Embed positions with frequency components
148
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
149
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
150
+
151
+ # Apply rotation
152
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
153
+
154
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
155
+ """Applies 2D rotary position embeddings to input tokens.
156
+
157
+ Args:
158
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
159
+ The feature dimension (dim) must be divisible by 4.
160
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
161
+ the y and x coordinates for each token.
162
+
163
+ Returns:
164
+ Tensor of same shape as input with applied 2D rotary position embeddings.
165
+
166
+ Raises:
167
+ AssertionError: If input dimensions are invalid or positions are malformed.
168
+ """
169
+ # Validate inputs
170
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
171
+ assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)"
172
+
173
+ # Compute feature dimension for each spatial direction
174
+ feature_dim = tokens.size(-1) // 2
175
+
176
+ # Get frequency components
177
+ max_position = int(positions.max()) + 1
178
+ cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype)
179
+
180
+ # Split features for vertical and horizontal processing
181
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
182
+
183
+ # Apply RoPE separately for each dimension
184
+ vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp)
185
+ horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp)
186
+
187
+ # Combine processed features
188
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
vggt/layers/vision_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.checkpoint import checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+ from . import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
20
+
21
+ logger = logging.getLogger("dinov2")
22
+
23
+
24
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
25
+ if not depth_first and include_root:
26
+ fn(module=module, name=name)
27
+ for child_name, child_module in module.named_children():
28
+ child_name = ".".join((name, child_name)) if name else child_name
29
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
30
+ if depth_first and include_root:
31
+ fn(module=module, name=name)
32
+ return module
33
+
34
+
35
+ class BlockChunk(nn.ModuleList):
36
+ def forward(self, x):
37
+ for b in self:
38
+ x = b(x)
39
+ return x
40
+
41
+
42
+ class DinoVisionTransformer(nn.Module):
43
+ def __init__(
44
+ self,
45
+ img_size=224,
46
+ patch_size=16,
47
+ in_chans=3,
48
+ embed_dim=768,
49
+ depth=12,
50
+ num_heads=12,
51
+ mlp_ratio=4.0,
52
+ qkv_bias=True,
53
+ ffn_bias=True,
54
+ proj_bias=True,
55
+ drop_path_rate=0.0,
56
+ drop_path_uniform=False,
57
+ init_values=None, # for layerscale: None or 0 => no layerscale
58
+ embed_layer=PatchEmbed,
59
+ act_layer=nn.GELU,
60
+ block_fn=Block,
61
+ ffn_layer="mlp",
62
+ block_chunks=1,
63
+ num_register_tokens=0,
64
+ interpolate_antialias=False,
65
+ interpolate_offset=0.1,
66
+ qk_norm=False,
67
+ ):
68
+ """
69
+ Args:
70
+ img_size (int, tuple): input image size
71
+ patch_size (int, tuple): patch size
72
+ in_chans (int): number of input channels
73
+ embed_dim (int): embedding dimension
74
+ depth (int): depth of transformer
75
+ num_heads (int): number of attention heads
76
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
77
+ qkv_bias (bool): enable bias for qkv if True
78
+ proj_bias (bool): enable bias for proj in attn if True
79
+ ffn_bias (bool): enable bias for ffn if True
80
+ drop_path_rate (float): stochastic depth rate
81
+ drop_path_uniform (bool): apply uniform drop rate across blocks
82
+ weight_init (str): weight init scheme
83
+ init_values (float): layer-scale init values
84
+ embed_layer (nn.Module): patch embedding layer
85
+ act_layer (nn.Module): MLP activation layer
86
+ block_fn (nn.Module): transformer block class
87
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
88
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
89
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
90
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
91
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
92
+ """
93
+ super().__init__()
94
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
95
+
96
+ # tricky but makes it work
97
+ self.use_checkpoint = False
98
+ #
99
+
100
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
101
+ self.num_tokens = 1
102
+ self.n_blocks = depth
103
+ self.num_heads = num_heads
104
+ self.patch_size = patch_size
105
+ self.num_register_tokens = num_register_tokens
106
+ self.interpolate_antialias = interpolate_antialias
107
+ self.interpolate_offset = interpolate_offset
108
+
109
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
110
+ num_patches = self.patch_embed.num_patches
111
+
112
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
113
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
114
+ assert num_register_tokens >= 0
115
+ self.register_tokens = (
116
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
117
+ )
118
+
119
+ if drop_path_uniform is True:
120
+ dpr = [drop_path_rate] * depth
121
+ else:
122
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
123
+
124
+ if ffn_layer == "mlp":
125
+ logger.info("using MLP layer as FFN")
126
+ ffn_layer = Mlp
127
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
128
+ logger.info("using SwiGLU layer as FFN")
129
+ ffn_layer = SwiGLUFFNFused
130
+ elif ffn_layer == "identity":
131
+ logger.info("using Identity layer as FFN")
132
+
133
+ def f(*args, **kwargs):
134
+ return nn.Identity()
135
+
136
+ ffn_layer = f
137
+ else:
138
+ raise NotImplementedError
139
+
140
+ blocks_list = [
141
+ block_fn(
142
+ dim=embed_dim,
143
+ num_heads=num_heads,
144
+ mlp_ratio=mlp_ratio,
145
+ qkv_bias=qkv_bias,
146
+ proj_bias=proj_bias,
147
+ ffn_bias=ffn_bias,
148
+ drop_path=dpr[i],
149
+ norm_layer=norm_layer,
150
+ act_layer=act_layer,
151
+ ffn_layer=ffn_layer,
152
+ init_values=init_values,
153
+ qk_norm=qk_norm,
154
+ )
155
+ for i in range(depth)
156
+ ]
157
+ if block_chunks > 0:
158
+ self.chunked_blocks = True
159
+ chunked_blocks = []
160
+ chunksize = depth // block_chunks
161
+ for i in range(0, depth, chunksize):
162
+ # this is to keep the block index consistent if we chunk the block list
163
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
164
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
165
+ else:
166
+ self.chunked_blocks = False
167
+ self.blocks = nn.ModuleList(blocks_list)
168
+
169
+ self.norm = norm_layer(embed_dim)
170
+ self.head = nn.Identity()
171
+
172
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
173
+
174
+ self.init_weights()
175
+
176
+ def init_weights(self):
177
+ trunc_normal_(self.pos_embed, std=0.02)
178
+ nn.init.normal_(self.cls_token, std=1e-6)
179
+ if self.register_tokens is not None:
180
+ nn.init.normal_(self.register_tokens, std=1e-6)
181
+ named_apply(init_weights_vit_timm, self)
182
+
183
+ def interpolate_pos_encoding(self, x, w, h):
184
+ previous_dtype = x.dtype
185
+ npatch = x.shape[1] - 1
186
+ N = self.pos_embed.shape[1] - 1
187
+ if npatch == N and w == h:
188
+ return self.pos_embed
189
+ pos_embed = self.pos_embed.float()
190
+ class_pos_embed = pos_embed[:, 0]
191
+ patch_pos_embed = pos_embed[:, 1:]
192
+ dim = x.shape[-1]
193
+ w0 = w // self.patch_size
194
+ h0 = h // self.patch_size
195
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
196
+ assert N == M * M
197
+ kwargs = {}
198
+ if self.interpolate_offset:
199
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
200
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
201
+ sx = float(w0 + self.interpolate_offset) / M
202
+ sy = float(h0 + self.interpolate_offset) / M
203
+ kwargs["scale_factor"] = (sx, sy)
204
+ else:
205
+ # Simply specify an output size instead of a scale factor
206
+ kwargs["size"] = (w0, h0)
207
+ patch_pos_embed = nn.functional.interpolate(
208
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
209
+ mode="bicubic",
210
+ antialias=self.interpolate_antialias,
211
+ **kwargs,
212
+ )
213
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
214
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
215
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
216
+
217
+ def prepare_tokens_with_masks(self, x, masks=None):
218
+ B, nc, w, h = x.shape
219
+ x = self.patch_embed(x)
220
+ if masks is not None:
221
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
222
+
223
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
224
+ x = x + self.interpolate_pos_encoding(x, w, h)
225
+
226
+ if self.register_tokens is not None:
227
+ x = torch.cat(
228
+ (
229
+ x[:, :1],
230
+ self.register_tokens.expand(x.shape[0], -1, -1),
231
+ x[:, 1:],
232
+ ),
233
+ dim=1,
234
+ )
235
+
236
+ return x
237
+
238
+ def forward_features_list(self, x_list, masks_list):
239
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
240
+
241
+ for blk in self.blocks:
242
+ if self.use_checkpoint:
243
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
244
+ else:
245
+ x = blk(x)
246
+
247
+ all_x = x
248
+ output = []
249
+ for x, masks in zip(all_x, masks_list):
250
+ x_norm = self.norm(x)
251
+ output.append(
252
+ {
253
+ "x_norm_clstoken": x_norm[:, 0],
254
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
255
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
256
+ "x_prenorm": x,
257
+ "masks": masks,
258
+ }
259
+ )
260
+ return output
261
+
262
+ def forward_features(self, x, masks=None):
263
+ if isinstance(x, list):
264
+ return self.forward_features_list(x, masks)
265
+
266
+ x = self.prepare_tokens_with_masks(x, masks)
267
+
268
+ for blk in self.blocks:
269
+ if self.use_checkpoint:
270
+ x = checkpoint(blk, x, use_reentrant=self.use_reentrant)
271
+ else:
272
+ x = blk(x)
273
+
274
+ x_norm = self.norm(x)
275
+ return {
276
+ "x_norm_clstoken": x_norm[:, 0],
277
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
278
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
279
+ "x_prenorm": x,
280
+ "masks": masks,
281
+ }
282
+
283
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ output, total_block_len = [], len(self.blocks)
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for i, blk in enumerate(self.blocks):
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
293
+ return output
294
+
295
+ def _get_intermediate_layers_chunked(self, x, n=1):
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
298
+ # If n is an int, take the n last blocks. If it's a list, take them
299
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
300
+ for block_chunk in self.blocks:
301
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
302
+ x = blk(x)
303
+ if i in blocks_to_take:
304
+ output.append(x)
305
+ i += 1
306
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
307
+ return output
308
+
309
+ def get_intermediate_layers(
310
+ self,
311
+ x: torch.Tensor,
312
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
313
+ reshape: bool = False,
314
+ return_class_token: bool = False,
315
+ norm=True,
316
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
317
+ if self.chunked_blocks:
318
+ outputs = self._get_intermediate_layers_chunked(x, n)
319
+ else:
320
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
321
+ if norm:
322
+ outputs = [self.norm(out) for out in outputs]
323
+ class_tokens = [out[:, 0] for out in outputs]
324
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
325
+ if reshape:
326
+ B, _, w, h = x.shape
327
+ outputs = [
328
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
329
+ for out in outputs
330
+ ]
331
+ if return_class_token:
332
+ return tuple(zip(outputs, class_tokens))
333
+ return tuple(outputs)
334
+
335
+ def forward(self, *args, is_training=True, **kwargs):
336
+ ret = self.forward_features(*args, **kwargs)
337
+ if is_training:
338
+ return ret
339
+ else:
340
+ return self.head(ret["x_norm_clstoken"])
341
+
342
+
343
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
344
+ """ViT weight initialization, original timm impl (for reproducibility)"""
345
+ if isinstance(module, nn.Linear):
346
+ trunc_normal_(module.weight, std=0.02)
347
+ if module.bias is not None:
348
+ nn.init.zeros_(module.bias)
349
+
350
+
351
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
352
+ model = DinoVisionTransformer(
353
+ patch_size=patch_size,
354
+ embed_dim=384,
355
+ depth=12,
356
+ num_heads=6,
357
+ mlp_ratio=4,
358
+ block_fn=partial(Block, attn_class=MemEffAttention),
359
+ num_register_tokens=num_register_tokens,
360
+ **kwargs,
361
+ )
362
+ return model
363
+
364
+
365
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=768,
369
+ depth=12,
370
+ num_heads=12,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=1024,
383
+ depth=24,
384
+ num_heads=16,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
394
+ """
395
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
396
+ """
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1536,
400
+ depth=40,
401
+ num_heads=24,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
vggt/models/__pycache__/aggregator.cpython-310.pyc ADDED
Binary file (8.69 kB). View file
 
vggt/utils/__pycache__/geometry.cpython-310.pyc ADDED
Binary file (4.53 kB). View file
 
vggt/utils/__pycache__/load_fn.cpython-310.pyc ADDED
Binary file (2.65 kB). View file
 
vggt/utils/__pycache__/rotation.cpython-310.pyc ADDED
Binary file (3.35 kB). View file
 
vggt/utils/geometry.py ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import os
8
+ import torch
9
+ import numpy as np
10
+
11
+
12
+ def unproject_depth_map_to_point_map(
13
+ depth_map: np.ndarray, extrinsics_cam: np.ndarray, intrinsics_cam: np.ndarray
14
+ ) -> np.ndarray:
15
+ """
16
+ Unproject a batch of depth maps to 3D world coordinates.
17
+
18
+ Args:
19
+ depth_map (np.ndarray): Batch of depth maps of shape (S, H, W, 1) or (S, H, W)
20
+ extrinsics_cam (np.ndarray): Batch of camera extrinsic matrices of shape (S, 3, 4)
21
+ intrinsics_cam (np.ndarray): Batch of camera intrinsic matrices of shape (S, 3, 3)
22
+
23
+ Returns:
24
+ np.ndarray: Batch of 3D world coordinates of shape (S, H, W, 3)
25
+ """
26
+ if isinstance(depth_map, torch.Tensor):
27
+ depth_map = depth_map.cpu().numpy()
28
+ if isinstance(extrinsics_cam, torch.Tensor):
29
+ extrinsics_cam = extrinsics_cam.cpu().numpy()
30
+ if isinstance(intrinsics_cam, torch.Tensor):
31
+ intrinsics_cam = intrinsics_cam.cpu().numpy()
32
+
33
+ world_points_list = []
34
+ for frame_idx in range(depth_map.shape[0]):
35
+ cur_world_points, _, _ = depth_to_world_coords_points(
36
+ depth_map[frame_idx].squeeze(-1), extrinsics_cam[frame_idx], intrinsics_cam[frame_idx]
37
+ )
38
+ world_points_list.append(cur_world_points)
39
+ world_points_array = np.stack(world_points_list, axis=0)
40
+
41
+ return world_points_array
42
+
43
+
44
+ def depth_to_world_coords_points(
45
+ depth_map: np.ndarray,
46
+ extrinsic: np.ndarray,
47
+ intrinsic: np.ndarray,
48
+ eps=1e-8,
49
+ ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
50
+ """
51
+ Convert a depth map to world coordinates.
52
+
53
+ Args:
54
+ depth_map (np.ndarray): Depth map of shape (H, W).
55
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
56
+ extrinsic (np.ndarray): Camera extrinsic matrix of shape (3, 4). OpenCV camera coordinate convention, cam from world.
57
+
58
+ Returns:
59
+ tuple[np.ndarray, np.ndarray]: World coordinates (H, W, 3) and valid depth mask (H, W).
60
+ """
61
+ if depth_map is None:
62
+ return None, None, None
63
+
64
+ # Valid depth mask
65
+ point_mask = depth_map > eps
66
+
67
+ # Convert depth map to camera coordinates
68
+ cam_coords_points = depth_to_cam_coords_points(depth_map, intrinsic)
69
+
70
+ # Multiply with the inverse of extrinsic matrix to transform to world coordinates
71
+ # extrinsic_inv is 4x4 (note closed_form_inverse_OpenCV is batched, the output is (N, 4, 4))
72
+ cam_to_world_extrinsic = closed_form_inverse_se3(extrinsic[None])[0]
73
+
74
+ R_cam_to_world = cam_to_world_extrinsic[:3, :3]
75
+ t_cam_to_world = cam_to_world_extrinsic[:3, 3]
76
+
77
+ # Apply the rotation and translation to the camera coordinates
78
+ world_coords_points = np.dot(cam_coords_points, R_cam_to_world.T) + t_cam_to_world # HxWx3, 3x3 -> HxWx3
79
+ # world_coords_points = np.einsum("ij,hwj->hwi", R_cam_to_world, cam_coords_points) + t_cam_to_world
80
+
81
+ return world_coords_points, cam_coords_points, point_mask
82
+
83
+
84
+ def depth_to_cam_coords_points(depth_map: np.ndarray, intrinsic: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
85
+ """
86
+ Convert a depth map to camera coordinates.
87
+
88
+ Args:
89
+ depth_map (np.ndarray): Depth map of shape (H, W).
90
+ intrinsic (np.ndarray): Camera intrinsic matrix of shape (3, 3).
91
+
92
+ Returns:
93
+ tuple[np.ndarray, np.ndarray]: Camera coordinates (H, W, 3)
94
+ """
95
+ H, W = depth_map.shape
96
+ assert intrinsic.shape == (3, 3), "Intrinsic matrix must be 3x3"
97
+ assert intrinsic[0, 1] == 0 and intrinsic[1, 0] == 0, "Intrinsic matrix must have zero skew"
98
+
99
+ # Intrinsic parameters
100
+ fu, fv = intrinsic[0, 0], intrinsic[1, 1]
101
+ cu, cv = intrinsic[0, 2], intrinsic[1, 2]
102
+
103
+ # Generate grid of pixel coordinates
104
+ u, v = np.meshgrid(np.arange(W), np.arange(H))
105
+
106
+ # Unproject to camera coordinates
107
+ x_cam = (u - cu) * depth_map / fu
108
+ y_cam = (v - cv) * depth_map / fv
109
+ z_cam = depth_map
110
+
111
+ # Stack to form camera coordinates
112
+ cam_coords = np.stack((x_cam, y_cam, z_cam), axis=-1).astype(np.float32)
113
+
114
+ return cam_coords
115
+
116
+
117
+ def closed_form_inverse_se3(se3, R=None, T=None):
118
+ """
119
+ Compute the inverse of each 4x4 (or 3x4) SE3 matrix in a batch.
120
+
121
+ If `R` and `T` are provided, they must correspond to the rotation and translation
122
+ components of `se3`. Otherwise, they will be extracted from `se3`.
123
+
124
+ Args:
125
+ se3: Nx4x4 or Nx3x4 array or tensor of SE3 matrices.
126
+ R (optional): Nx3x3 array or tensor of rotation matrices.
127
+ T (optional): Nx3x1 array or tensor of translation vectors.
128
+
129
+ Returns:
130
+ Inverted SE3 matrices with the same type and device as `se3`.
131
+
132
+ Shapes:
133
+ se3: (N, 4, 4)
134
+ R: (N, 3, 3)
135
+ T: (N, 3, 1)
136
+ """
137
+ # Check if se3 is a numpy array or a torch tensor
138
+ is_numpy = isinstance(se3, np.ndarray)
139
+
140
+ # Validate shapes
141
+ if se3.shape[-2:] != (4, 4) and se3.shape[-2:] != (3, 4):
142
+ raise ValueError(f"se3 must be of shape (N,4,4), got {se3.shape}.")
143
+
144
+ # Extract R and T if not provided
145
+ if R is None:
146
+ R = se3[:, :3, :3] # (N,3,3)
147
+ if T is None:
148
+ T = se3[:, :3, 3:] # (N,3,1)
149
+
150
+ # Transpose R
151
+ if is_numpy:
152
+ # Compute the transpose of the rotation for NumPy
153
+ R_transposed = np.transpose(R, (0, 2, 1))
154
+ # -R^T t for NumPy
155
+ top_right = -np.matmul(R_transposed, T)
156
+ inverted_matrix = np.tile(np.eye(4), (len(R), 1, 1))
157
+ else:
158
+ R_transposed = R.transpose(1, 2) # (N,3,3)
159
+ top_right = -torch.bmm(R_transposed, T) # (N,3,1)
160
+ inverted_matrix = torch.eye(4, 4)[None].repeat(len(R), 1, 1)
161
+ inverted_matrix = inverted_matrix.to(R.dtype).to(R.device)
162
+
163
+ inverted_matrix[:, :3, :3] = R_transposed
164
+ inverted_matrix[:, :3, 3:] = top_right
165
+
166
+ return inverted_matrix
vggt/utils/load_fn.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from PIL import Image
9
+ from torchvision import transforms as TF
10
+
11
+
12
+ def load_and_preprocess_images(image_path_list):
13
+ """
14
+ A quick start function to load and preprocess images for model input.
15
+ This assumes the images should have the same shape for easier batching, but our model can also work well with different shapes.
16
+
17
+ Args:
18
+ image_path_list (list): List of paths to image files
19
+
20
+ Returns:
21
+ torch.Tensor: Batched tensor of preprocessed images with shape (N, 3, H, W)
22
+
23
+ Raises:
24
+ ValueError: If the input list is empty
25
+
26
+ Notes:
27
+ - Images with different dimensions will be padded with white (value=1.0)
28
+ - A warning is printed when images have different shapes
29
+ - The function ensures width=518px while maintaining aspect ratio
30
+ - Height is adjusted to be divisible by 14 for compatibility with model requirements
31
+ """
32
+ # Check for empty list
33
+ if len(image_path_list) == 0:
34
+ raise ValueError("At least 1 image is required")
35
+
36
+ images = []
37
+ shapes = set()
38
+ to_tensor = TF.ToTensor()
39
+
40
+ # First process all images and collect their shapes
41
+ for image_path in image_path_list:
42
+
43
+ # Open image
44
+ img = Image.open(image_path)
45
+
46
+ # If there's an alpha channel, blend onto white background:
47
+ if img.mode == "RGBA":
48
+ # Create white background
49
+ background = Image.new("RGBA", img.size, (255, 255, 255, 255))
50
+ # Alpha composite onto the white background
51
+ img = Image.alpha_composite(background, img)
52
+
53
+ # Now convert to "RGB" (this step assigns white for transparent areas)
54
+ img = img.convert("RGB")
55
+
56
+ width, height = img.size
57
+ new_width = 518
58
+
59
+ # Calculate height maintaining aspect ratio, divisible by 14
60
+ new_height = round(height * (new_width / width) / 14) * 14
61
+
62
+ # Resize with new dimensions (width, height)
63
+
64
+ img = img.resize((new_width, new_height), Image.Resampling.BICUBIC)
65
+ img = to_tensor(img) # Convert to tensor (0, 1)
66
+
67
+ # Center crop height if it's larger than 518
68
+
69
+ if new_height > 518:
70
+ start_y = (new_height - 518) // 2
71
+ img = img[:, start_y : start_y + 518, :]
72
+
73
+ shapes.add((img.shape[1], img.shape[2]))
74
+ images.append(img)
75
+
76
+ # Check if we have different shapes
77
+ # In theory our model can also work well with different shapes
78
+
79
+ if len(shapes) > 1:
80
+ print(f"Warning: Found images with different shapes: {shapes}")
81
+ # Find maximum dimensions
82
+ max_height = max(shape[0] for shape in shapes)
83
+ max_width = max(shape[1] for shape in shapes)
84
+
85
+ # Pad images if necessary
86
+ padded_images = []
87
+ for img in images:
88
+ h_padding = max_height - img.shape[1]
89
+ w_padding = max_width - img.shape[2]
90
+
91
+ if h_padding > 0 or w_padding > 0:
92
+ pad_top = h_padding // 2
93
+ pad_bottom = h_padding - pad_top
94
+ pad_left = w_padding // 2
95
+ pad_right = w_padding - pad_left
96
+
97
+ img = torch.nn.functional.pad(
98
+ img, (pad_left, pad_right, pad_top, pad_bottom), mode="constant", value=1.0
99
+ )
100
+ padded_images.append(img)
101
+ images = padded_images
102
+
103
+ images = torch.stack(images) # concatenate images
104
+
105
+ # Ensure correct shape when single image
106
+ if len(image_path_list) == 1:
107
+ # Verify shape is (1, C, H, W)
108
+ if images.dim() == 3:
109
+ images = images.unsqueeze(0)
110
+
111
+ return images
vggt/utils/pose_enc.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import torch
8
+ from .rotation import quat_to_mat, mat_to_quat
9
+
10
+
11
+ def extri_intri_to_pose_encoding(
12
+ extrinsics,
13
+ intrinsics,
14
+ image_size_hw=None, # e.g., (256, 512)
15
+ pose_encoding_type="absT_quaR_FoV",
16
+ ):
17
+ """Convert camera extrinsics and intrinsics to a compact pose encoding.
18
+
19
+ This function transforms camera parameters into a unified pose encoding format,
20
+ which can be used for various downstream tasks like pose prediction or representation.
21
+
22
+ Args:
23
+ extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4,
24
+ where B is batch size and S is sequence length.
25
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world transformation.
26
+ The format is [R|t] where R is a 3x3 rotation matrix and t is a 3x1 translation vector.
27
+ intrinsics (torch.Tensor): Camera intrinsic parameters with shape BxSx3x3.
28
+ Defined in pixels, with format:
29
+ [[fx, 0, cx],
30
+ [0, fy, cy],
31
+ [0, 0, 1]]
32
+ where fx, fy are focal lengths and (cx, cy) is the principal point
33
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
34
+ Required for computing field of view values. For example: (256, 512).
35
+ pose_encoding_type (str): Type of pose encoding to use. Currently only
36
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
37
+
38
+ Returns:
39
+ torch.Tensor: Encoded camera pose parameters with shape BxSx9.
40
+ For "absT_quaR_FoV" type, the 9 dimensions are:
41
+ - [:3] = absolute translation vector T (3D)
42
+ - [3:7] = rotation as quaternion quat (4D)
43
+ - [7:] = field of view (2D)
44
+ """
45
+
46
+ # extrinsics: BxSx3x4
47
+ # intrinsics: BxSx3x3
48
+
49
+ if pose_encoding_type == "absT_quaR_FoV":
50
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
51
+ T = extrinsics[:, :, :3, 3] # BxSx3
52
+
53
+ quat = mat_to_quat(R)
54
+ # Note the order of h and w here
55
+ H, W = image_size_hw
56
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
57
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
58
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
59
+ else:
60
+ raise NotImplementedError
61
+
62
+ return pose_encoding
63
+
64
+
65
+ def pose_encoding_to_extri_intri(
66
+ pose_encoding,
67
+ image_size_hw=None, # e.g., (256, 512)
68
+ pose_encoding_type="absT_quaR_FoV",
69
+ build_intrinsics=True,
70
+ ):
71
+ """Convert a pose encoding back to camera extrinsics and intrinsics.
72
+
73
+ This function performs the inverse operation of extri_intri_to_pose_encoding,
74
+ reconstructing the full camera parameters from the compact encoding.
75
+
76
+ Args:
77
+ pose_encoding (torch.Tensor): Encoded camera pose parameters with shape BxSx9,
78
+ where B is batch size and S is sequence length.
79
+ For "absT_quaR_FoV" type, the 9 dimensions are:
80
+ - [:3] = absolute translation vector T (3D)
81
+ - [3:7] = rotation as quaternion quat (4D)
82
+ - [7:] = field of view (2D)
83
+ image_size_hw (tuple): Tuple of (height, width) of the image in pixels.
84
+ Required for reconstructing intrinsics from field of view values.
85
+ For example: (256, 512).
86
+ pose_encoding_type (str): Type of pose encoding used. Currently only
87
+ supports "absT_quaR_FoV" (absolute translation, quaternion rotation, field of view).
88
+ build_intrinsics (bool): Whether to reconstruct the intrinsics matrix.
89
+ If False, only extrinsics are returned and intrinsics will be None.
90
+
91
+ Returns:
92
+ tuple: (extrinsics, intrinsics)
93
+ - extrinsics (torch.Tensor): Camera extrinsic parameters with shape BxSx3x4.
94
+ In OpenCV coordinate system (x-right, y-down, z-forward), representing camera from world
95
+ transformation. The format is [R|t] where R is a 3x3 rotation matrix and t is
96
+ a 3x1 translation vector.
97
+ - intrinsics (torch.Tensor or None): Camera intrinsic parameters with shape BxSx3x3,
98
+ or None if build_intrinsics is False. Defined in pixels, with format:
99
+ [[fx, 0, cx],
100
+ [0, fy, cy],
101
+ [0, 0, 1]]
102
+ where fx, fy are focal lengths and (cx, cy) is the principal point,
103
+ assumed to be at the center of the image (W/2, H/2).
104
+ """
105
+
106
+ intrinsics = None
107
+
108
+ if pose_encoding_type == "absT_quaR_FoV":
109
+ T = pose_encoding[..., :3]
110
+ quat = pose_encoding[..., 3:7]
111
+ fov_h = pose_encoding[..., 7]
112
+ fov_w = pose_encoding[..., 8]
113
+
114
+ R = quat_to_mat(quat)
115
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
116
+
117
+ if build_intrinsics:
118
+ H, W = image_size_hw
119
+ fy = (H / 2.0) / torch.tan(fov_h / 2.0)
120
+ fx = (W / 2.0) / torch.tan(fov_w / 2.0)
121
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
122
+ intrinsics[..., 0, 0] = fx
123
+ intrinsics[..., 1, 1] = fy
124
+ intrinsics[..., 0, 2] = W / 2
125
+ intrinsics[..., 1, 2] = H / 2
126
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
127
+ else:
128
+ raise NotImplementedError
129
+
130
+ return extrinsics, intrinsics
vggt/utils/visual_track.py ADDED
@@ -0,0 +1,239 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import cv2
8
+ import torch
9
+ import numpy as np
10
+ import os
11
+
12
+
13
+ def color_from_xy(x, y, W, H, cmap_name="hsv"):
14
+ """
15
+ Map (x, y) -> color in (R, G, B).
16
+ 1) Normalize x,y to [0,1].
17
+ 2) Combine them into a single scalar c in [0,1].
18
+ 3) Use matplotlib's colormap to convert c -> (R,G,B).
19
+
20
+ You can customize step 2, e.g., c = (x + y)/2, or some function of (x, y).
21
+ """
22
+ import matplotlib.cm
23
+ import matplotlib.colors
24
+
25
+ x_norm = x / max(W - 1, 1)
26
+ y_norm = y / max(H - 1, 1)
27
+ # Simple combination:
28
+ c = (x_norm + y_norm) / 2.0
29
+
30
+ cmap = matplotlib.cm.get_cmap(cmap_name)
31
+ # cmap(c) -> (r,g,b,a) in [0,1]
32
+ rgba = cmap(c)
33
+ r, g, b = rgba[0], rgba[1], rgba[2]
34
+ return (r, g, b) # in [0,1], RGB order
35
+
36
+
37
+ def get_track_colors_by_position(tracks_b, vis_mask_b=None, image_width=None, image_height=None, cmap_name="hsv"):
38
+ """
39
+ Given all tracks in one sample (b), compute a (N,3) array of RGB color values
40
+ in [0,255]. The color is determined by the (x,y) position in the first
41
+ visible frame for each track.
42
+
43
+ Args:
44
+ tracks_b: Tensor of shape (S, N, 2). (x,y) for each track in each frame.
45
+ vis_mask_b: (S, N) boolean mask; if None, assume all are visible.
46
+ image_width, image_height: used for normalizing (x, y).
47
+ cmap_name: for matplotlib (e.g., 'hsv', 'rainbow', 'jet').
48
+
49
+ Returns:
50
+ track_colors: np.ndarray of shape (N, 3), each row is (R,G,B) in [0,255].
51
+ """
52
+ S, N, _ = tracks_b.shape
53
+ track_colors = np.zeros((N, 3), dtype=np.uint8)
54
+
55
+ if vis_mask_b is None:
56
+ # treat all as visible
57
+ vis_mask_b = torch.ones(S, N, dtype=torch.bool, device=tracks_b.device)
58
+
59
+ for i in range(N):
60
+ # Find first visible frame for track i
61
+ visible_frames = torch.where(vis_mask_b[:, i])[0]
62
+ if len(visible_frames) == 0:
63
+ # track is never visible; just assign black or something
64
+ track_colors[i] = (0, 0, 0)
65
+ continue
66
+
67
+ first_s = int(visible_frames[0].item())
68
+ # use that frame's (x,y)
69
+ x, y = tracks_b[first_s, i].tolist()
70
+
71
+ # map (x,y) -> (R,G,B) in [0,1]
72
+ r, g, b = color_from_xy(x, y, W=image_width, H=image_height, cmap_name=cmap_name)
73
+ # scale to [0,255]
74
+ r, g, b = int(r * 255), int(g * 255), int(b * 255)
75
+ track_colors[i] = (r, g, b)
76
+
77
+ return track_colors
78
+
79
+
80
+ def visualize_tracks_on_images(
81
+ images,
82
+ tracks,
83
+ track_vis_mask=None,
84
+ out_dir="track_visuals_concat_by_xy",
85
+ image_format="CHW", # "CHW" or "HWC"
86
+ normalize_mode="[0,1]",
87
+ cmap_name="hsv", # e.g. "hsv", "rainbow", "jet"
88
+ frames_per_row=4, # New parameter for grid layout
89
+ save_grid=True, # Flag to control whether to save the grid image
90
+ ):
91
+ """
92
+ Visualizes frames in a grid layout with specified frames per row.
93
+ Each track's color is determined by its (x,y) position
94
+ in the first visible frame (or frame 0 if always visible).
95
+ Finally convert the BGR result to RGB before saving.
96
+ Also saves each individual frame as a separate PNG file.
97
+
98
+ Args:
99
+ images: torch.Tensor (S, 3, H, W) if CHW or (S, H, W, 3) if HWC.
100
+ tracks: torch.Tensor (S, N, 2), last dim = (x, y).
101
+ track_vis_mask: torch.Tensor (S, N) or None.
102
+ out_dir: folder to save visualizations.
103
+ image_format: "CHW" or "HWC".
104
+ normalize_mode: "[0,1]", "[-1,1]", or None for direct raw -> 0..255
105
+ cmap_name: a matplotlib colormap name for color_from_xy.
106
+ frames_per_row: number of frames to display in each row of the grid.
107
+ save_grid: whether to save all frames in one grid image.
108
+
109
+ Returns:
110
+ None (saves images in out_dir).
111
+ """
112
+
113
+ if len(tracks.shape) == 4:
114
+ tracks = tracks.squeeze(0)
115
+ images = images.squeeze(0)
116
+ if track_vis_mask is not None:
117
+ track_vis_mask = track_vis_mask.squeeze(0)
118
+
119
+ import matplotlib
120
+
121
+ matplotlib.use("Agg") # for non-interactive (optional)
122
+
123
+ os.makedirs(out_dir, exist_ok=True)
124
+
125
+ S = images.shape[0]
126
+ _, N, _ = tracks.shape # (S, N, 2)
127
+
128
+ # Move to CPU
129
+ images = images.cpu().clone()
130
+ tracks = tracks.cpu().clone()
131
+ if track_vis_mask is not None:
132
+ track_vis_mask = track_vis_mask.cpu().clone()
133
+
134
+ # Infer H, W from images shape
135
+ if image_format == "CHW":
136
+ # e.g. images[s].shape = (3, H, W)
137
+ H, W = images.shape[2], images.shape[3]
138
+ else:
139
+ # e.g. images[s].shape = (H, W, 3)
140
+ H, W = images.shape[1], images.shape[2]
141
+
142
+ # Pre-compute the color for each track i based on first visible position
143
+ track_colors_rgb = get_track_colors_by_position(
144
+ tracks, # shape (S, N, 2)
145
+ vis_mask_b=track_vis_mask if track_vis_mask is not None else None,
146
+ image_width=W,
147
+ image_height=H,
148
+ cmap_name=cmap_name,
149
+ )
150
+
151
+ # We'll accumulate each frame's drawn image in a list
152
+ frame_images = []
153
+
154
+ for s in range(S):
155
+ # shape => either (3, H, W) or (H, W, 3)
156
+ img = images[s]
157
+
158
+ # Convert to (H, W, 3)
159
+ if image_format == "CHW":
160
+ img = img.permute(1, 2, 0) # (H, W, 3)
161
+ # else "HWC", do nothing
162
+
163
+ img = img.numpy().astype(np.float32)
164
+
165
+ # Scale to [0,255] if needed
166
+ if normalize_mode == "[0,1]":
167
+ img = np.clip(img, 0, 1) * 255.0
168
+ elif normalize_mode == "[-1,1]":
169
+ img = (img + 1.0) * 0.5 * 255.0
170
+ img = np.clip(img, 0, 255.0)
171
+ # else no normalization
172
+
173
+ # Convert to uint8
174
+ img = img.astype(np.uint8)
175
+
176
+ # For drawing in OpenCV, convert to BGR
177
+ img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
178
+
179
+ # Draw each visible track
180
+ cur_tracks = tracks[s] # shape (N, 2)
181
+ if track_vis_mask is not None:
182
+ valid_indices = torch.where(track_vis_mask[s])[0]
183
+ else:
184
+ valid_indices = range(N)
185
+
186
+ cur_tracks_np = cur_tracks.numpy()
187
+ for i in valid_indices:
188
+ x, y = cur_tracks_np[i]
189
+ pt = (int(round(x)), int(round(y)))
190
+
191
+ # track_colors_rgb[i] is (R,G,B). For OpenCV circle, we need BGR
192
+ R, G, B = track_colors_rgb[i]
193
+ color_bgr = (int(B), int(G), int(R))
194
+ cv2.circle(img_bgr, pt, radius=3, color=color_bgr, thickness=-1)
195
+
196
+ # Convert back to RGB for consistent final saving:
197
+ img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)
198
+
199
+ # Save individual frame
200
+ frame_path = os.path.join(out_dir, f"frame_{s:04d}.png")
201
+ # Convert to BGR for OpenCV imwrite
202
+ frame_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
203
+ cv2.imwrite(frame_path, frame_bgr)
204
+
205
+ frame_images.append(img_rgb)
206
+
207
+ # Only create and save the grid image if save_grid is True
208
+ if save_grid:
209
+ # Calculate grid dimensions
210
+ num_rows = (S + frames_per_row - 1) // frames_per_row # Ceiling division
211
+
212
+ # Create a grid of images
213
+ grid_img = None
214
+ for row in range(num_rows):
215
+ start_idx = row * frames_per_row
216
+ end_idx = min(start_idx + frames_per_row, S)
217
+
218
+ # Concatenate this row horizontally
219
+ row_img = np.concatenate(frame_images[start_idx:end_idx], axis=1)
220
+
221
+ # If this row has fewer than frames_per_row images, pad with black
222
+ if end_idx - start_idx < frames_per_row:
223
+ padding_width = (frames_per_row - (end_idx - start_idx)) * W
224
+ padding = np.zeros((H, padding_width, 3), dtype=np.uint8)
225
+ row_img = np.concatenate([row_img, padding], axis=1)
226
+
227
+ # Add this row to the grid
228
+ if grid_img is None:
229
+ grid_img = row_img
230
+ else:
231
+ grid_img = np.concatenate([grid_img, row_img], axis=0)
232
+
233
+ out_path = os.path.join(out_dir, "tracks_grid.png")
234
+ # Convert back to BGR for OpenCV imwrite
235
+ grid_img_bgr = cv2.cvtColor(grid_img, cv2.COLOR_RGB2BGR)
236
+ cv2.imwrite(out_path, grid_img_bgr)
237
+ print(f"[INFO] Saved color-by-XY track visualization grid -> {out_path}")
238
+
239
+ print(f"[INFO] Saved {S} individual frames to {out_dir}/frame_*.png")
wandb/offline-run-20250711_184611-k8qgu560/files/requirements.txt ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openvla-oft==0.0.1
2
+ coloredlogs==15.0.1
3
+ openvla-oft==0.0.1
4
+ pydantic_core==2.27.2
5
+ opt_einsum==3.4.0
6
+ nvidia-cublas-cu12==12.1.3.1
7
+ nltk==3.9.1
8
+ anyio==4.9.0
9
+ tensorflow-addons==0.23.0
10
+ h11==0.14.0
11
+ setproctitle==1.3.5
12
+ json-numpy==2.1.0
13
+ gast==0.6.0
14
+ protobuf==3.20.3
15
+ array_record==0.7.1
16
+ keras==2.15.0
17
+ scipy==1.15.2
18
+ sentencepiece==0.1.99
19
+ Jinja2==3.1.6
20
+ glfw==2.8.0
21
+ tensorflow-io-gcs-filesystem==0.37.1
22
+ transformers==4.40.1
23
+ gitdb==4.0.12
24
+ packaging==24.2
25
+ ml-dtypes==0.2.0
26
+ pillow==11.1.0
27
+ nvidia-cusolver-cu12==11.4.5.107
28
+ jsonlines==4.0.0
29
+ google-auth==2.38.0
30
+ rpds-py==0.23.1
31
+ nvidia-nvjitlink-cu12==12.4.127
32
+ torch==2.2.0
33
+ fonttools==4.56.0
34
+ opencv-python==4.11.0.86
35
+ numba==0.61.0
36
+ jupyter_core==5.7.2
37
+ grpcio==1.71.0
38
+ peft==0.11.1
39
+ annotated-types==0.7.0
40
+ typing-inspect==0.9.0
41
+ termcolor==2.5.0
42
+ antlr4-python3-runtime==4.9.3
43
+ markdown-it-py==3.0.0
44
+ huggingface-hub==0.29.3
45
+ imageio==2.37.0
46
+ nvidia-nvtx-cu12==12.1.105
47
+ draccus==0.8.0
48
+ mypy-extensions==1.0.0
49
+ future==1.0.0
50
+ onnxsim==0.4.36
51
+ tensorboard-data-server==0.7.2
52
+ six==1.17.0
53
+ tqdm==4.67.1
54
+ rsa==4.9
55
+ typing_extensions==4.12.2
56
+ rich==13.9.4
57
+ nvidia-cusparse-cu12==12.1.0.106
58
+ jsonschema-specifications==2024.10.1
59
+ libclang==18.1.1
60
+ ninja==1.11.1.3
61
+ cloudpickle==3.1.1
62
+ onnx==1.17.0
63
+ python-xlib==0.33
64
+ referencing==0.36.2
65
+ filelock==3.18.0
66
+ debugpy==1.8.13
67
+ pip==25.0
68
+ mdurl==0.1.2
69
+ tensorflow-graphics==2021.12.3
70
+ pydantic==2.10.6
71
+ docker-pycreds==0.4.0
72
+ kiwisolver==1.4.8
73
+ networkx==3.4.2
74
+ pyasn1==0.6.1
75
+ humanfriendly==10.0
76
+ pynput==1.8.0
77
+ certifi==2025.1.31
78
+ pytest==8.3.5
79
+ sniffio==1.3.1
80
+ nbformat==5.10.4
81
+ requests-oauthlib==2.0.0
82
+ etils==1.12.2
83
+ tensorflow-estimator==2.15.0
84
+ cachetools==5.5.2
85
+ click==8.1.8
86
+ importlib_resources==6.5.2
87
+ robosuite==1.4.1
88
+ pyasn1_modules==0.4.1
89
+ nvidia-nccl-cu12==2.19.3
90
+ qwen-vl-utils==0.0.11
91
+ cycler==0.12.1
92
+ nvidia-cufft-cu12==11.0.2.54
93
+ typeguard==2.13.3
94
+ iniconfig==2.0.0
95
+ idna==3.10
96
+ MarkupSafe==3.0.2
97
+ matplotlib==3.10.1
98
+ promise==2.3
99
+ easydict==1.13
100
+ tensorflow-datasets==4.9.3
101
+ Werkzeug==3.1.3
102
+ tomli==2.2.1
103
+ nvidia-cuda-cupti-cu12==12.1.105
104
+ omegaconf==2.3.0
105
+ imageio-ffmpeg==0.6.0
106
+ absl-py==2.1.0
107
+ mujoco==3.3.0
108
+ evdev==1.9.1
109
+ sentry-sdk==2.22.0
110
+ pyparsing==3.2.1
111
+ dm-tree==0.1.9
112
+ psutil==7.0.0
113
+ torchaudio==2.2.0
114
+ h5py==3.13.0
115
+ PyOpenGL==3.1.9
116
+ triton==2.2.0
117
+ fsspec==2025.3.0
118
+ nvidia-cudnn-cu12==8.9.2.26
119
+ trimesh==4.6.4
120
+ Pygments==2.19.1
121
+ nvidia-cuda-runtime-cu12==12.1.105
122
+ wheel==0.45.1
123
+ astunparse==1.6.3
124
+ requests==2.32.3
125
+ importlib_metadata==8.6.1
126
+ starlette==0.46.1
127
+ charset-normalizer==3.4.1
128
+ tokenizers==0.19.1
129
+ accelerate==1.5.2
130
+ tensorflow-metadata==1.16.1
131
+ OpenEXR==3.3.2
132
+ mpmath==1.3.0
133
+ einops==0.8.1
134
+ google-pasta==0.2.0
135
+ exceptiongroup==1.2.2
136
+ bddl==3.5.0
137
+ safetensors==0.5.3
138
+ nvidia-cuda-nvrtc-cu12==12.1.105
139
+ regex==2024.11.6
140
+ zipp==3.21.0
141
+ mdit-py-plugins==0.4.2
142
+ contourpy==1.3.1
143
+ nvidia-cusparselt-cu12==0.6.2
144
+ wandb==0.19.8
145
+ tensorboard==2.15.2
146
+ wrapt==1.14.1
147
+ pyyaml-include==1.4.1
148
+ urllib3==2.3.0
149
+ setuptools==75.8.0
150
+ fastjsonschema==2.21.1
151
+ fastapi==0.115.11
152
+ oauthlib==3.2.2
153
+ uvicorn==0.34.0
154
+ gym-notices==0.0.8
155
+ jupytext==1.16.7
156
+ diffusers==0.32.2
157
+ flatbuffers==25.2.10
158
+ timm==0.9.10
159
+ traitlets==5.14.3
160
+ tensorflow==2.15.0
161
+ flash-attn==2.5.5
162
+ Markdown==3.7
163
+ torchvision==0.17.0
164
+ smmap==5.0.2
165
+ attrs==25.3.0
166
+ google-auth-oauthlib==1.2.1
167
+ av==14.3.0
168
+ onnxruntime==1.21.0
169
+ gym==0.26.2
170
+ platformdirs==4.3.6
171
+ mergedeep==1.3.4
172
+ nvidia-curand-cu12==10.3.2.106
173
+ python-dateutil==2.9.0.post0
174
+ toml==0.10.2
175
+ numpy==1.26.4
176
+ GitPython==3.1.44
177
+ jsonschema==4.23.0
178
+ joblib==1.4.2
179
+ PyYAML==6.0.2
180
+ sympy==1.13.1
181
+ llvmlite==0.44.0
182
+ pluggy==1.5.0
183
+ dlimp==0.0.1
184
+ jaraco.collections==5.1.0
185
+ packaging==24.2
186
+ importlib_metadata==8.0.0
187
+ tomli==2.0.1
188
+ backports.tarfile==1.2.0
189
+ typing_extensions==4.12.2
190
+ jaraco.context==5.3.0
191
+ typeguard==4.3.0
192
+ wheel==0.43.0
193
+ autocommand==2.2.2
194
+ jaraco.text==3.12.1
195
+ more-itertools==10.3.0
196
+ platformdirs==4.2.2
197
+ inflect==7.3.1
198
+ jaraco.functools==4.0.1
199
+ zipp==3.19.2
wandb/offline-run-20250711_184611-k8qgu560/files/wandb-metadata.json ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-144-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.10.16",
4
+ "startedAt": "2025-07-11T10:46:12.241804Z",
5
+ "args": [
6
+ "--vla_path",
7
+ "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b",
8
+ "--data_root_dir",
9
+ "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/TianxingChen/RoboTwin2.0/tfds",
10
+ "--dataset_name",
11
+ "aloha_agilex_robotwin2_benchmark",
12
+ "--run_root_dir",
13
+ "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2/results/simvla_all_25",
14
+ "--use_l1_regression",
15
+ "True",
16
+ "--use_diffusion",
17
+ "False",
18
+ "--use_film",
19
+ "True",
20
+ "--num_images_in_input",
21
+ "3",
22
+ "--use_proprio",
23
+ "True",
24
+ "--batch_size",
25
+ "4",
26
+ "--learning_rate",
27
+ "1e-4",
28
+ "--num_steps_before_decay",
29
+ "10000",
30
+ "--max_steps",
31
+ "20000",
32
+ "--save_freq",
33
+ "10000",
34
+ "--save_latest_checkpoint_only",
35
+ "False",
36
+ "--image_aug",
37
+ "True",
38
+ "--lora_rank",
39
+ "32",
40
+ "--wandb_entity",
41
+ "chenghaha",
42
+ "--wandb_project",
43
+ "robotwin",
44
+ "--wandb_log_freq",
45
+ "1",
46
+ "--run_id_note",
47
+ "simvla_all_25_inner1_proj_type_onlynorm_ffn_type_relu_mlp_ffn_decoder_num_blocks_4-M20000-F10000-D10000",
48
+ "--use_predict_future_prop",
49
+ "False",
50
+ "--use_action_ts_head",
51
+ "True",
52
+ "--use_one_embed",
53
+ "True",
54
+ "--use_multi_scaling",
55
+ "False",
56
+ "--mlp_type",
57
+ "ffn",
58
+ "--decoder_num_blocks",
59
+ "4",
60
+ "--robot_platform",
61
+ "aloha",
62
+ "--proj_type",
63
+ "onlynorm",
64
+ "--ffn_type",
65
+ "relu",
66
+ "--expand_inner_ratio",
67
+ "1",
68
+ "--linear_drop_ratio",
69
+ "0.0",
70
+ "--multi_query_norm_type",
71
+ "layernorm",
72
+ "--multi_queries_num",
73
+ "25",
74
+ "--action_norm",
75
+ "layernorm",
76
+ "--use_fredf",
77
+ "False"
78
+ ],
79
+ "program": "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2/vla-scripts/finetune.py",
80
+ "codePath": "vla-scripts/finetune.py",
81
+ "git": {
82
+ "remote": "https://github.com/cheng-haha/SimVLA.git",
83
+ "commit": "7cd4f4827011a08c125854a6cbb073e24ed00b43"
84
+ },
85
+ "root": "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2",
86
+ "host": "lmms5--aa4284ee6b0c-imxfhs5an2",
87
+ "executable": "/opt/conda/envs/openvla-oft/bin/python",
88
+ "codePathLocal": "vla-scripts/finetune.py",
89
+ "cpu_count": 64,
90
+ "cpu_count_logical": 128,
91
+ "gpu": "NVIDIA H100 80GB HBM3",
92
+ "gpu_count": 4,
93
+ "disk": {
94
+ "/": {
95
+ "total": "7679364128768",
96
+ "used": "5943319695360"
97
+ }
98
+ },
99
+ "memory": {
100
+ "total": "2164203204608"
101
+ },
102
+ "cpu": {
103
+ "count": 64,
104
+ "countLogical": 128
105
+ },
106
+ "gpu_nvidia": [
107
+ {
108
+ "name": "NVIDIA H100 80GB HBM3",
109
+ "memoryTotal": "85520809984",
110
+ "cudaCores": 16896,
111
+ "architecture": "Hopper"
112
+ },
113
+ {
114
+ "name": "NVIDIA H100 80GB HBM3",
115
+ "memoryTotal": "85520809984",
116
+ "cudaCores": 16896,
117
+ "architecture": "Hopper"
118
+ },
119
+ {
120
+ "name": "NVIDIA H100 80GB HBM3",
121
+ "memoryTotal": "85520809984",
122
+ "cudaCores": 16896,
123
+ "architecture": "Hopper"
124
+ },
125
+ {
126
+ "name": "NVIDIA H100 80GB HBM3",
127
+ "memoryTotal": "85520809984",
128
+ "cudaCores": 16896,
129
+ "architecture": "Hopper"
130
+ }
131
+ ],
132
+ "cudaVersion": "12.2"
133
+ }
wandb/offline-run-20250711_184611-k8qgu560/logs/debug-internal.log ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {"time":"2025-07-11T18:46:12.32323806+08:00","level":"INFO","msg":"stream: starting","core version":"0.19.8","symlink path":"/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2/wandb/offline-run-20250711_184611-k8qgu560/logs/debug-core.log"}
2
+ {"time":"2025-07-11T18:46:12.562037522+08:00","level":"INFO","msg":"created new stream","id":"k8qgu560"}
3
+ {"time":"2025-07-11T18:46:12.562125722+08:00","level":"INFO","msg":"stream: started","id":"k8qgu560"}
4
+ {"time":"2025-07-11T18:46:12.562186708+08:00","level":"INFO","msg":"handler: started","stream_id":"k8qgu560"}
5
+ {"time":"2025-07-11T18:46:12.562194872+08:00","level":"INFO","msg":"writer: Do: started","stream_id":"k8qgu560"}
6
+ {"time":"2025-07-11T18:46:12.562223056+08:00","level":"INFO","msg":"sender: started","stream_id":"k8qgu560"}
7
+ {"time":"2025-07-11T18:46:12.566390076+08:00","level":"INFO","msg":"Starting system monitor"}
wandb/offline-run-20250711_184611-k8qgu560/logs/debug.log ADDED
File without changes
wandb/offline-run-20250711_211915-s4epglyq/files/requirements.txt ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openvla-oft==0.0.1
2
+ coloredlogs==15.0.1
3
+ openvla-oft==0.0.1
4
+ pydantic_core==2.27.2
5
+ opt_einsum==3.4.0
6
+ nvidia-cublas-cu12==12.1.3.1
7
+ nltk==3.9.1
8
+ anyio==4.9.0
9
+ tensorflow-addons==0.23.0
10
+ h11==0.14.0
11
+ setproctitle==1.3.5
12
+ json-numpy==2.1.0
13
+ gast==0.6.0
14
+ protobuf==3.20.3
15
+ array_record==0.7.1
16
+ keras==2.15.0
17
+ scipy==1.15.2
18
+ sentencepiece==0.1.99
19
+ Jinja2==3.1.6
20
+ glfw==2.8.0
21
+ tensorflow-io-gcs-filesystem==0.37.1
22
+ transformers==4.40.1
23
+ gitdb==4.0.12
24
+ packaging==24.2
25
+ ml-dtypes==0.2.0
26
+ pillow==11.1.0
27
+ nvidia-cusolver-cu12==11.4.5.107
28
+ jsonlines==4.0.0
29
+ google-auth==2.38.0
30
+ rpds-py==0.23.1
31
+ nvidia-nvjitlink-cu12==12.4.127
32
+ torch==2.2.0
33
+ fonttools==4.56.0
34
+ opencv-python==4.11.0.86
35
+ numba==0.61.0
36
+ jupyter_core==5.7.2
37
+ grpcio==1.71.0
38
+ peft==0.11.1
39
+ annotated-types==0.7.0
40
+ typing-inspect==0.9.0
41
+ termcolor==2.5.0
42
+ antlr4-python3-runtime==4.9.3
43
+ markdown-it-py==3.0.0
44
+ huggingface-hub==0.29.3
45
+ imageio==2.37.0
46
+ nvidia-nvtx-cu12==12.1.105
47
+ draccus==0.8.0
48
+ mypy-extensions==1.0.0
49
+ future==1.0.0
50
+ onnxsim==0.4.36
51
+ tensorboard-data-server==0.7.2
52
+ six==1.17.0
53
+ tqdm==4.67.1
54
+ rsa==4.9
55
+ typing_extensions==4.12.2
56
+ rich==13.9.4
57
+ nvidia-cusparse-cu12==12.1.0.106
58
+ jsonschema-specifications==2024.10.1
59
+ libclang==18.1.1
60
+ ninja==1.11.1.3
61
+ cloudpickle==3.1.1
62
+ onnx==1.17.0
63
+ python-xlib==0.33
64
+ referencing==0.36.2
65
+ filelock==3.18.0
66
+ debugpy==1.8.13
67
+ pip==25.0
68
+ mdurl==0.1.2
69
+ tensorflow-graphics==2021.12.3
70
+ pydantic==2.10.6
71
+ docker-pycreds==0.4.0
72
+ kiwisolver==1.4.8
73
+ networkx==3.4.2
74
+ pyasn1==0.6.1
75
+ humanfriendly==10.0
76
+ pynput==1.8.0
77
+ certifi==2025.1.31
78
+ pytest==8.3.5
79
+ sniffio==1.3.1
80
+ nbformat==5.10.4
81
+ requests-oauthlib==2.0.0
82
+ etils==1.12.2
83
+ tensorflow-estimator==2.15.0
84
+ cachetools==5.5.2
85
+ click==8.1.8
86
+ importlib_resources==6.5.2
87
+ robosuite==1.4.1
88
+ pyasn1_modules==0.4.1
89
+ nvidia-nccl-cu12==2.19.3
90
+ qwen-vl-utils==0.0.11
91
+ cycler==0.12.1
92
+ nvidia-cufft-cu12==11.0.2.54
93
+ typeguard==2.13.3
94
+ iniconfig==2.0.0
95
+ idna==3.10
96
+ MarkupSafe==3.0.2
97
+ matplotlib==3.10.1
98
+ promise==2.3
99
+ easydict==1.13
100
+ tensorflow-datasets==4.9.3
101
+ Werkzeug==3.1.3
102
+ tomli==2.2.1
103
+ nvidia-cuda-cupti-cu12==12.1.105
104
+ omegaconf==2.3.0
105
+ imageio-ffmpeg==0.6.0
106
+ absl-py==2.1.0
107
+ mujoco==3.3.0
108
+ evdev==1.9.1
109
+ sentry-sdk==2.22.0
110
+ pyparsing==3.2.1
111
+ dm-tree==0.1.9
112
+ psutil==7.0.0
113
+ torchaudio==2.2.0
114
+ h5py==3.13.0
115
+ PyOpenGL==3.1.9
116
+ triton==2.2.0
117
+ fsspec==2025.3.0
118
+ nvidia-cudnn-cu12==8.9.2.26
119
+ trimesh==4.6.4
120
+ Pygments==2.19.1
121
+ nvidia-cuda-runtime-cu12==12.1.105
122
+ wheel==0.45.1
123
+ astunparse==1.6.3
124
+ requests==2.32.3
125
+ importlib_metadata==8.6.1
126
+ starlette==0.46.1
127
+ charset-normalizer==3.4.1
128
+ tokenizers==0.19.1
129
+ accelerate==1.5.2
130
+ tensorflow-metadata==1.16.1
131
+ OpenEXR==3.3.2
132
+ mpmath==1.3.0
133
+ einops==0.8.1
134
+ google-pasta==0.2.0
135
+ exceptiongroup==1.2.2
136
+ bddl==3.5.0
137
+ safetensors==0.5.3
138
+ nvidia-cuda-nvrtc-cu12==12.1.105
139
+ regex==2024.11.6
140
+ zipp==3.21.0
141
+ mdit-py-plugins==0.4.2
142
+ contourpy==1.3.1
143
+ nvidia-cusparselt-cu12==0.6.2
144
+ wandb==0.19.8
145
+ tensorboard==2.15.2
146
+ wrapt==1.14.1
147
+ pyyaml-include==1.4.1
148
+ urllib3==2.3.0
149
+ setuptools==75.8.0
150
+ fastjsonschema==2.21.1
151
+ fastapi==0.115.11
152
+ oauthlib==3.2.2
153
+ uvicorn==0.34.0
154
+ gym-notices==0.0.8
155
+ jupytext==1.16.7
156
+ diffusers==0.32.2
157
+ flatbuffers==25.2.10
158
+ timm==0.9.10
159
+ traitlets==5.14.3
160
+ tensorflow==2.15.0
161
+ flash-attn==2.5.5
162
+ Markdown==3.7
163
+ torchvision==0.17.0
164
+ smmap==5.0.2
165
+ attrs==25.3.0
166
+ google-auth-oauthlib==1.2.1
167
+ av==14.3.0
168
+ onnxruntime==1.21.0
169
+ gym==0.26.2
170
+ platformdirs==4.3.6
171
+ mergedeep==1.3.4
172
+ nvidia-curand-cu12==10.3.2.106
173
+ python-dateutil==2.9.0.post0
174
+ toml==0.10.2
175
+ numpy==1.26.4
176
+ GitPython==3.1.44
177
+ jsonschema==4.23.0
178
+ joblib==1.4.2
179
+ PyYAML==6.0.2
180
+ sympy==1.13.1
181
+ llvmlite==0.44.0
182
+ pluggy==1.5.0
183
+ dlimp==0.0.1
184
+ jaraco.collections==5.1.0
185
+ packaging==24.2
186
+ importlib_metadata==8.0.0
187
+ tomli==2.0.1
188
+ backports.tarfile==1.2.0
189
+ typing_extensions==4.12.2
190
+ jaraco.context==5.3.0
191
+ typeguard==4.3.0
192
+ wheel==0.43.0
193
+ autocommand==2.2.2
194
+ jaraco.text==3.12.1
195
+ more-itertools==10.3.0
196
+ platformdirs==4.2.2
197
+ inflect==7.3.1
198
+ jaraco.functools==4.0.1
199
+ zipp==3.19.2
wandb/offline-run-20250711_211915-s4epglyq/files/wandb-metadata.json ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "os": "Linux-5.4.0-144-generic-x86_64-with-glibc2.35",
3
+ "python": "CPython 3.10.16",
4
+ "startedAt": "2025-07-11T13:19:15.784155Z",
5
+ "args": [
6
+ "--vla_path",
7
+ "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/ai_models/openvla/openvla-7b",
8
+ "--data_root_dir",
9
+ "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/datasets/TianxingChen/RoboTwin2.0/tfds",
10
+ "--dataset_name",
11
+ "aloha_agilex_robotwin2_benchmark",
12
+ "--run_root_dir",
13
+ "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2/results/simvla_patch_all_25",
14
+ "--use_l1_regression",
15
+ "True",
16
+ "--use_diffusion",
17
+ "False",
18
+ "--use_film",
19
+ "True",
20
+ "--num_images_in_input",
21
+ "3",
22
+ "--use_proprio",
23
+ "True",
24
+ "--batch_size",
25
+ "4",
26
+ "--learning_rate",
27
+ "1e-4",
28
+ "--num_steps_before_decay",
29
+ "10000",
30
+ "--max_steps",
31
+ "20000",
32
+ "--save_freq",
33
+ "10000",
34
+ "--save_latest_checkpoint_only",
35
+ "False",
36
+ "--image_aug",
37
+ "True",
38
+ "--lora_rank",
39
+ "32",
40
+ "--wandb_entity",
41
+ "chenghaha",
42
+ "--wandb_project",
43
+ "robotwin",
44
+ "--wandb_log_freq",
45
+ "1",
46
+ "--run_id_note",
47
+ "simvla_patch_all_25_inner1_proj_type_onlynorm_ffn_type_relu_mlp_ffn_decoder_num_blocks_4-M20000-F10000-D10000",
48
+ "--use_predict_future_prop",
49
+ "False",
50
+ "--use_action_ts_head",
51
+ "True",
52
+ "--use_one_embed",
53
+ "True",
54
+ "--use_multi_scaling",
55
+ "False",
56
+ "--mlp_type",
57
+ "ffn",
58
+ "--decoder_num_blocks",
59
+ "4",
60
+ "--robot_platform",
61
+ "aloha",
62
+ "--proj_type",
63
+ "onlynorm",
64
+ "--ffn_type",
65
+ "relu",
66
+ "--expand_inner_ratio",
67
+ "1",
68
+ "--linear_drop_ratio",
69
+ "0.5",
70
+ "--multi_query_norm_type",
71
+ "layernorm",
72
+ "--multi_queries_num",
73
+ "2",
74
+ "--action_norm",
75
+ "layernorm",
76
+ "--use_fredf",
77
+ "False",
78
+ "--use_patch_wise_loss",
79
+ "True",
80
+ "--use_dual_arm_head",
81
+ "True"
82
+ ],
83
+ "program": "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2/vla-scripts/finetune.py",
84
+ "codePath": "vla-scripts/finetune.py",
85
+ "git": {
86
+ "remote": "https://github.com/cheng-haha/SimVLA.git",
87
+ "commit": "7cd4f4827011a08c125854a6cbb073e24ed00b43"
88
+ },
89
+ "root": "/inspire/hdd/ws-f4d69b29-e0a5-44e6-bd92-acf4de9990f0/public-project/chengdongzhou-240108390137/vla_projects/simvla_twin2",
90
+ "host": "lmms5--aa4284ee6b0c-imxfhs5an2",
91
+ "executable": "/opt/conda/envs/openvla-oft/bin/python",
92
+ "codePathLocal": "vla-scripts/finetune.py",
93
+ "cpu_count": 64,
94
+ "cpu_count_logical": 128,
95
+ "gpu": "NVIDIA H100 80GB HBM3",
96
+ "gpu_count": 4,
97
+ "disk": {
98
+ "/": {
99
+ "total": "7679364128768",
100
+ "used": "5943384707072"
101
+ }
102
+ },
103
+ "memory": {
104
+ "total": "2164203204608"
105
+ },
106
+ "cpu": {
107
+ "count": 64,
108
+ "countLogical": 128
109
+ },
110
+ "gpu_nvidia": [
111
+ {
112
+ "name": "NVIDIA H100 80GB HBM3",
113
+ "memoryTotal": "85520809984",
114
+ "cudaCores": 16896,
115
+ "architecture": "Hopper"
116
+ },
117
+ {
118
+ "name": "NVIDIA H100 80GB HBM3",
119
+ "memoryTotal": "85520809984",
120
+ "cudaCores": 16896,
121
+ "architecture": "Hopper"
122
+ },
123
+ {
124
+ "name": "NVIDIA H100 80GB HBM3",
125
+ "memoryTotal": "85520809984",
126
+ "cudaCores": 16896,
127
+ "architecture": "Hopper"
128
+ },
129
+ {
130
+ "name": "NVIDIA H100 80GB HBM3",
131
+ "memoryTotal": "85520809984",
132
+ "cudaCores": 16896,
133
+ "architecture": "Hopper"
134
+ }
135
+ ],
136
+ "cudaVersion": "12.2"
137
+ }
wandb/offline-run-20250711_211915-s4epglyq/logs/debug.log ADDED
File without changes
wandb/offline-run-20250711_211915-s4epglyq/run-s4epglyq.wandb ADDED
Binary file (98.3 kB). View file
 
wandb/offline-run-20250711_212208-i2mclkeg/files/requirements.txt ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ openvla-oft==0.0.1
2
+ coloredlogs==15.0.1
3
+ openvla-oft==0.0.1
4
+ pydantic_core==2.27.2
5
+ opt_einsum==3.4.0
6
+ nvidia-cublas-cu12==12.1.3.1
7
+ nltk==3.9.1
8
+ anyio==4.9.0
9
+ tensorflow-addons==0.23.0
10
+ h11==0.14.0
11
+ setproctitle==1.3.5
12
+ json-numpy==2.1.0
13
+ gast==0.6.0
14
+ protobuf==3.20.3
15
+ array_record==0.7.1
16
+ keras==2.15.0
17
+ scipy==1.15.2
18
+ sentencepiece==0.1.99
19
+ Jinja2==3.1.6
20
+ glfw==2.8.0
21
+ tensorflow-io-gcs-filesystem==0.37.1
22
+ transformers==4.40.1
23
+ gitdb==4.0.12
24
+ packaging==24.2
25
+ ml-dtypes==0.2.0
26
+ pillow==11.1.0
27
+ nvidia-cusolver-cu12==11.4.5.107
28
+ jsonlines==4.0.0
29
+ google-auth==2.38.0
30
+ rpds-py==0.23.1
31
+ nvidia-nvjitlink-cu12==12.4.127
32
+ torch==2.2.0
33
+ fonttools==4.56.0
34
+ opencv-python==4.11.0.86
35
+ numba==0.61.0
36
+ jupyter_core==5.7.2
37
+ grpcio==1.71.0
38
+ peft==0.11.1
39
+ annotated-types==0.7.0
40
+ typing-inspect==0.9.0
41
+ termcolor==2.5.0
42
+ antlr4-python3-runtime==4.9.3
43
+ markdown-it-py==3.0.0
44
+ huggingface-hub==0.29.3
45
+ imageio==2.37.0
46
+ nvidia-nvtx-cu12==12.1.105
47
+ draccus==0.8.0
48
+ mypy-extensions==1.0.0
49
+ future==1.0.0
50
+ onnxsim==0.4.36
51
+ tensorboard-data-server==0.7.2
52
+ six==1.17.0
53
+ tqdm==4.67.1
54
+ rsa==4.9
55
+ typing_extensions==4.12.2
56
+ rich==13.9.4
57
+ nvidia-cusparse-cu12==12.1.0.106
58
+ jsonschema-specifications==2024.10.1
59
+ libclang==18.1.1
60
+ ninja==1.11.1.3
61
+ cloudpickle==3.1.1
62
+ onnx==1.17.0
63
+ python-xlib==0.33
64
+ referencing==0.36.2
65
+ filelock==3.18.0
66
+ debugpy==1.8.13
67
+ pip==25.0
68
+ mdurl==0.1.2
69
+ tensorflow-graphics==2021.12.3
70
+ pydantic==2.10.6
71
+ docker-pycreds==0.4.0
72
+ kiwisolver==1.4.8
73
+ networkx==3.4.2
74
+ pyasn1==0.6.1
75
+ humanfriendly==10.0
76
+ pynput==1.8.0
77
+ certifi==2025.1.31
78
+ pytest==8.3.5
79
+ sniffio==1.3.1
80
+ nbformat==5.10.4
81
+ requests-oauthlib==2.0.0
82
+ etils==1.12.2
83
+ tensorflow-estimator==2.15.0
84
+ cachetools==5.5.2
85
+ click==8.1.8
86
+ importlib_resources==6.5.2
87
+ robosuite==1.4.1
88
+ pyasn1_modules==0.4.1
89
+ nvidia-nccl-cu12==2.19.3
90
+ qwen-vl-utils==0.0.11
91
+ cycler==0.12.1
92
+ nvidia-cufft-cu12==11.0.2.54
93
+ typeguard==2.13.3
94
+ iniconfig==2.0.0
95
+ idna==3.10
96
+ MarkupSafe==3.0.2
97
+ matplotlib==3.10.1
98
+ promise==2.3
99
+ easydict==1.13
100
+ tensorflow-datasets==4.9.3
101
+ Werkzeug==3.1.3
102
+ tomli==2.2.1
103
+ nvidia-cuda-cupti-cu12==12.1.105
104
+ omegaconf==2.3.0
105
+ imageio-ffmpeg==0.6.0
106
+ absl-py==2.1.0
107
+ mujoco==3.3.0
108
+ evdev==1.9.1
109
+ sentry-sdk==2.22.0
110
+ pyparsing==3.2.1
111
+ dm-tree==0.1.9
112
+ psutil==7.0.0
113
+ torchaudio==2.2.0
114
+ h5py==3.13.0
115
+ PyOpenGL==3.1.9
116
+ triton==2.2.0
117
+ fsspec==2025.3.0
118
+ nvidia-cudnn-cu12==8.9.2.26
119
+ trimesh==4.6.4
120
+ Pygments==2.19.1
121
+ nvidia-cuda-runtime-cu12==12.1.105
122
+ wheel==0.45.1
123
+ astunparse==1.6.3
124
+ requests==2.32.3
125
+ importlib_metadata==8.6.1
126
+ starlette==0.46.1
127
+ charset-normalizer==3.4.1
128
+ tokenizers==0.19.1
129
+ accelerate==1.5.2
130
+ tensorflow-metadata==1.16.1
131
+ OpenEXR==3.3.2
132
+ mpmath==1.3.0
133
+ einops==0.8.1
134
+ google-pasta==0.2.0
135
+ exceptiongroup==1.2.2
136
+ bddl==3.5.0
137
+ safetensors==0.5.3
138
+ nvidia-cuda-nvrtc-cu12==12.1.105
139
+ regex==2024.11.6
140
+ zipp==3.21.0
141
+ mdit-py-plugins==0.4.2
142
+ contourpy==1.3.1
143
+ nvidia-cusparselt-cu12==0.6.2
144
+ wandb==0.19.8
145
+ tensorboard==2.15.2
146
+ wrapt==1.14.1
147
+ pyyaml-include==1.4.1
148
+ urllib3==2.3.0
149
+ setuptools==75.8.0
150
+ fastjsonschema==2.21.1
151
+ fastapi==0.115.11
152
+ oauthlib==3.2.2
153
+ uvicorn==0.34.0
154
+ gym-notices==0.0.8
155
+ jupytext==1.16.7
156
+ diffusers==0.32.2
157
+ flatbuffers==25.2.10
158
+ timm==0.9.10
159
+ traitlets==5.14.3
160
+ tensorflow==2.15.0
161
+ flash-attn==2.5.5
162
+ Markdown==3.7
163
+ torchvision==0.17.0
164
+ smmap==5.0.2
165
+ attrs==25.3.0
166
+ google-auth-oauthlib==1.2.1
167
+ av==14.3.0
168
+ onnxruntime==1.21.0
169
+ gym==0.26.2
170
+ platformdirs==4.3.6
171
+ mergedeep==1.3.4
172
+ nvidia-curand-cu12==10.3.2.106
173
+ python-dateutil==2.9.0.post0
174
+ toml==0.10.2
175
+ numpy==1.26.4
176
+ GitPython==3.1.44
177
+ jsonschema==4.23.0
178
+ joblib==1.4.2
179
+ PyYAML==6.0.2
180
+ sympy==1.13.1
181
+ llvmlite==0.44.0
182
+ pluggy==1.5.0
183
+ dlimp==0.0.1
184
+ jaraco.collections==5.1.0
185
+ packaging==24.2
186
+ importlib_metadata==8.0.0
187
+ tomli==2.0.1
188
+ backports.tarfile==1.2.0
189
+ typing_extensions==4.12.2
190
+ jaraco.context==5.3.0
191
+ typeguard==4.3.0
192
+ wheel==0.43.0
193
+ autocommand==2.2.2
194
+ jaraco.text==3.12.1
195
+ more-itertools==10.3.0
196
+ platformdirs==4.2.2
197
+ inflect==7.3.1
198
+ jaraco.functools==4.0.1
199
+ zipp==3.19.2