lsnu commited on
Commit
7eb3f10
·
verified ·
1 Parent(s): 80c771a

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +9 -0
  2. environment/base_pip_freeze.txt +176 -0
  3. environment/base_python.txt +1 -0
  4. environment/env_list.txt +4 -0
  5. environment/hardware_snapshot.txt +1 -0
  6. environment/nvidia_smi.txt +22 -0
  7. environment/reconstruct_anybimanual_overlap_replay.sh +22 -0
  8. environment/rlbench_pip_freeze.txt +200 -0
  9. environment/rlbench_python.txt +1 -0
  10. environment/runtime_env_vars.sh +4 -0
  11. environment/setup_same_hardware.sh +25 -0
  12. environment/uname.txt +1 -0
  13. handoff/instructions4.md +591 -0
  14. history/VLAarchtests_previous_README.md +172 -0
  15. metadata/source_sizes.txt +4 -0
  16. metadata/staged_size.txt +1 -0
  17. metadata/staged_tree_top2.txt +64 -0
  18. third_party/AnyBimanual/agents/__init__.py +0 -0
  19. third_party/AnyBimanual/agents/agent_factory.py +101 -0
  20. third_party/AnyBimanual/agents/peract_bc/__init__.py +1 -0
  21. third_party/AnyBimanual/agents/peract_bc/launch_utils.py +128 -0
  22. third_party/AnyBimanual/agents/peract_bc/perceiver_lang_io.py +481 -0
  23. third_party/AnyBimanual/agents/peract_bc/qattention_peract_bc_agent.py +939 -0
  24. third_party/AnyBimanual/agents/peract_bc/qattention_stack_agent.py +132 -0
  25. third_party/AnyBimanual/agents/peract_bc/skill_manager.py +70 -0
  26. third_party/AnyBimanual/agents/peract_bc/trajectory_gpt2.py +775 -0
  27. third_party/AnyBimanual/agents/peract_bc/visual_aligner.py +39 -0
  28. third_party/AnyBimanual/agents/peract_bimanual/__init__.py +1 -0
  29. third_party/AnyBimanual/agents/peract_bimanual/launch_utils.py +117 -0
  30. third_party/AnyBimanual/agents/peract_bimanual/perceiver_lang_io.py +628 -0
  31. third_party/AnyBimanual/agents/peract_bimanual/qattention_peract_bc_agent.py +1317 -0
  32. third_party/AnyBimanual/agents/peract_bimanual/qattention_stack_agent.py +209 -0
  33. third_party/AnyBimanual/agents/peract_bimanual/skill_manager.py +70 -0
  34. third_party/AnyBimanual/agents/peract_bimanual/trajectory_gpt2.py +775 -0
  35. third_party/AnyBimanual/agents/peract_bimanual/visual_aligner.py +39 -0
  36. third_party/AnyBimanual/agents/replay_utils.py +667 -0
  37. third_party/AnyBimanual/agents/rvt/__init__.py +6 -0
  38. third_party/AnyBimanual/agents/rvt/launch_utils.py +221 -0
  39. third_party/AnyBimanual/agents/rvt/rvt/config.py +54 -0
  40. third_party/AnyBimanual/agents/rvt/rvt/configs/peract_official_config.yaml +127 -0
  41. third_party/AnyBimanual/agents/rvt/rvt/configs/rvt.yaml +15 -0
  42. third_party/AnyBimanual/agents/rvt/rvt/configs/rvt2.yaml +19 -0
  43. third_party/AnyBimanual/agents/rvt/rvt/eval.py +556 -0
  44. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitattributes +1 -0
  45. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitignore +4 -0
  46. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/LICENSE +97 -0
  47. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/README.md +55 -0
  48. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/demo.png +3 -0
  49. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/image_0_splat_2xaa.png +0 -0
  50. third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/point_renderer/cameras.py +119 -0
.gitattributes CHANGED
@@ -14068,3 +14068,12 @@ baselines/AnyBimanual_overlap_replay/multi/10000-14999/14031.replay filter=lfs d
14068
  baselines/AnyBimanual_overlap_replay/multi/10000-14999/14030.replay filter=lfs diff=lfs merge=lfs -text
14069
  baselines/AnyBimanual_overlap_replay/multi/10000-14999/14032.replay filter=lfs diff=lfs merge=lfs -text
14070
  baselines/AnyBimanual_overlap_replay/multi/10000-14999/14029.replay filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
14068
  baselines/AnyBimanual_overlap_replay/multi/10000-14999/14030.replay filter=lfs diff=lfs merge=lfs -text
14069
  baselines/AnyBimanual_overlap_replay/multi/10000-14999/14032.replay filter=lfs diff=lfs merge=lfs -text
14070
  baselines/AnyBimanual_overlap_replay/multi/10000-14999/14029.replay filter=lfs diff=lfs merge=lfs -text
14071
+ third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/demo.png filter=lfs diff=lfs merge=lfs -text
14072
+ third_party/AnyBimanual/third_party/RLBench/readme_files/task_grid.png filter=lfs diff=lfs merge=lfs -text
14073
+ third_party/AnyBimanual/third_party/PyRep/tutorials/images/kinematics_group.png filter=lfs diff=lfs merge=lfs -text
14074
+ third_party/AnyBimanual/third_party/PyRep/tutorials/images/collision_collections.png filter=lfs diff=lfs merge=lfs -text
14075
+ third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene_robots.ttt filter=lfs diff=lfs merge=lfs -text
14076
+ third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene_mobiles_with_arms.ttt filter=lfs diff=lfs merge=lfs -text
14077
+ third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene_mobiles.ttt filter=lfs diff=lfs merge=lfs -text
14078
+ third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene.ttt filter=lfs diff=lfs merge=lfs -text
14079
+ third_party/AnyBimanual/third_party/PyRep/tests/assets/cracker_box/texture_map.png filter=lfs diff=lfs merge=lfs -text
environment/base_pip_freeze.txt ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.13.0
2
+ annotated-doc==0.0.4
3
+ antlr4-python3-runtime==4.9.3
4
+ anyio==4.6.0
5
+ argon2-cffi==23.1.0
6
+ argon2-cffi-bindings==21.2.0
7
+ arrow==1.3.0
8
+ asttokens==2.4.1
9
+ async-lru==2.0.4
10
+ attrs==24.2.0
11
+ babel==2.16.0
12
+ beautifulsoup4==4.12.3
13
+ bleach==6.1.0
14
+ blinker==1.4
15
+ certifi==2024.8.30
16
+ cffi==1.17.1
17
+ charset-normalizer==3.3.2
18
+ click==8.3.1
19
+ comm==0.2.2
20
+ cryptography==3.4.8
21
+ cuda-bindings==12.9.4
22
+ cuda-pathfinder==1.2.2
23
+ cuda-toolkit==12.8.1
24
+ dbus-python==1.2.18
25
+ debugpy==1.8.5
26
+ decorator==5.1.1
27
+ defusedxml==0.7.1
28
+ distro==1.7.0
29
+ entrypoints==0.4
30
+ execnet==2.1.2
31
+ executing==2.1.0
32
+ fastjsonschema==2.20.0
33
+ filelock==3.13.1
34
+ fqdn==1.5.1
35
+ fsspec==2024.2.0
36
+ h11==0.14.0
37
+ hf-xet==1.4.2
38
+ httpcore==1.0.5
39
+ httplib2==0.20.2
40
+ httpx==0.27.2
41
+ huggingface_hub==1.8.0
42
+ idna==3.10
43
+ importlib-metadata==4.6.4
44
+ iniconfig==2.3.0
45
+ ipykernel==6.29.5
46
+ ipython==8.27.0
47
+ ipython-genutils==0.2.0
48
+ ipywidgets==8.1.5
49
+ isoduration==20.11.0
50
+ jedi==0.19.1
51
+ jeepney==0.7.1
52
+ Jinja2==3.1.3
53
+ json5==0.9.25
54
+ jsonpointer==3.0.0
55
+ jsonschema==4.23.0
56
+ jsonschema-specifications==2023.12.1
57
+ jupyter-archive==3.4.0
58
+ jupyter-events==0.10.0
59
+ jupyter-highlight-selected-word==0.2.0
60
+ jupyter-lsp==2.2.5
61
+ jupyter_client==7.4.9
62
+ jupyter_contrib_core==0.4.2
63
+ jupyter_contrib_nbextensions==0.7.0
64
+ jupyter_core==5.7.2
65
+ jupyter_nbextensions_configurator==0.6.4
66
+ jupyter_server==2.14.2
67
+ jupyter_server_terminals==0.5.3
68
+ jupyterlab==4.2.5
69
+ jupyterlab_pygments==0.3.0
70
+ jupyterlab_server==2.27.3
71
+ jupyterlab_widgets==3.0.13
72
+ keyring==23.5.0
73
+ launchpadlib==1.10.16
74
+ lazr.restfulclient==0.14.4
75
+ lazr.uri==1.0.6
76
+ lxml==5.3.0
77
+ markdown-it-py==4.0.0
78
+ MarkupSafe==2.1.5
79
+ matplotlib-inline==0.1.7
80
+ mdurl==0.1.2
81
+ mistune==3.0.2
82
+ more-itertools==8.10.0
83
+ mpmath==1.3.0
84
+ nbclassic==1.1.0
85
+ nbclient==0.10.0
86
+ nbconvert==7.16.4
87
+ nbformat==5.10.4
88
+ nest-asyncio==1.6.0
89
+ networkx==3.2.1
90
+ notebook==6.5.5
91
+ notebook_shim==0.2.4
92
+ numpy==1.26.3
93
+ nvidia-cublas-cu12==12.8.4.1
94
+ nvidia-cuda-cupti-cu12==12.8.90
95
+ nvidia-cuda-nvrtc-cu12==12.8.93
96
+ nvidia-cuda-runtime-cu12==12.8.90
97
+ nvidia-cudnn-cu12==9.19.0.56
98
+ nvidia-cufft-cu12==11.3.3.83
99
+ nvidia-cufile-cu12==1.13.1.3
100
+ nvidia-curand-cu12==10.3.9.90
101
+ nvidia-cusolver-cu12==11.7.3.90
102
+ nvidia-cusparse-cu12==12.5.8.93
103
+ nvidia-cusparselt-cu12==0.7.1
104
+ nvidia-nccl-cu12==2.28.9
105
+ nvidia-nvjitlink-cu12==12.8.93
106
+ nvidia-nvshmem-cu12==3.4.5
107
+ nvidia-nvtx-cu12==12.8.90
108
+ oauthlib==3.2.0
109
+ omegaconf==2.3.0
110
+ overrides==7.7.0
111
+ packaging==24.1
112
+ pandocfilters==1.5.1
113
+ parso==0.8.4
114
+ pexpect==4.9.0
115
+ pillow==10.2.0
116
+ platformdirs==4.3.6
117
+ pluggy==1.6.0
118
+ prometheus_client==0.21.0
119
+ prompt_toolkit==3.0.47
120
+ psutil==6.0.0
121
+ ptyprocess==0.7.0
122
+ pure_eval==0.2.3
123
+ py-spy==0.4.1
124
+ pycparser==2.22
125
+ Pygments==2.18.0
126
+ PyGObject==3.42.1
127
+ PyJWT==2.3.0
128
+ pyparsing==2.4.7
129
+ pytest==9.0.2
130
+ pytest-xdist==3.8.0
131
+ python-apt==2.4.0+ubuntu4
132
+ python-dateutil==2.9.0.post0
133
+ python-json-logger==2.0.7
134
+ PyYAML==6.0.2
135
+ pyzmq==24.0.1
136
+ referencing==0.35.1
137
+ regex==2026.3.32
138
+ requests==2.32.3
139
+ rfc3339-validator==0.1.4
140
+ rfc3986-validator==0.1.1
141
+ rich==14.3.3
142
+ rpds-py==0.20.0
143
+ safetensors==0.7.0
144
+ SecretStorage==3.3.1
145
+ Send2Trash==1.8.3
146
+ sentencepiece==0.2.1
147
+ shellingham==1.5.4
148
+ six==1.16.0
149
+ sniffio==1.3.1
150
+ soupsieve==2.6
151
+ stack-data==0.6.3
152
+ sympy==1.14.0
153
+ systemd-python==234
154
+ terminado==0.18.1
155
+ tinycss2==1.3.0
156
+ tokenizers==0.22.2
157
+ torch==2.11.0+cu128
158
+ torchaudio==2.11.0+cu128
159
+ torchvision==0.26.0+cu128
160
+ tornado==6.4.1
161
+ tqdm==4.67.3
162
+ traitlets==5.14.3
163
+ transformers==5.4.0
164
+ triton==3.6.0
165
+ typer==0.24.1
166
+ types-python-dateutil==2.9.0.20240906
167
+ typing_extensions==4.15.0
168
+ uri-template==1.3.0
169
+ urllib3==2.2.3
170
+ wadllib==1.3.6
171
+ wcwidth==0.2.13
172
+ webcolors==24.8.0
173
+ webencodings==0.5.1
174
+ websocket-client==1.8.0
175
+ widgetsnbextension==4.0.13
176
+ zipp==1.0.0
environment/base_python.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Python 3.11.10
environment/env_list.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ Name Active Path
2
+ ────────────────────────────────────────────
3
+ base * /workspace
4
+ rlbench /workspace/envs/rlbench
environment/hardware_snapshot.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ NVIDIA RTX PRO 6000 Blackwell Server Edition, 580.126.09, 97887 MiB
environment/nvidia_smi.txt ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Mon Mar 30 14:48:02 2026
2
+ +-----------------------------------------------------------------------------------------+
3
+ | NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 |
4
+ +-----------------------------------------+------------------------+----------------------+
5
+ | GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
6
+ | Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
7
+ | | | MIG M. |
8
+ |=========================================+========================+======================|
9
+ | 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:F4:00.0 Off | 0 |
10
+ | N/A 34C P0 83W / 600W | 6110MiB / 97887MiB | 0% Default |
11
+ | | | Disabled |
12
+ +-----------------------------------------+------------------------+----------------------+
13
+
14
+ +-----------------------------------------------------------------------------------------+
15
+ | Processes: |
16
+ | GPU GI CI PID Type Process name GPU Memory |
17
+ | ID ID Usage |
18
+ |=========================================================================================|
19
+ | 0 N/A N/A 181865 G /usr/lib/xorg/Xorg 97MiB |
20
+ | 0 N/A N/A 278028 C python 570MiB |
21
+ | 0 N/A N/A 278251 C+G ...space/envs/rlbench/bin/python 5401MiB |
22
+ +-----------------------------------------------------------------------------------------+
environment/reconstruct_anybimanual_overlap_replay.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ if [ "$#" -ne 2 ]; then
5
+ echo "usage: $0 <sharded_multi_dir> <flat_output_dir>" >&2
6
+ echo "example: $0 baselines/AnyBimanual_overlap_replay/multi /tmp/multi_flat" >&2
7
+ exit 1
8
+ fi
9
+
10
+ src="$1"
11
+ dst="$2"
12
+
13
+ mkdir -p "$dst"
14
+
15
+ find "$src" -mindepth 2 -maxdepth 2 -type f -name '*.replay' -print0 \
16
+ | xargs -0 -n 1 -P 32 bash -c '
17
+ f="$1"
18
+ out="$2/$(basename "$f")"
19
+ ln "$f" "$out"
20
+ ' _ '{}' "$dst"
21
+
22
+ echo "reconstructed flat replay directory at: $dst"
environment/rlbench_pip_freeze.txt ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ absl-py==2.1.0
2
+ accelerate==0.31.0
3
+ addict==2.4.0
4
+ aiohappyeyeballs==2.6.1
5
+ aiohttp==3.13.4
6
+ aiosignal==1.4.0
7
+ antlr4-python3-runtime==4.9.3
8
+ asttokens==3.0.1
9
+ async-timeout==5.0.1
10
+ attrs==26.1.0
11
+ backports.zstd @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_backports.zstd_1767044984/work
12
+ blinker==1.9.0
13
+ blosc==1.11.4
14
+ Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1764016952863/work
15
+ cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
16
+ certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1772001073725/work/certifi
17
+ cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1761202865726/work
18
+ charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1773659966602/work
19
+ click==8.3.1
20
+ click-prompt==0.5.1
21
+ clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
22
+ cloudpickle==3.1.2
23
+ comm==0.2.3
24
+ ConfigArgParse==1.7.5
25
+ contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1744743067588/work
26
+ cuda-bindings==12.9.4
27
+ cuda-pathfinder==1.2.2
28
+ cuda-toolkit==12.8.1
29
+ cycler @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_cycler_1764466758/work
30
+ dash==4.1.0
31
+ decorator==5.2.1
32
+ docker-pycreds==0.4.0
33
+ einops==0.8.0
34
+ exceptiongroup==1.3.1
35
+ executing==2.2.1
36
+ Farama-Notifications==0.0.4
37
+ fastjsonschema==2.21.2
38
+ filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1773313889543/work
39
+ Flask==3.1.3
40
+ fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1773137064424/work
41
+ freetype-py==2.5.1
42
+ frozenlist==1.8.0
43
+ fsspec==2026.3.0
44
+ ftfy==6.2.0
45
+ gitdb==4.0.12
46
+ GitPython==3.1.46
47
+ gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1773244929835/work
48
+ grpcio==1.78.0
49
+ gym==0.26.2
50
+ gym-notices==0.1.0
51
+ gymnasium==1.0.0a2
52
+ h2 @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_h2_1756364871/work
53
+ h5py @ file:///home/conda/feedstock_root/build_artifacts/h5py_1774712049671/work
54
+ hf-xet==1.4.2
55
+ hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
56
+ huggingface_hub==0.36.2
57
+ hydra-core==1.3.2
58
+ hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
59
+ idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1760286409563/work
60
+ imageio @ file:///home/conda/feedstock_root/build_artifacts/imageio_1738273805233/work
61
+ imgaug==0.4.0
62
+ importlib_metadata==9.0.0
63
+ ipython==8.39.0
64
+ ipywidgets==8.1.8
65
+ itsdangerous==2.2.0
66
+ jedi==0.19.2
67
+ Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_jinja2_1764517220/work
68
+ joblib==1.5.3
69
+ jsonschema==4.26.0
70
+ jsonschema-specifications==2025.9.1
71
+ jupyter_core==5.9.1
72
+ jupyterlab_widgets==3.0.16
73
+ kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_kiwisolver_1773067043/work
74
+ lazy-loader==0.5
75
+ Markdown==3.10.2
76
+ markdown-it-py==4.0.0
77
+ MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1772444934960/work
78
+ matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1715976200404/work
79
+ matplotlib-inline==0.2.1
80
+ mdurl==0.1.2
81
+ meshcat==0.3.2
82
+ moviepy==2.2.1
83
+ mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1773661943568/work
84
+ multidict==6.7.1
85
+ munkres==1.1.4
86
+ narwhals==2.18.1
87
+ natsort==8.4.0
88
+ nbformat==5.10.4
89
+ nest-asyncio==1.6.0
90
+ networkx @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_networkx_1731521053/work
91
+ numpy==1.26.4
92
+ nvidia-cublas-cu12==12.8.4.1
93
+ nvidia-cuda-cupti-cu12==12.8.90
94
+ nvidia-cuda-nvrtc-cu12==12.8.93
95
+ nvidia-cuda-runtime-cu12==12.8.90
96
+ nvidia-cudnn-cu12==9.19.0.56
97
+ nvidia-cufft-cu12==11.3.3.83
98
+ nvidia-cufile-cu12==1.13.1.3
99
+ nvidia-curand-cu12==10.3.9.90
100
+ nvidia-cusolver-cu12==11.7.3.90
101
+ nvidia-cusparse-cu12==12.5.8.93
102
+ nvidia-cusparselt-cu12==0.7.1
103
+ nvidia-nccl-cu12==2.28.9
104
+ nvidia-nvjitlink-cu12==12.8.93
105
+ nvidia-nvshmem-cu12==3.4.5
106
+ nvidia-nvtx-cu12==12.8.90
107
+ omegaconf==2.3.0
108
+ open3d==0.19.0
109
+ openai==0.28.1
110
+ opencv-python==4.10.0.84
111
+ packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1769093650/work
112
+ pandas @ file:///home/conda/feedstock_root/build_artifacts/pandas_1744430447393/work
113
+ parso==0.8.6
114
+ -e git+https://github.com/markusgrotz/peract_bimanual.git@bb0232a6ba3fe116566e9568f0c7af980ed6703d#egg=peract_bimanual
115
+ perceiver-pytorch==0.8.8
116
+ pexpect==4.9.0
117
+ pillow==12.1.1
118
+ platformdirs==4.9.4
119
+ plotly==6.6.0
120
+ ply @ file:///home/conda/feedstock_root/build_artifacts/ply_1733239724146/work
121
+ poetry-core==2.3.2
122
+ prompt_toolkit==3.0.52
123
+ propcache==0.4.1
124
+ protobuf==5.29.6
125
+ psutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_psutil_1769678154/work
126
+ ptyprocess==0.7.0
127
+ pure_eval==0.2.3
128
+ pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
129
+ pyglet==2.1.13
130
+ Pygments==2.20.0
131
+ pyngrok==7.5.1
132
+ PyOpenGL==3.1.0
133
+ pyparsing @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pyparsing_1769003998/work
134
+ PyQt5==5.15.11
135
+ PyQt5_sip==12.17.0
136
+ pyquaternion==0.9.9
137
+ pyrender==0.1.45
138
+ -e git+https://github.com/markusgrotz/PyRep.git@b8bd1d7a3182adcd570d001649c0849047ebf197#egg=PyRep
139
+ PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
140
+ python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
141
+ pytorch-lamb==1.0.0
142
+ pytz @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pytz_1773679724/work
143
+ PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1770223234623/work
144
+ pyzmq==27.1.0
145
+ referencing==0.37.0
146
+ regex==2024.5.15
147
+ requests @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_requests_1774462091/work
148
+ retrying==1.4.2
149
+ # Editable install with no version control (reveal-vla-bimanual==0.1.0)
150
+ -e /workspace/reveal_vla_bimanual
151
+ rich==13.9.4
152
+ rich-click==1.8.9
153
+ -e git+https://github.com/markusgrotz/RLBench.git@8af748c51287989294e00c9c670e3330a0e35ed5#egg=rlbench
154
+ rpds-py==0.30.0
155
+ safetensors==0.4.3
156
+ scikit-image==0.25.2
157
+ scikit-learn==1.7.2
158
+ scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1716470219380/work/dist/scipy-1.13.1-cp310-cp310-linux_x86_64.whl#sha256=a4ff22b6dc27b61196be51695f53f9b0676e7c1bc564872b51fc3c41b79ae80b
159
+ segment-anything==1.0
160
+ sentry-sdk==2.56.0
161
+ setproctitle==1.3.7
162
+ shapely==2.1.2
163
+ sip @ file:///home/conda/feedstock_root/build_artifacts/sip_1759437834046/work
164
+ six @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_six_1753199211/work
165
+ smmap==5.0.3
166
+ stack-data==0.6.3
167
+ sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1771952240620/work
168
+ tensorboard==2.16.2
169
+ tensorboard-data-server==0.7.2
170
+ tensorboardX==2.6.4
171
+ termcolor==3.3.0
172
+ threadpoolctl==3.6.0
173
+ tifffile==2025.5.10
174
+ timeout-decorator==0.5.0
175
+ timm==1.0.26
176
+ tokenizers==0.19.1
177
+ toml @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_toml_1764486833/work
178
+ tomli @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tomli_1774492402/work
179
+ torch==2.11.0+cu128
180
+ torchaudio==2.11.0+cu128
181
+ torchvision==0.26.0+cu128
182
+ tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1774357896577/work
183
+ tqdm @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tqdm_1770153424/work
184
+ traitlets==5.14.3
185
+ transformers==4.41.2
186
+ transforms3d==0.4.1
187
+ trimesh @ file:///home/conda/feedstock_root/build_artifacts/trimesh_1774412449209/work
188
+ triton==3.6.0
189
+ typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1756220668/work
190
+ tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1765719872007/work
191
+ u-msgpack-python==2.8.0
192
+ unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1770908960326/work
193
+ urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1767817748113/work
194
+ wandb==0.18.0
195
+ wcwidth==0.2.14
196
+ Werkzeug==3.1.7
197
+ widgetsnbextension==4.0.15
198
+ yarl==1.23.0
199
+ -e git+https://github.com/markusgrotz/YARR.git@6822ff78602c77878b27d4cfe759ce029c67bffb#egg=yarr
200
+ zipp==3.23.0
environment/rlbench_python.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Python 3.10.20
environment/runtime_env_vars.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ export HF_HOME=/workspace/.cache/huggingface
2
+ export MAMBA_ROOT_PREFIX=/workspace/.micromamba
3
+ export DISPLAY=:99
4
+ export PYTHONPATH=/workspace/VLAarchtests/code/reveal_vla_bimanual
environment/setup_same_hardware.sh ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ set -euo pipefail
3
+
4
+ ROOT_DIR="${ROOT_DIR:-/workspace}"
5
+
6
+ echo "[setup] expected hardware: NVIDIA RTX PRO 6000 Blackwell Server Edition"
7
+ echo "[setup] expected OS family: Ubuntu 24.04 / Linux 6.8"
8
+
9
+ if [ ! -d "$ROOT_DIR/VLAarchtests" ]; then
10
+ echo "[setup] expected staged tree at $ROOT_DIR/VLAarchtests" >&2
11
+ fi
12
+
13
+ echo "[setup] base python:"
14
+ python --version || true
15
+
16
+ if [ -x "$ROOT_DIR/.tools/micromamba/bin/micromamba" ]; then
17
+ echo "[setup] micromamba envs:"
18
+ "$ROOT_DIR/.tools/micromamba/bin/micromamba" env list || true
19
+ fi
20
+
21
+ echo "[setup] current GPU snapshot:"
22
+ nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader || true
23
+
24
+ echo "[setup] this repo includes package snapshots in environment/"
25
+ echo "[setup] recreate the rlbench env at $ROOT_DIR/envs/rlbench before running the overlap baseline path"
environment/uname.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ Linux 129a2ec0f4a9 6.8.0-106-generic #106-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 6 07:58:08 UTC 2026 x86_64 x86_64 x86_64 GNU/Linux
handoff/instructions4.md ADDED
@@ -0,0 +1,591 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Developer handoff: 10-hour sim sprint on 1×RTX PRO 6000 for elastic-occlusion bimanual reveal/retrieve
2
+
3
+ ## Scope
4
+
5
+ This handoff is for the current simulation phase only. The purpose is not to produce publication-grade evidence. The purpose is to use one short sprint to extract the most decision-relevant signal possible before the custom three-task teleoperation benchmark exists.
6
+
7
+ This document does **not** include explicit instructions for future teleoperation data collection.
8
+
9
+ The target problem remains the same: bimanual reveal and retrieve under elastic occlusion, with task families that map to (1) foliage reveal with safe actor insertion and retrieval, (2) bag opening plus retrieval, and (3) folded-cloth / suitcase reveal with minimal fold disruption.
10
+
11
+ ## Hard constraints for this sprint
12
+
13
+ The sprint must assume the following:
14
+
15
+ - Hardware: **1× RTX PRO 6000** workstation GPU.
16
+ - Wall-clock budget: **~10 hours total**.
17
+ - Deliverable standard: **decision-quality results**, not paper-quality results. They must be rigorous nonetheless - absolutely no data leaks.
18
+ - The output of the sprint must be enough to shape the next development cycle.
19
+
20
+ This means the sprint is not allowed to expand into a broad refactor, a large hyperparameter search, an external benchmark integration project, or a foundation-model migration. The objective is to get the strongest signal per hour.
21
+
22
+ ## What this sprint needs to answer
23
+
24
+ At the end of the 10-hour window, the repo should let us answer these questions with reasonable confidence:
25
+
26
+ 1. Does the current full architecture look directionally better than trivial and structured baselines on rough proxy versions of the three real task families?
27
+ 2. Which parts of the architecture appear to matter most right now: explicit task conditioning, geometry, memory, planner, or candidate family structure?
28
+ 3. For each of the three target tasks, is the current architecture best described as **promising**, **uncertain**, or **weak** under the proxy tests?
29
+ 4. If the system is weak, is the weakness coming mostly from perception/state estimation, memory, retrieve gating/planning, or proxy mismatch?
30
+ 5. Which next engineering changes are most likely to widen the eventual real-task performance gap, and which ones are unlikely to matter?
31
+
32
+ That is the bar for success in this sprint. It is acceptable if the results are approximate. It is not acceptable if the results are too noisy or too poorly instrumented to support those decisions.
33
+
34
+ ## What this sprint can and cannot tell us
35
+
36
+ This sprint **can** tell us whether the current structured architecture is showing the right dependencies under stress. For example, it can tell us whether memory actually helps under reocclusion, whether geometry actually helps under camera perturbation, whether retrieve gating blocks premature retrieve, and whether the planner adds value beyond trivial candidate selection.
37
+
38
+ This sprint **cannot** tell us true real-world performance on live foliage, real bag interiors, or folded clothes in a suitcase. It also cannot tell us whether the final production backbone should be CLIP, OpenVLA-style, LingBot-style, or something else. Those are later decisions.
39
+
40
+ So the correct target is not “exact future performance.” The correct target is “a useful approximation of whether the current structure is pointed in the right direction, and where the biggest current bottlenecks are.”
41
+
42
+ ## High-level decisions to lock now
43
+
44
+ 1. Treat the current **compact-phase CLIP/RGB-D handoff branch** as the only reference branch for this sprint.
45
+ 2. Keep the explicit reveal-state stack. Do **not** rewrite the repo into a monolithic end-to-end VLA policy in this sprint.
46
+ 3. Keep RLBench only as a smoke test for three-camera and bimanual integration. Do **not** use RLBench mean success as the main selector for reveal/retrieve architecture changes.
47
+ 4. Do **not** switch to a new backbone in this sprint. The goal here is to evaluate the structure, not to spend the 10-hour budget on trunk migration.
48
+ 5. Prefer **eval-time knockouts and toggles** over fully retrained matched ablations in this sprint. Retrained matched ablations remain important later, but they are too expensive for the current time budget.
49
+
50
+ ## Immediate read of the current repo
51
+
52
+ The current repo is still a good scaffold, but the signal quality is weaker than the structure quality. The strongest part of the system remains the decomposition itself: explicit reveal-state fields, dual scene/belief memory, task-conditioned proposal families, and a planner that reasons about persistence, reocclusion, support, and actor feasibility. The weakest part is that the current evidence is still too easy to misread. The proxy results already suggest that the large spatial rollout path is not the right place to spend another iteration right now, while the compact-phase line remains the most credible base. That should be treated as settled for this sprint.
53
+
54
+ ## Sprint strategy
55
+
56
+ The fastest path to useful conclusions is:
57
+
58
+ 1. Remove the most obvious confounds.
59
+ 2. Add just enough stress slicing and logging to make the proxy benchmark informative.
60
+ 3. Run a **small but fixed stratified benchmark** that is reused for every comparison.
61
+ 4. Compare the base model against trivial baselines and a few **eval-time architecture knockouts**.
62
+ 5. Use the results to build a task-by-task bottleneck map.
63
+
64
+ This is deliberately narrower than the earlier broad handoff. The point is to obtain meaningful conclusions within one GPU and one workday.
65
+
66
+ ## Do now vs defer
67
+
68
+ ### Must do in this sprint
69
+
70
+ These are the changes that are worth the time even under a 10-hour cap.
71
+
72
+ 1. Explicit task metadata must override text routing.
73
+ 2. History camera geometry must propagate correctly.
74
+ 3. The compact-phase branch must become the default base config.
75
+ 4. The proxy benchmark must become stratified by task and stress slice.
76
+ 5. The benchmark runner must support simple baselines and eval-time architecture knockouts.
77
+ 6. Reporting must include the small set of task-specific metrics that actually matter for the three target tasks.
78
+
79
+ ### Explicitly defer until after this sprint
80
+
81
+ These are good ideas, but they should not consume the current 10-hour budget.
82
+
83
+ 1. Foundation trunk migration (OpenVLA, LingBot, π0.5, etc.).
84
+ 2. External deformable benchmark integration.
85
+ 3. Full matched retraining ablation suite.
86
+ 4. Major loss redesign or long retraining campaigns.
87
+ 5. Large nuisance sweeps beyond the narrow stress slices listed below.
88
+ 6. Spatial rollout branch rescue work.
89
+
90
+ ## Mandatory repo changes for this sprint
91
+
92
+ ### 1. Replace heuristic task routing as the primary path
93
+
94
+ **Why now:** this is a real confound and cheap to fix.
95
+
96
+ Files to edit:
97
+
98
+ - `code/reveal_vla_bimanual/models/policy.py`
99
+ - `code/reveal_vla_bimanual/models/action_decoder.py`
100
+ - `code/reveal_vla_bimanual/sim_reveal/dataset.py`
101
+ - `code/reveal_vla_bimanual/sim_reveal/generate_dataset.py`
102
+ - `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
103
+
104
+ Required changes:
105
+
106
+ - Add `task_name` and `task_id` to every proxy training and evaluation example.
107
+ - Make explicit task metadata override any text-based inference everywhere.
108
+ - Keep keyword routing only as a fallback for legacy examples that do not carry task metadata.
109
+ - Surface the resolved task family in benchmark logs so mistakes are easy to see.
110
+
111
+ Acceptance criterion:
112
+
113
+ - A misleading prompt string must not change the task family when `task_name` is present.
114
+
115
+ ### 2. Fix history geometry propagation
116
+
117
+ **Why now:** if geometry is broken in history, the current geometry ablations are not trustworthy.
118
+
119
+ Files to edit:
120
+
121
+ - `code/reveal_vla_bimanual/models/policy.py`
122
+ - any history batching utility that currently drops camera matrices
123
+ - proxy dataset serialization if history camera metadata is missing
124
+
125
+ Required changes:
126
+
127
+ - Save and batch history camera intrinsics and extrinsics.
128
+ - Pass them through the history encoder when geometry and camera-pose tokens are enabled.
129
+ - Add a validity mask if some history frames do not have full camera metadata.
130
+ - Add a debug log or assertion path that makes it obvious whether history geometry is really being used.
131
+
132
+ Acceptance criterion:
133
+
134
+ - A geometry-enabled run must receive non-null history camera tensors in the forward path.
135
+
136
+ ### 3. Freeze the compact-phase branch as the main base
137
+
138
+ **Why now:** the time budget does not allow another architecture round on the weaker spatial branch.
139
+
140
+ Files to edit:
141
+
142
+ - `code/reveal_vla_bimanual/train/configs/*.yaml`
143
+ - `code/reveal_vla_bimanual/eval/run_ablations.py`
144
+ - any training or eval launcher that still points to older dummy ablation configs
145
+
146
+ Required changes:
147
+
148
+ - Create one new base config derived from the compact-phase recipe. Suggested filename:
149
+ - `proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact_phase_v7_base.yaml`
150
+ - Mark the current spatial configs as experimental.
151
+ - Make the new v7 base the default for this sprint.
152
+ - Do not create a broad new family of retrained ablation configs now. That is for later.
153
+
154
+ Acceptance criterion:
155
+
156
+ - The benchmark runner should use the v7 compact-phase config by default unless explicitly told otherwise.
157
+
158
+ ### 4. Add a narrow but informative proxy stress suite
159
+
160
+ **Why now:** the current proxies are too undifferentiated. The sprint only needs enough stress structure to make results interpretable.
161
+
162
+ Files to edit:
163
+
164
+ - `code/reveal_vla_bimanual/sim_reveal/procedural_envs.py`
165
+ - `code/reveal_vla_bimanual/sim_reveal/proxy_specs.py`
166
+ - `code/reveal_vla_bimanual/sim_reveal/dataset.py`
167
+ - `code/reveal_vla_bimanual/sim_reveal/generate_dataset.py`
168
+
169
+ Required changes:
170
+
171
+ Add only these stress slices for this sprint:
172
+
173
+ - `nominal`
174
+ - `high_reocclusion`
175
+ - `camera_perturbation`
176
+ - one task-specific critical slice per task:
177
+ - foliage: `tight_corridor_high_collateral`
178
+ - bag: `one_sided_slip`
179
+ - cloth: `fold_sensitive_long_persistence`
180
+
181
+ Also add:
182
+
183
+ - `difficulty_bin` with only `medium` and `hard` for this sprint (skip easy and extreme to save time and focus on decision-relevant cases)
184
+ - episode metadata for the sampled nuisance parameters used in each stress slice
185
+ - per-step traces for visibility, support, access/corridor, reocclusion risk, disturbance, and chosen candidate family
186
+
187
+ Acceptance criterion:
188
+
189
+ - Every benchmark report must be sliceable by task family, stress slice, and difficulty bin.
190
+
191
+ ### 5. Add simple baselines and oracle-style planner evaluation
192
+
193
+ **Why now:** without these, base-model numbers are hard to interpret.
194
+
195
+ Files to edit/add:
196
+
197
+ - `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
198
+ - `code/reveal_vla_bimanual/eval/metrics.py`
199
+ - add `code/reveal_vla_bimanual/eval/run_proxy_random_eval.py`
200
+ - add `code/reveal_vla_bimanual/eval/run_proxy_candidate0_eval.py`
201
+ - add `code/reveal_vla_bimanual/eval/run_planner_oracle_eval.py`
202
+ - add `code/reveal_vla_bimanual/eval/run_proxy_scripted_eval.py` if the existing scripted path is not already callable directly
203
+
204
+ Required changes:
205
+
206
+ - Add random candidate selection.
207
+ - Add candidate-0 selection.
208
+ - Add scripted teacher execution.
209
+ - Add oracle-planner evaluation that uses proxy candidate summaries directly.
210
+ - Add support for eval-time architecture toggles:
211
+ - `--disable_planner`
212
+ - `--disable_memory`
213
+ - `--disable_task_conditioning`
214
+ - `--disable_geometry`
215
+ - `--disable_camera_pose` (optional if it is cheap)
216
+
217
+ Acceptance criterion:
218
+
219
+ - The benchmark must be able to compare the same checkpoint against trivial baselines and against architecture knockouts without retraining all variants.
220
+
221
+ ### 6. Strengthen reporting, but only where it matters
222
+
223
+ **Why now:** the current sprint only succeeds if it produces conclusions, not just numbers.
224
+
225
+ Files to edit:
226
+
227
+ - `code/reveal_vla_bimanual/eval/metrics.py`
228
+ - `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
229
+
230
+ Required metric outputs:
231
+
232
+ Global:
233
+
234
+ - overall proxy success
235
+ - per-task success
236
+ - success by stress slice
237
+ - success by difficulty bin
238
+ - premature retrieve rate
239
+ - reocclusion-after-reveal rate
240
+ - planner regret (where oracle summaries are available)
241
+
242
+ Task-specific headline metrics:
243
+
244
+ Foliage:
245
+
246
+ - visibility integral
247
+ - corridor availability
248
+ - collateral motion / damage proxy
249
+ - actor-feasibility floor before retrieve
250
+
251
+ Bag:
252
+
253
+ - mouth aperture
254
+ - hold persistence
255
+ - rim slip rate
256
+ - insertable corridor
257
+
258
+ Cloth:
259
+
260
+ - fold preservation
261
+ - layer separation quality
262
+ - top-layer stability
263
+ - lift-too-high rate
264
+
265
+ Required report shape:
266
+
267
+ - one overall table
268
+ - one task × stress slice table
269
+ - per-episode JSON traces for failure clustering later
270
+
271
+ Acceptance criterion:
272
+
273
+ - A single report should make it obvious whether the model is failing because of reocclusion, geometry sensitivity, premature retrieve, or task-specific degradation.
274
+
275
+ ## Changes that are useful only if they are cheap
276
+
277
+ These are allowed only if the mandatory work finishes early.
278
+
279
+ ### A. Light training rebalance
280
+
281
+ Only do this if it can be implemented in under about one hour.
282
+
283
+ Allowed small changes:
284
+
285
+ - oversample obvious hard negatives already present in the proxy dataset
286
+ - slightly increase loss weight on unsafe-retrieve ranking errors
287
+ - log candidate ranking diagnostics during training
288
+
289
+ Do **not** do a broad loss redesign in this sprint.
290
+
291
+ ### B. Trunk modularity prep
292
+
293
+ Only do this if the mandatory work is already complete.
294
+
295
+ Allowed small changes:
296
+
297
+ - define a simple trunk adapter interface around the current CLIP path
298
+ - avoid touching planner, memory, or reveal-head code
299
+
300
+ Do **not** attempt a real new-backbone integration in this sprint.
301
+
302
+ ## 10-hour execution plan
303
+
304
+ This schedule is the intended wall-clock plan. It is aggressive but realistic if the scope stays narrow.
305
+
306
+ ### Hour 0 to 1.5: remove confounds
307
+
308
+ Complete:
309
+
310
+ - explicit task metadata path
311
+ - history geometry propagation
312
+ - v7 compact-phase base config
313
+ - eval-time toggle plumbing in the benchmark runner
314
+
315
+ Output expected by the end of this block:
316
+
317
+ - code compiles
318
+ - a tiny smoke run confirms task routing and history geometry are active
319
+
320
+ ### Hour 1.5 to 3: build the fixed eval set and reporting
321
+
322
+ Complete:
323
+
324
+ - stratified proxy eval set generation
325
+ - task/stress/difficulty metadata roundtrip
326
+ - benchmark tables and per-episode JSON traces
327
+ - random, candidate-0, scripted, and oracle evaluation entry points
328
+
329
+ Output expected by the end of this block:
330
+
331
+ - one fixed benchmark episode set reused by every later run
332
+
333
+ ### Hour 3 to 5.5: produce one base-model result
334
+
335
+ Preferred path:
336
+
337
+ - reuse the best existing compact-phase checkpoint if it still loads cleanly after the metadata and geometry fixes
338
+ - if needed, run a short warm-start fine-tune from that checkpoint rather than starting from scratch
339
+
340
+ Do not spend this block on multi-seed training. One strong base run is more valuable than several weak incomplete runs.
341
+
342
+ Output expected by the end of this block:
343
+
344
+ - one evaluated base model on the full fixed proxy suite
345
+
346
+ ### Hour 5.5 to 8.5: run baselines and eval-time knockouts
347
+
348
+ Required comparisons:
349
+
350
+ - random
351
+ - candidate-0
352
+ - scripted teacher
353
+ - oracle planner
354
+ - base model
355
+ - base model with planner disabled
356
+ - base model with memory disabled
357
+ - base model with task conditioning disabled
358
+ - base model with geometry disabled
359
+ - base model with camera pose disabled if cheap enough
360
+
361
+ Output expected by the end of this block:
362
+
363
+ - a complete comparison table on the same episodes
364
+
365
+ ### Hour 8.5 to 10: summarize and extract conclusions
366
+
367
+ Complete:
368
+
369
+ - task-by-task bottleneck summary
370
+ - approximate transfer-readiness labels for foliage, bag, and cloth
371
+ - ranked next-step engineering priorities
372
+
373
+ Output expected by the end of this block:
374
+
375
+ - one benchmark summary table
376
+ - one short conclusion memo in the repo or artifact directory
377
+
378
+ ## Fixed benchmark design for the 10-hour sprint
379
+
380
+ Use one small, fixed, stratified benchmark. Reuse the same episode seeds for all variants.
381
+
382
+ Recommended size:
383
+
384
+ - **300 total episodes**
385
+ - **100 per task family**
386
+ - per task family:
387
+ - 20 `nominal` / `medium`
388
+ - 20 `nominal` / `hard`
389
+ - 20 `high_reocclusion`
390
+ - 20 `camera_perturbation`
391
+ - 20 task-specific critical slice
392
+
393
+ This is small enough to run repeatedly and large enough to show directional differences if the logging is good.
394
+
395
+ ## Required tests to add now
396
+
397
+ The test suite in this sprint is not meant to be exhaustive. It is meant to prevent false conclusions.
398
+
399
+ ### Unit tests
400
+
401
+ `tests/test_explicit_task_metadata_overrides_text.py`
402
+
403
+ - Batch has `task_name="bag"` and misleading foliage text.
404
+ - Assert that bag proposal families and bag task heads are used.
405
+
406
+ `tests/test_text_routing_only_used_as_fallback.py`
407
+
408
+ - Assert that keyword routing is skipped when task metadata exists.
409
+
410
+ `tests/test_history_camera_geometry_propagates.py`
411
+
412
+ - Assert that history frames receive non-null camera tensors when geometry is enabled.
413
+
414
+ `tests/test_history_geometry_validity_mask.py`
415
+
416
+ - Assert that mixed valid/invalid history geometry uses a validity mask rather than silent nulling.
417
+
418
+ `tests/test_eval_toggle_paths_work.py`
419
+
420
+ - Assert that planner, memory, task-conditioning, and geometry toggles actually change the execution path.
421
+
422
+ `tests/test_benchmark_report_contains_task_and_stress_slices.py`
423
+
424
+ - Assert that the output report includes task family, stress slice, and difficulty bin tables.
425
+
426
+ ### Integration tests
427
+
428
+ `tests/test_proxy_stress_profile_metadata_roundtrip.py`
429
+
430
+ - Generate all sprint stress slices and assert the metadata survives dataset serialization and evaluation.
431
+
432
+ `tests/test_planner_beats_random_on_oracle_candidates.py`
433
+
434
+ - Use oracle candidate summaries and assert the planner beats random and candidate-0 on regret/top-1.
435
+
436
+ `tests/test_memory_matters_under_high_reocclusion.py`
437
+
438
+ - Compare full memory vs disabled-memory on a small high-reocclusion slice and assert a directional drop.
439
+
440
+ `tests/test_geometry_matters_under_camera_perturbation.py`
441
+
442
+ - Compare geometry-on vs geometry-off on a small camera-perturbation slice and assert a directional drop.
443
+
444
+ `tests/test_retrieve_gating_blocks_premature_retrieve.py`
445
+
446
+ - Feed candidates where raw retrieve looks tempting but support/persistence/corridor are unsafe and assert that retrieve is rejected.
447
+
448
+ These tests are enough for this sprint. Do not expand the suite unless the mandatory implementation is already done.
449
+
450
+ ## Benchmark runs to perform in this sprint
451
+
452
+ ### 1. Baselines
453
+
454
+ Run on the fixed 300-episode benchmark:
455
+
456
+ - random candidate selection
457
+ - candidate-0 selection
458
+ - scripted teacher
459
+ - oracle planner
460
+
461
+ These establish the floor, the structured upper bound for the current proposal set, and whether the learned planner adds anything.
462
+
463
+ ### 2. Base model run
464
+
465
+ Run the v7 compact-phase base on the same 300 episodes.
466
+
467
+ ### 3. Eval-time architecture knockouts
468
+
469
+ Run the same checkpoint with these toggles:
470
+
471
+ - no planner
472
+ - no memory
473
+ - no task conditioning
474
+ - no geometry
475
+ - no camera pose (only if cheap)
476
+
477
+ These are not final scientific ablations. They are fast directional probes that tell us what appears to matter right now.
478
+
479
+ ### 4. Optional cheap extra run
480
+
481
+ Only if time remains:
482
+
483
+ - a short warm-start fine-tune with a small hard-negative rebalance
484
+
485
+ Treat this only as bonus signal. It is not part of the minimum sprint success condition.
486
+
487
+ ## Decision rubric for interpreting the results
488
+
489
+ At the end of the sprint, do **not** only report the raw numbers. Convert the results into a bottleneck map.
490
+
491
+ ### Signals that the architecture is directionally healthy
492
+
493
+ - Base clearly beats random and candidate-0 on all three task families.
494
+ - Oracle planner beats random and candidate-0 by a wide margin, showing the proposal/planner structure is at least usable.
495
+ - Disabling planner hurts most on premature-retrieve and task-specific stress slices.
496
+ - Disabling memory hurts most on `high_reocclusion` and long-persistence cloth cases.
497
+ - Disabling geometry hurts most on `camera_perturbation`.
498
+ - Task conditioning matters on the mixed-task benchmark and under misleading text.
499
+
500
+ ### Signals that a component is currently unproven
501
+
502
+ - Base is only slightly above candidate-0 or random.
503
+ - Oracle planner is weak, which means the proposal set or planner utility is not yet reliable.
504
+ - Memory removal is almost flat on `high_reocclusion`.
505
+ - Geometry removal is almost flat on `camera_perturbation`.
506
+ - Planner removal is flat, which suggests that the learned scores or the candidate shortlist are not carrying useful structure.
507
+
508
+ ### Task-specific interpretation
509
+
510
+ For foliage, judge transfer-readiness mainly from:
511
+
512
+ - corridor availability
513
+ - actor-feasibility floor before retrieve
514
+ - collateral motion / damage proxy
515
+ - robustness under `high_reocclusion` and `tight_corridor_high_collateral`
516
+
517
+ For bag, judge transfer-readiness mainly from:
518
+
519
+ - mouth aperture
520
+ - hold persistence
521
+ - rim slip rate
522
+ - robustness under `one_sided_slip`
523
+
524
+ For cloth, judge transfer-readiness mainly from:
525
+
526
+ - fold preservation
527
+ - layer separation quality
528
+ - top-layer stability
529
+ - lift-too-high rate
530
+ - robustness under `fold_sensitive_long_persistence`
531
+
532
+ At the end of the sprint, label each task family as:
533
+
534
+ - **Promising**: base beats weak baselines clearly and the expected architectural components matter on the right stress slices.
535
+ - **Uncertain**: base is somewhat above baselines but at least one critical stress slice or component dependency is weak.
536
+ - **Weak**: base is near trivial baselines or the critical stress slices fail badly enough that the current structure is not yet convincing.
537
+
538
+ ## Practical GPU/runtime guidance
539
+
540
+ This sprint assumes a single 96 GB workstation GPU. That is enough for the current CLIP-based compact-phase line and repeated proxy evaluations, but it is not enough time to justify broad parallel experiments.
541
+
542
+ Use the following operating rules:
543
+
544
+ - run experiments sequentially
545
+ - prefer one strong base run over many partial runs
546
+ - use mixed precision if already supported
547
+ - keep evaluation batch sizes modest and stable
548
+ - avoid large retraining loops or many seeds
549
+ - reuse the same fixed benchmark episodes for every comparison
550
+
551
+ The main bottleneck in this sprint should be engineering and interpretation, not raw VRAM.
552
+
553
+ ## Things not to do in this sprint
554
+
555
+ Do not switch backbones. Do not integrate an external simulator. Do not rescue the spatial rollout branch. Do not run broad hyperparameter sweeps. Do not attempt a five-seed retraining ablation matrix. Do not use RLBench averages as the main argument for or against architecture changes meant for foliage, bag, or folded cloth.
556
+
557
+ ## Minimal deliverables at the 10-hour mark
558
+
559
+ At the end of the sprint, the repo or artifact directory should contain:
560
+
561
+ 1. the new `v7` compact-phase base config
562
+ 2. the fixed 300-episode stratified benchmark definition or metadata file
563
+ 3. updated benchmark runner with stress-slice reporting
564
+ 4. random, candidate-0, scripted, and oracle evaluation runners
565
+ 5. the required unit and integration tests listed above
566
+ 6. one complete result table comparing:
567
+ - random
568
+ - candidate-0
569
+ - scripted teacher
570
+ - oracle planner
571
+ - base model
572
+ - no planner
573
+ - no memory
574
+ - no task conditioning
575
+ - no geometry
576
+ - optional no camera pose
577
+ 7. one short decision memo that states:
578
+ - approximate transfer-readiness for foliage, bag, and cloth
579
+ - which architectural components look most important
580
+ - which current weakness is most likely to block real-task success
581
+ - what should be strengthened next
582
+
583
+ ## Success condition for this sprint
584
+
585
+ This sprint is successful if, by the end of 10 hours, we can say something like the following and defend it with benchmark evidence:
586
+
587
+ - “The structured architecture appears meaningfully better than trivial baselines on foliage and bag proxies, but cloth remains fragile because fold preservation degrades under long persistence.”
588
+ - or “The planner structure looks sound under oracle candidates, but the learned state estimate is still too weak, so the next work should target perception/memory rather than planner redesign.”
589
+ - or “Geometry and task conditioning matter, but memory does not yet move the reocclusion slice, so the current memory story is still unproven.”
590
+
591
+ If we can make claims of that form with actual run outputs, the sprint has done its job.
history/VLAarchtests_previous_README.md ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ tags:
3
+ - robotics
4
+ - vision-language-action
5
+ - bimanual-manipulation
6
+ - rlbench
7
+ - rgbd
8
+ ---
9
+
10
+ # VLAarchtests
11
+
12
+ Bundle uploaded from `/workspace` runpod sessions dated `2026-03-25 UTC` and `2026-03-26 UTC`.
13
+
14
+ ## Top-Level Contents
15
+
16
+ - `code/reveal_vla_bimanual/`
17
+ - project code used for the proxy and RLBench runs in this bundle
18
+ - `artifacts/data/reveal_proxy/`
19
+ - proxy dataset bundles used by the handoff runs
20
+ - `artifacts/outputs/r3d/`
21
+ - previously uploaded R3D proxy outputs already present in the bundle
22
+ - `artifacts/outputs/r3d_handoff/`
23
+ - handoff proxy checkpoints
24
+ - `artifacts/outputs/r3d_handoff_phase/`
25
+ - phase-supervised handoff proxy checkpoints
26
+ - `artifacts/outputs/rlbench_current/`
27
+ - RLBench checkpoints from the current session
28
+ - `artifacts/reports/`
29
+ - proxy and RLBench result files copied from `/workspace/reports`
30
+ - `environment/`
31
+ - same-machine setup files and validation helpers
32
+ - `tests/`
33
+ - local test suite
34
+ - `handoff/instructions.md`
35
+ - instruction file used for the handoff work
36
+ - `MODEL_INDEX.md`
37
+ - checkpoint and result index
38
+ - `results/session_results_20260326.md`
39
+ - raw result tables for the `2026-03-25/26` work
40
+
41
+ ## Code Added Or Updated
42
+
43
+ ### Core model, memory, planner, and dataset paths
44
+
45
+ - `code/reveal_vla_bimanual/models/backbones.py`
46
+ - `code/reveal_vla_bimanual/models/multiview_fusion.py`
47
+ - `code/reveal_vla_bimanual/models/observation_memory.py`
48
+ - `code/reveal_vla_bimanual/models/reveal_head.py`
49
+ - `code/reveal_vla_bimanual/models/world_model.py`
50
+ - `code/reveal_vla_bimanual/models/action_decoder.py`
51
+ - `code/reveal_vla_bimanual/models/planner.py`
52
+ - `code/reveal_vla_bimanual/models/policy.py`
53
+ - `code/reveal_vla_bimanual/train/losses.py`
54
+ - `code/reveal_vla_bimanual/sim_reveal/dataset.py`
55
+ - `code/reveal_vla_bimanual/sim_reveal/procedural_envs.py`
56
+ - `code/reveal_vla_bimanual/sim_rlbench/dataset.py`
57
+
58
+ ### Training and evaluation paths
59
+
60
+ - `code/reveal_vla_bimanual/train/run_rlbench_experiment.py`
61
+ - `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
62
+ - `code/reveal_vla_bimanual/eval/run_ablations.py`
63
+ - `code/reveal_vla_bimanual/eval/run_teacher_audit.py`
64
+ - `code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py`
65
+ - `code/reveal_vla_bimanual/eval/run_rlbench_knn_eval.py`
66
+
67
+ ### Added or updated training configs
68
+
69
+ - `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact.yaml`
70
+ - `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_spatial.yaml`
71
+ - `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact_phase.yaml`
72
+ - `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_spatial_phase.yaml`
73
+ - `code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip_current_valid9.yaml`
74
+ - `code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip_current_common23.yaml`
75
+ - `code/reveal_vla_bimanual/train/configs/rlbench_lift_ball_backbone_only_clip_current_wide.yaml`
76
+ - `code/reveal_vla_bimanual/train/configs/rlbench_lift_ball_backbone_only_clip_step1.yaml`
77
+ - `code/reveal_vla_bimanual/train/configs/rlbench_push_box_backbone_only_clip_step1.yaml`
78
+
79
+ ### Test files
80
+
81
+ The staged `tests/` directory contains `32` test modules plus `conftest.py`, including:
82
+
83
+ - geometry and camera rotation coverage
84
+ - phase-label and candidate-ranking coverage
85
+ - planner gradient-flow and reocclusion gating coverage
86
+ - world-model null-rollout, field-consistency, and task-adapter coverage
87
+ - proxy scripted benchmark and teacher-audit coverage
88
+
89
+ ## Verification
90
+
91
+ - local test command:
92
+ - `PYTHONPATH=/workspace/VLAarchtests_work/code/reveal_vla_bimanual python -m pytest -q /workspace/VLAarchtests_work/tests`
93
+ - result:
94
+ - `33 passed`
95
+
96
+ ## Raw Result Files
97
+
98
+ ### Proxy and handoff results
99
+
100
+ - `artifacts/reports/reveal_smoke_mod/reveal_benchmark.json`
101
+ - `artifacts/reports/reveal_smoke_nogeom/reveal_benchmark.json`
102
+ - `artifacts/reports/reveal_smoke_noplanner/reveal_benchmark.json`
103
+ - `artifacts/reports/reveal_handoff_compare_serious/reveal_benchmark.json`
104
+ - `artifacts/reports/reveal_handoff_compare_serious_compact/reveal_benchmark.json`
105
+ - `artifacts/reports/reveal_phase_compare_serious_compact/reveal_benchmark.json`
106
+ - `artifacts/reports/reveal_phase_compare_serious_spatial_compactwm/reveal_benchmark.json`
107
+ - `artifacts/reports/reveal_phase_ablations_compact/ablations.json`
108
+ - `artifacts/reports/reveal_teacher_audit_serious/teacher_audit.json`
109
+
110
+ ### RLBench result files
111
+
112
+ - `artifacts/reports/rlbench_dual_buttons_baseline_len100_ep1_ik_rescale/rollout_eval.json`
113
+ - `artifacts/reports/rlbench_dual_buttons_common23_len100_ep1_ik_rescale/rollout_eval.json`
114
+ - `artifacts/reports/rlbench_push_box_common23_len100_ep1_ik_rescale/rollout_eval.json`
115
+ - `artifacts/reports/rlbench_lift_ball_wide_len160_ep1_ik_c1/rollout_eval.json`
116
+ - `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1/rollout_eval.json`
117
+ - `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1_s005/rollout_eval.json`
118
+ - `artifacts/reports/rlbench_push_box_knn_step1_ep1/rollout_eval.json`
119
+ - `artifacts/reports/rlbench_push_box_knn_step1_ep5/rollout_eval.json`
120
+ - `artifacts/reports/rlbench_push_box_knn_step1_ep5_top1_dense/rollout_eval.json`
121
+
122
+ ## Raw Result Tables
123
+
124
+ ### Proxy serious runs
125
+
126
+ | Artifact | File | Raw values |
127
+ | --- | --- | --- |
128
+ | spatial handoff vs released baseline | `artifacts/reports/reveal_handoff_compare_serious/reveal_benchmark.json` | baseline mean success `0.5833`, handoff mean success `0.2167` |
129
+ | spatial-trained checkpoint with compact world model vs released baseline | `artifacts/reports/reveal_handoff_compare_serious_compact/reveal_benchmark.json` | baseline mean success `0.5833`, handoff mean success `0.5200` |
130
+ | compact-phase vs released baseline | `artifacts/reports/reveal_phase_compare_serious_compact/reveal_benchmark.json` | baseline mean success `0.5833`, compact-phase mean success `0.5133` |
131
+ | spatial-phase with compact world model vs released baseline | `artifacts/reports/reveal_phase_compare_serious_spatial_compactwm/reveal_benchmark.json` | baseline mean success `0.5833`, spatial-phase compact-world-model mean success `0.4933` |
132
+
133
+ ### Proxy ablations
134
+
135
+ | Artifact | File | Raw values |
136
+ | --- | --- | --- |
137
+ | compact-phase ablations | `artifacts/reports/reveal_phase_ablations_compact/ablations.json` | full `0.5133`, `no_geometry` `0.5133`, `no_spatial_memory` `0.4967`, `compact_world_model` `0.5133`, `no_planner` `0.4333`, `gaussian_candidates_only` `0.4667`, `no_task_head` `0.5133`, `no_support_mode_conditioning` `0.5133` |
138
+
139
+ ### RLBench direct-policy runs
140
+
141
+ | Artifact | File | Raw values |
142
+ | --- | --- | --- |
143
+ | lift-ball wide checkpoint, one-step replanning | `artifacts/reports/rlbench_lift_ball_wide_len160_ep1_ik_c1/rollout_eval.json` | mean success `0.0`, mean return `0.0`, path recoveries `[148]`, noop fallbacks `[11]` |
144
+ | push-box step-1 checkpoint, one-step replanning | `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1/rollout_eval.json` | mean success `0.0`, mean return `0.0`, path recoveries `[177]`, noop fallbacks `[0]` |
145
+ | push-box step-1 checkpoint, one-step replanning, `delta_scale=0.05` | `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1_s005/rollout_eval.json` | mean success `0.0`, mean return `0.0`, path recoveries `[180]`, noop fallbacks `[0]` |
146
+
147
+ ### RLBench retrieval runs
148
+
149
+ | Artifact | File | Raw values |
150
+ | --- | --- | --- |
151
+ | push-box kNN, `bank_stride=4`, `top_k=5`, `time_window=8`, `episodes=1` | `artifacts/reports/rlbench_push_box_knn_step1_ep1/rollout_eval.json` | mean success `1.0`, mean return `1.0`, bank size `2815` |
152
+ | push-box kNN, `bank_stride=4`, `top_k=5`, `time_window=8`, `episodes=5` | `artifacts/reports/rlbench_push_box_knn_step1_ep5/rollout_eval.json` | successes `[0.0, 1.0, 0.0, 0.0, 0.0]`, mean success `0.2`, bank size `2815` |
153
+ | push-box kNN, `bank_stride=1`, `top_k=1`, `time_window=4`, `episodes=5` | `artifacts/reports/rlbench_push_box_knn_step1_ep5_top1_dense/rollout_eval.json` | successes `[0.0, 0.0, 1.0, 1.0, 0.0]`, mean success `0.4`, bank size `11259` |
154
+
155
+ ## Environment Recreation Files
156
+
157
+ - `environment/setup_same_machine.sh`
158
+ - `environment/validate_same_machine.sh`
159
+ - `environment/run_peract2_13_rollouts.sh`
160
+ - `environment/runtime_env_vars.sh`
161
+ - `environment/hardware_snapshot.txt`
162
+ - `environment/glxinfo_B.txt`
163
+ - `environment/upstream_revisions.txt`
164
+ - `environment/system_packages_same_machine.txt`
165
+ - `environment/rlbench_env_export.yaml`
166
+ - `environment/rlbench_env_explicit.txt`
167
+ - `environment/rlbench_pip_freeze.txt`
168
+ - `environment/reveal_env_export.yaml`
169
+ - `environment/reveal_env_explicit.txt`
170
+ - `environment/reveal_pip_freeze.txt`
171
+
172
+ Detailed raw tables for the `2026-03-25/26` work are in `results/session_results_20260326.md`.
metadata/source_sizes.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ 13G /workspace/VLAarchtests
2
+ 2.2G /workspace/third_party/AnyBimanual
3
+ 54G /workspace/baselines
4
+ 219M /workspace/reports
metadata/staged_size.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 69G /workspace/hf_publish/VLAarchtests2
metadata/staged_tree_top2.txt ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /workspace/hf_publish/VLAarchtests2/CHANGE_AND_TEST_LOG.md
2
+ /workspace/hf_publish/VLAarchtests2/MODEL_AND_ARTIFACT_INDEX.md
3
+ /workspace/hf_publish/VLAarchtests2/README.md
4
+ /workspace/hf_publish/VLAarchtests2/RESULTS_RAW.md
5
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests
6
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/.cache
7
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/.gitattributes
8
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/.gitignore
9
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/.pytest_cache
10
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/FILE_MANIFEST.txt
11
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/MODEL_INDEX.md
12
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/README.md
13
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/artifacts
14
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/code
15
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/results
16
+ /workspace/hf_publish/VLAarchtests2/VLAarchtests/tests
17
+ /workspace/hf_publish/VLAarchtests2/baselines
18
+ /workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual
19
+ /workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_dummy_demo_root
20
+ /workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_evalroot
21
+ /workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_overlap_replay
22
+ /workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_overlap_runs
23
+ /workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_subset3_demo_root
24
+ /workspace/hf_publish/VLAarchtests2/environment
25
+ /workspace/hf_publish/VLAarchtests2/environment/base_pip_freeze.txt
26
+ /workspace/hf_publish/VLAarchtests2/environment/base_python.txt
27
+ /workspace/hf_publish/VLAarchtests2/environment/env_list.txt
28
+ /workspace/hf_publish/VLAarchtests2/environment/hardware_snapshot.txt
29
+ /workspace/hf_publish/VLAarchtests2/environment/nvidia_smi.txt
30
+ /workspace/hf_publish/VLAarchtests2/environment/rlbench_pip_freeze.txt
31
+ /workspace/hf_publish/VLAarchtests2/environment/rlbench_python.txt
32
+ /workspace/hf_publish/VLAarchtests2/environment/runtime_env_vars.sh
33
+ /workspace/hf_publish/VLAarchtests2/environment/setup_same_hardware.sh
34
+ /workspace/hf_publish/VLAarchtests2/environment/uname.txt
35
+ /workspace/hf_publish/VLAarchtests2/handoff
36
+ /workspace/hf_publish/VLAarchtests2/handoff/instructions4.md
37
+ /workspace/hf_publish/VLAarchtests2/history
38
+ /workspace/hf_publish/VLAarchtests2/history/VLAarchtests_previous_README.md
39
+ /workspace/hf_publish/VLAarchtests2/metadata
40
+ /workspace/hf_publish/VLAarchtests2/metadata/source_sizes.txt
41
+ /workspace/hf_publish/VLAarchtests2/metadata/staged_size.txt
42
+ /workspace/hf_publish/VLAarchtests2/metadata/staged_tree_top2.txt
43
+ /workspace/hf_publish/VLAarchtests2/reports
44
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_chain.log
45
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_eval.log
46
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_eval_watcher.log
47
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_train.log
48
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb2_train.log
49
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb3_eval.log
50
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb3_train.log
51
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb3_train_presavefix.log
52
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb_train.log
53
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_train.log
54
+ /workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_train.log
55
+ /workspace/hf_publish/VLAarchtests2/reports/peract2_13_launch_smoke_live
56
+ /workspace/hf_publish/VLAarchtests2/reports/rlbench_common23_exec_calib
57
+ /workspace/hf_publish/VLAarchtests2/reports/rlbench_common23_exec_calib_iso
58
+ /workspace/hf_publish/VLAarchtests2/reports/rlbench_debug_common23_pushbox_ep1
59
+ /workspace/hf_publish/VLAarchtests2/reports/rlbench_general_debug
60
+ /workspace/hf_publish/VLAarchtests2/reports/rlbench_subset3_common23_live_ep1
61
+ /workspace/hf_publish/VLAarchtests2/reports/run_bag_selector_iter9_prebuild.log
62
+ /workspace/hf_publish/VLAarchtests2/reports/true_baseline_compare_subset3_v1
63
+ /workspace/hf_publish/VLAarchtests2/third_party
64
+ /workspace/hf_publish/VLAarchtests2/third_party/AnyBimanual
third_party/AnyBimanual/agents/__init__.py ADDED
File without changes
third_party/AnyBimanual/agents/agent_factory.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import logging
3
+
4
+ from omegaconf import DictConfig
5
+
6
+
7
+ from yarr.agents.agent import BimanualAgent
8
+ from yarr.agents.agent import LeaderFollowerAgent
9
+ from yarr.agents.agent import Agent
10
+
11
+
12
+ supported_agents = {"leader_follower": ("PERACT_BC", "RVT"),
13
+ "independent" : ("PERACT_BC", "RVT"),
14
+ "bimanual": ("BIMANUAL_PERACT", "ACT_BC_LANG"),
15
+ "unimanual": ()}
16
+
17
+
18
+ def create_agent(cfg: DictConfig) -> Agent:
19
+
20
+ method_name = cfg.method.name
21
+ agent_type = cfg.method.agent_type
22
+
23
+ logging.info("Using method %s with type %s", method_name, agent_type)
24
+
25
+ assert(method_name in supported_agents[agent_type])
26
+
27
+ agent_fn = agent_fn_by_name(method_name)
28
+
29
+ if agent_type == "leader_follower":
30
+ checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
31
+ cfg.method.robot_name = "right"
32
+ cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_leader"
33
+ leader_agent = agent_fn(cfg)
34
+
35
+ cfg.method.robot_name = "left"
36
+ cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_follower"
37
+ cfg.method.low_dim_size = cfg.method.low_dim_size + 8 # also add the action size
38
+ follower_agent = agent_fn(cfg)
39
+
40
+ cfg.method.robot_name = "bimanual"
41
+
42
+ return LeaderFollowerAgent(leader_agent, follower_agent)
43
+
44
+ elif agent_type == "independent":
45
+ checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
46
+ cfg.method.robot_name = "right"
47
+ cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_right"
48
+ right_agent = agent_fn(cfg)
49
+
50
+ cfg.method.robot_name = "left"
51
+ cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_left"
52
+ left_agent = agent_fn(cfg)
53
+
54
+ cfg.method.robot_name = "bimanual"
55
+
56
+ return BimanualAgent(right_agent, left_agent)
57
+ elif agent_type == "bimanual" or agent_type == "unimanual":
58
+
59
+ return agent_fn(cfg)
60
+ else:
61
+ raise Exception("invalid agent type")
62
+
63
+
64
+ def agent_fn_by_name(method_name: str) -> Agent:
65
+ if method_name == "ARM":
66
+ from agents import arm
67
+
68
+ raise NotImplementedError("ARM not yet supported for eval.py")
69
+ elif method_name == "BC_LANG":
70
+ from agents.baselines import bc_lang
71
+
72
+ return bc_lang.launch_utils.create_agent
73
+ elif method_name == "VIT_BC_LANG":
74
+ from agents.baselines import vit_bc_lang
75
+
76
+ return vit_bc_lang.launch_utils.create_agent
77
+ elif method_name == "C2FARM_LINGUNET_BC":
78
+ from agents import c2farm_lingunet_bc
79
+
80
+ return c2farm_lingunet_bc.launch_utils.create_agent
81
+ elif method_name.startswith("PERACT_BC"):
82
+ from agents import peract_bc
83
+
84
+ return peract_bc.launch_utils.create_agent
85
+ elif method_name.startswith("BIMANUAL_PERACT"):
86
+ from agents import peract_bimanual
87
+
88
+ return peract_bimanual.launch_utils.create_agent
89
+ elif method_name.startswith("RVT"):
90
+ from agents import rvt
91
+
92
+ return rvt.launch_utils.create_agent
93
+ elif method_name.startswith("ACT_BC_LANG"):
94
+ from agents import act_bc_lang
95
+
96
+ return act_bc_lang.launch_utils.create_agent
97
+ elif method_name == "PERACT_RL":
98
+ raise NotImplementedError("PERACT_RL not yet supported for eval.py")
99
+
100
+ else:
101
+ raise ValueError("Method %s does not exists." % method_name)
third_party/AnyBimanual/agents/peract_bc/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.peract_bc.launch_utils
third_party/AnyBimanual/agents/peract_bc/launch_utils.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+
5
+
6
+ from helpers.preprocess_agent import PreprocessAgent
7
+ from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder
8
+ from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
9
+ from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent
10
+ import pickle
11
+ import torch
12
+ from agents.peract_bc.skill_manager import SkillManager
13
+ from agents.peract_bc.visual_aligner import VisualAligner
14
+ from omegaconf import DictConfig
15
+ import os
16
+
17
+
18
+ def create_agent(cfg: DictConfig):
19
+ LATENT_SIZE = 64
20
+ depth_0bounds = cfg.rlbench.scene_bounds
21
+ cam_resolution = cfg.rlbench.camera_resolution
22
+
23
+ num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
24
+ qattention_agents = []
25
+
26
+ current_dir = os.path.dirname(os.path.abspath(__file__))
27
+ pkl_path = os.path.join(current_dir, "../../lang_token.pkl")
28
+ pkl_path = os.path.abspath(pkl_path)
29
+ with open(pkl_path, "rb") as f:
30
+ embeddings_dict = pickle.load(f)
31
+ flattened_embeddings = []
32
+ for key in embeddings_dict.keys():
33
+ embedding = torch.tensor(embeddings_dict[key])
34
+ flattened_embedding = embedding.view(-1)
35
+ flattened_embeddings.append(flattened_embedding)
36
+ embeddings_matrix = torch.stack(flattened_embeddings)
37
+ # The released AnyBimanual checkpoints were trained with a wider skill-manager
38
+ # hidden size than the public repo currently hardcodes. Keep the original
39
+ # default for non-AnyBimanual runs, but use the released width when loading
40
+ # the AnyBimanual checkpoint family.
41
+ skill_manager_hidden_size = int(
42
+ getattr(cfg.framework, "skill_manager_hidden_size", 256 if cfg.framework.anybimanual else 128)
43
+ )
44
+ skill_manager = SkillManager(
45
+ num_classes=18,
46
+ embedding_matrix=embeddings_matrix,
47
+ hidden_size=skill_manager_hidden_size,
48
+ )
49
+ visual_aligner = VisualAligner()
50
+
51
+ for depth, vox_size in enumerate(cfg.method.voxel_sizes):
52
+ last = depth == len(cfg.method.voxel_sizes) - 1
53
+ perceiver_encoder = PerceiverVoxelLangEncoder(
54
+ depth=cfg.method.transformer_depth,
55
+ iterations=cfg.method.transformer_iterations,
56
+ voxel_size=vox_size,
57
+ initial_dim=3 + 3 + 1 + 3,
58
+ low_dim_size=cfg.method.low_dim_size,
59
+ layer=depth,
60
+ num_rotation_classes=num_rotation_classes if last else 0,
61
+ num_grip_classes=2 if last else 0,
62
+ num_collision_classes=2 if last else 0,
63
+ input_axis=3,
64
+ num_latents=cfg.method.num_latents,
65
+ latent_dim=cfg.method.latent_dim,
66
+ cross_heads=cfg.method.cross_heads,
67
+ latent_heads=cfg.method.latent_heads,
68
+ cross_dim_head=cfg.method.cross_dim_head,
69
+ latent_dim_head=cfg.method.latent_dim_head,
70
+ weight_tie_layers=False,
71
+ activation=cfg.method.activation,
72
+ pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
73
+ input_dropout=cfg.method.input_dropout,
74
+ attn_dropout=cfg.method.attn_dropout,
75
+ decoder_dropout=cfg.method.decoder_dropout,
76
+ lang_fusion_type=cfg.method.lang_fusion_type,
77
+ voxel_patch_size=cfg.method.voxel_patch_size,
78
+ voxel_patch_stride=cfg.method.voxel_patch_stride,
79
+ no_skip_connection=cfg.method.no_skip_connection,
80
+ no_perceiver=cfg.method.no_perceiver,
81
+ no_language=cfg.method.no_language,
82
+ final_dim=cfg.method.final_dim,
83
+ anybimanual=cfg.framework.anybimanual,
84
+ skill_manager = skill_manager,
85
+ visual_aligner = visual_aligner
86
+ )
87
+
88
+ qattention_agent = QAttentionPerActBCAgent(
89
+ layer=depth,
90
+ coordinate_bounds=depth_0bounds,
91
+ perceiver_encoder=perceiver_encoder,
92
+ camera_names=cfg.rlbench.cameras,
93
+ voxel_size=vox_size,
94
+ bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
95
+ image_crop_size=cfg.method.image_crop_size,
96
+ lr=cfg.method.lr,
97
+ training_iterations=cfg.framework.training_iterations,
98
+ lr_scheduler=cfg.method.lr_scheduler,
99
+ num_warmup_steps=cfg.method.num_warmup_steps,
100
+ trans_loss_weight=cfg.method.trans_loss_weight,
101
+ rot_loss_weight=cfg.method.rot_loss_weight,
102
+ grip_loss_weight=cfg.method.grip_loss_weight,
103
+ collision_loss_weight=cfg.method.collision_loss_weight,
104
+ include_low_dim_state=True,
105
+ image_resolution=cam_resolution,
106
+ batch_size=cfg.replay.batch_size,
107
+ voxel_feature_size=3,
108
+ lambda_weight_l2=cfg.method.lambda_weight_l2,
109
+ num_rotation_classes=num_rotation_classes,
110
+ rotation_resolution=cfg.method.rotation_resolution,
111
+ transform_augmentation=cfg.method.transform_augmentation.apply_se3,
112
+ transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
113
+ transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
114
+ transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
115
+ optimizer_type=cfg.method.optimizer,
116
+ num_devices=cfg.ddp.num_devices,
117
+ checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix,
118
+ anybimanual=cfg.framework.anybimanual,
119
+ )
120
+ qattention_agents.append(qattention_agent)
121
+
122
+ rotation_agent = QAttentionStackAgent(
123
+ qattention_agents=qattention_agents,
124
+ rotation_resolution=cfg.method.rotation_resolution,
125
+ camera_names=cfg.rlbench.cameras,
126
+ )
127
+ preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
128
+ return preprocess_agent
third_party/AnyBimanual/agents/peract_bc/perceiver_lang_io.py ADDED
@@ -0,0 +1,481 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perceiver IO implementation adpated for manipulation
2
+ # Source: https://github.com/lucidrains/perceiver-pytorch
3
+ # License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
4
+
5
+ import torch
6
+ from torch import nn
7
+
8
+ from einops import rearrange
9
+ from einops import repeat
10
+ import torch.nn.functional as F
11
+ from perceiver_pytorch.perceiver_pytorch import cache_fn
12
+ from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
13
+
14
+ from helpers.network_utils import (
15
+ DenseBlock,
16
+ SpatialSoftmax3D,
17
+ Conv3DBlock,
18
+ Conv3DUpsampleBlock,
19
+ )
20
+ def symmetric_kl_divergence(left, right):
21
+ eps = 1e-2
22
+ left_prob = torch.clamp(F.log_softmax(left, dim=-1), min=-10, max=10)
23
+ right_prob = torch.clamp(F.log_softmax(right, dim=-1), min=-10, max=10)
24
+
25
+ kl_left_to_right = F.kl_div(left_prob, right_prob.exp(), reduction="batchmean")*eps
26
+ kl_right_to_left = F.kl_div(right_prob, left_prob.exp(), reduction="batchmean")*eps
27
+
28
+ symmetric_kl = -(kl_left_to_right + kl_right_to_left) / 2.0
29
+ return symmetric_kl
30
+
31
+ def l1_norm(tensor):
32
+ return torch.sum(torch.abs(tensor)) + 1e-4 * torch.norm(tensor)
33
+
34
+ def l2_1_norm(tensor):
35
+ l2_norm_per_skill = torch.norm(tensor, dim=-1)
36
+ return torch.sum(l2_norm_per_skill)
37
+
38
+ # PerceiverIO adapted for 6-DoF manipulation
39
+ class PerceiverVoxelLangEncoder(nn.Module):
40
+ def __init__(
41
+ self,
42
+ depth, # number of self-attention layers
43
+ iterations, # number cross-attention iterations (PerceiverIO uses just 1)
44
+ voxel_size, # N voxels per side (size: N*N*N)
45
+ initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
46
+ low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
47
+ layer=0,
48
+ num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
49
+ num_grip_classes=2, # open or not open
50
+ num_collision_classes=2, # collisions allowed or not allowed
51
+ input_axis=3, # 3D tensors have 3 axes
52
+ num_latents=512, # number of latent vectors
53
+ im_channels=64, # intermediate channel size
54
+ latent_dim=512, # dimensions of latent vectors
55
+ cross_heads=1, # number of cross-attention heads
56
+ latent_heads=8, # number of latent heads
57
+ cross_dim_head=64,
58
+ latent_dim_head=64,
59
+ activation="relu",
60
+ weight_tie_layers=False,
61
+ pos_encoding_with_lang=True,
62
+ input_dropout=0.1,
63
+ attn_dropout=0.1,
64
+ decoder_dropout=0.0,
65
+ lang_fusion_type="seq",
66
+ voxel_patch_size=9,
67
+ voxel_patch_stride=8,
68
+ no_skip_connection=False,
69
+ no_perceiver=False,
70
+ no_language=False,
71
+ final_dim=64,
72
+ anybimanual=False,
73
+ skill_manager=None,
74
+ visual_aligner=None,
75
+ ):
76
+ super().__init__()
77
+ self.depth = depth
78
+ self.layer = layer
79
+ self.init_dim = int(initial_dim)
80
+ self.iterations = iterations
81
+ self.input_axis = input_axis
82
+ self.voxel_size = voxel_size
83
+ self.low_dim_size = low_dim_size
84
+ self.im_channels = im_channels
85
+ self.pos_encoding_with_lang = pos_encoding_with_lang
86
+ self.lang_fusion_type = lang_fusion_type
87
+ self.voxel_patch_size = voxel_patch_size
88
+ self.voxel_patch_stride = voxel_patch_stride
89
+ self.num_rotation_classes = num_rotation_classes
90
+ self.num_grip_classes = num_grip_classes
91
+ self.num_collision_classes = num_collision_classes
92
+ self.final_dim = final_dim
93
+ self.input_dropout = input_dropout
94
+ self.attn_dropout = attn_dropout
95
+ self.decoder_dropout = decoder_dropout
96
+ self.no_skip_connection = no_skip_connection
97
+ self.no_perceiver = no_perceiver
98
+ self.no_language = no_language
99
+ self.anybimanual = anybimanual
100
+ self.skill_manager = skill_manager
101
+ self.visual_aligner = visual_aligner
102
+ # patchified input dimensions
103
+ spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
104
+
105
+ # 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
106
+ self.input_dim_before_seq = (
107
+ self.im_channels * 3
108
+ if self.lang_fusion_type == "concat"
109
+ else self.im_channels * 2
110
+ )
111
+ if self.anybimanual:
112
+ self.input_dim_before_seq_ = self.input_dim_before_seq*2
113
+ else:
114
+ self.input_dim_before_seq_ = self.input_dim_before_seq
115
+ # CLIP language feature dimensions
116
+ if self.anybimanual:
117
+ lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 154
118
+ else:
119
+ lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
120
+
121
+ self.lang_max_seq_len = lang_max_seq_len
122
+ # learnable positional encoding
123
+ if self.pos_encoding_with_lang:
124
+ self.pos_encoding = nn.Parameter(
125
+ torch.randn(
126
+ 1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
127
+ )
128
+ )
129
+ else:
130
+ # assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
131
+ self.pos_encoding = nn.Parameter(
132
+ torch.randn(
133
+ 1,
134
+ spatial_size,
135
+ spatial_size,
136
+ spatial_size,
137
+ self.input_dim_before_seq,
138
+ )
139
+ )
140
+
141
+ # voxel input preprocessing 1x1 conv encoder
142
+ self.input_preprocess = Conv3DBlock(
143
+ self.init_dim,
144
+ self.im_channels,
145
+ kernel_sizes=1,
146
+ strides=1,
147
+ norm=None,
148
+ activation=activation,
149
+ )
150
+
151
+ # patchify conv
152
+ self.patchify = Conv3DBlock(
153
+ self.input_preprocess.out_channels,
154
+ self.im_channels,
155
+ kernel_sizes=self.voxel_patch_size,
156
+ strides=self.voxel_patch_stride,
157
+ norm=None,
158
+ activation=activation,
159
+ )
160
+
161
+ # language preprocess
162
+ if self.lang_fusion_type == "concat":
163
+ self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
164
+ elif self.lang_fusion_type == "seq":
165
+ self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
166
+
167
+ # proprioception
168
+ if self.low_dim_size > 0:
169
+ self.proprio_preprocess = DenseBlock(
170
+ self.low_dim_size,
171
+ self.im_channels,
172
+ norm=None,
173
+ activation=activation,
174
+ )
175
+
176
+ # pooling functions
177
+ self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
178
+ self.global_maxp = nn.AdaptiveMaxPool3d(1)
179
+
180
+ # 1st 3D softmax
181
+ self.ss0 = SpatialSoftmax3D(
182
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
183
+ )
184
+ flat_size = self.im_channels * 4
185
+
186
+ # latent vectors (that are randomly initialized)
187
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
188
+
189
+ # encoder cross attention
190
+ self.cross_attend_blocks = nn.ModuleList(
191
+ [
192
+ PreNorm(
193
+ latent_dim,
194
+ Attention(
195
+ latent_dim,
196
+ self.input_dim_before_seq_,
197
+ heads=cross_heads,
198
+ dim_head=cross_dim_head,
199
+ dropout=input_dropout,
200
+ ),
201
+ context_dim=self.input_dim_before_seq_,
202
+ ),
203
+ PreNorm(latent_dim, FeedForward(latent_dim)),
204
+ ]
205
+ )
206
+
207
+ get_latent_attn = lambda: PreNorm(
208
+ latent_dim,
209
+ Attention(
210
+ latent_dim,
211
+ heads=latent_heads,
212
+ dim_head=latent_dim_head,
213
+ dropout=attn_dropout,
214
+ ),
215
+ )
216
+ get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
217
+ get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
218
+
219
+ # self attention layers
220
+ self.layers = nn.ModuleList([])
221
+ cache_args = {"_cache": weight_tie_layers}
222
+
223
+ for i in range(depth):
224
+ self.layers.append(
225
+ nn.ModuleList(
226
+ [get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
227
+ )
228
+ )
229
+
230
+ # decoder cross attention
231
+ self.decoder_cross_attn = PreNorm(
232
+ self.input_dim_before_seq_,
233
+ Attention(
234
+ self.input_dim_before_seq_,
235
+ latent_dim,
236
+ heads=cross_heads,
237
+ dim_head=cross_dim_head,
238
+ dropout=decoder_dropout,
239
+ ),
240
+ context_dim=latent_dim,
241
+ )
242
+
243
+ # upsample conv
244
+ self.up0 = Conv3DUpsampleBlock(
245
+ self.input_dim_before_seq_,
246
+ self.final_dim,
247
+ kernel_sizes=self.voxel_patch_size,
248
+ strides=self.voxel_patch_stride,
249
+ norm=None,
250
+ activation=activation,
251
+ )
252
+
253
+ # 2nd 3D softmax
254
+ self.ss1 = SpatialSoftmax3D(
255
+ spatial_size, spatial_size, spatial_size, self.input_dim_before_seq_
256
+ )
257
+
258
+ flat_size += self.input_dim_before_seq_ * 4
259
+
260
+ # final 3D softmax
261
+ self.final = Conv3DBlock(
262
+ self.im_channels
263
+ if (self.no_perceiver or self.no_skip_connection)
264
+ else self.im_channels * 2,
265
+ self.im_channels,
266
+ kernel_sizes=3,
267
+ strides=1,
268
+ norm=None,
269
+ activation=activation,
270
+ )
271
+
272
+ self.trans_decoder = Conv3DBlock(
273
+ self.final_dim,
274
+ 1,
275
+ kernel_sizes=3,
276
+ strides=1,
277
+ norm=None,
278
+ activation=None,
279
+ )
280
+
281
+ # rotation, gripper, and collision MLP layers
282
+ if self.num_rotation_classes > 0:
283
+ self.ss_final = SpatialSoftmax3D(
284
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
285
+ )
286
+
287
+ flat_size += self.im_channels * 4
288
+
289
+ self.dense0 = DenseBlock(flat_size, 256, None, activation)
290
+ self.dense1 = DenseBlock(256, self.final_dim, None, activation)
291
+
292
+ self.rot_grip_collision_ff = DenseBlock(
293
+ self.final_dim,
294
+ self.num_rotation_classes * 3
295
+ + self.num_grip_classes
296
+ + self.num_collision_classes,
297
+ None,
298
+ None,
299
+ )
300
+
301
+ def encode_text(self, x):
302
+ with torch.no_grad():
303
+ text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
304
+
305
+ text_feat = text_feat.detach()
306
+ text_emb = text_emb.detach()
307
+ text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
308
+ return text_feat, text_emb
309
+
310
+ def forward(
311
+ self,
312
+ ins,
313
+ proprio,
314
+ lang_goal_emb,
315
+ lang_token_embs,
316
+ prev_layer_voxel_grid,
317
+ bounds,
318
+ prev_layer_bounds,
319
+ mask=None,
320
+ arm=None,
321
+ ):
322
+ # preprocess input
323
+ d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
324
+
325
+ # aggregated features from 1st softmax and maxpool for MLP decoders
326
+ feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
327
+
328
+ # patchify input (5x5x5 patches)
329
+ ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
330
+
331
+ b, c, d, h, w, device = *ins.shape, ins.device
332
+ axis = [d, h, w]
333
+ assert (
334
+ len(axis) == self.input_axis
335
+ ), "input must have the same number of axis as input_axis"
336
+
337
+ # concat proprio
338
+ if self.low_dim_size > 0:
339
+ p = self.proprio_preprocess(proprio) # [B,4] -> [B,64]
340
+ p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
341
+ ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
342
+
343
+ # language ablation
344
+ if self.no_language:
345
+ lang_goal_emb = torch.zeros_like(lang_goal_emb)
346
+ lang_token_embs = torch.zeros_like(lang_token_embs)
347
+
348
+ # option 1: tile and concat lang goal to input
349
+ if self.lang_fusion_type == "concat":
350
+ lang_emb = lang_goal_emb
351
+ lang_emb = lang_emb.to(dtype=ins.dtype)
352
+ l = self.lang_preprocess(lang_emb)
353
+ l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
354
+ ins = torch.cat([ins, l], dim=1)
355
+
356
+ # channel last
357
+ ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
358
+
359
+ # add pos encoding to grid
360
+ if not self.pos_encoding_with_lang:
361
+ ins = ins + self.pos_encoding
362
+
363
+ ######################## NOTE #############################
364
+ # NOTE: If you add positional encodings ^here the lang embs
365
+ # won't have positional encodings. I accidently forgot
366
+ # to turn this off for all the experiments in the paper.
367
+ # So I guess those models were using language embs
368
+ # as a bag of words :( But it doesn't matter much for
369
+ # RLBench tasks since we don't test for novel instructions
370
+ # at test time anyway. The recommend way is to add
371
+ # positional encodings to the final input sequence
372
+ # fed into the Perceiver Transformer, as done below
373
+ # (and also in the Colab tutorial).
374
+ ###########################################################
375
+
376
+ # concat to channels of and flatten axis
377
+ queries_orig_shape = ins.shape
378
+
379
+ # rearrange input to be channel last
380
+ ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
381
+ ins_wo_prev_layers = ins
382
+ # option 2: add lang token embs as a sequence
383
+ if self.anybimanual:
384
+ l = self.lang_preprocess(lang_token_embs) # [B,77,512] -> [B,77,128]
385
+ mask_right, mask_left = self.visual_aligner(ins)
386
+ L_voxel = symmetric_kl_divergence(mask_left, mask_right)
387
+ right_skill = self.skill_manager(mask_right, l)
388
+ left_skill = self.skill_manager(mask_left, l)
389
+ right_skill = self.lang_preprocess(right_skill)
390
+ left_skill = self.lang_preprocess(left_skill)
391
+ L_skill = (
392
+ l1_norm(left_skill) + l1_norm(right_skill) +
393
+ 0.01 * (l2_1_norm(left_skill) + l2_1_norm(right_skill))
394
+ )
395
+ l_right = torch.cat((right_skill, l), dim=1)
396
+ ins_right = torch.cat((l_right, mask_right), dim=1)
397
+ l_left = torch.cat((left_skill, l), dim=1)
398
+ ins_left = torch.cat((l_left, mask_left), dim=1)
399
+ if arm == "right":
400
+ skill = right_skill
401
+ ins_ = ins_right
402
+ else:
403
+ skill = left_skill
404
+ ins_ = ins_left
405
+ if self.pos_encoding_with_lang:
406
+ ins_ = ins_ + self.pos_encoding
407
+ else:
408
+ if self.lang_fusion_type == "seq":
409
+ l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128]
410
+ ins = torch.cat((l, ins), dim=1) # [B,8077,128]
411
+ # add pos encoding to language + flattened grid (the recommended way)
412
+ if self.pos_encoding_with_lang:
413
+ ins = ins + self.pos_encoding
414
+
415
+ if self.anybimanual:
416
+ skill_l = torch.cat((skill, l), dim=1)
417
+ ins = torch.cat((skill_l, ins),dim=1)
418
+ ins = torch.cat((ins_, ins),dim=2)
419
+ # batchify latents
420
+ x = repeat(self.latents, "n d -> b n d", b=b)
421
+
422
+ cross_attn, cross_ff = self.cross_attend_blocks
423
+
424
+ for it in range(self.iterations):
425
+ # encoder cross attention
426
+ x = cross_attn(x, context=ins, mask=mask) + x
427
+ x = cross_ff(x) + x
428
+
429
+ # self-attention layers
430
+ for self_attn, self_ff in self.layers:
431
+ x = self_attn(x) + x
432
+ x = self_ff(x) + x
433
+
434
+ # decoder cross attention
435
+ latents = self.decoder_cross_attn(ins, context=x)
436
+ # crop out the language part of the output sequence
437
+ if self.lang_fusion_type == "seq":
438
+ latents = latents[:, self.lang_max_seq_len :]
439
+
440
+ # reshape back to voxel grid
441
+ latents = latents.view(
442
+ b, *queries_orig_shape[1:-1], latents.shape[-1]
443
+ ) # [B,20,20,20,64]
444
+ latents = rearrange(latents, "b ... d -> b d ...") # [B,64,20,20,20]
445
+
446
+ # aggregated features from 2nd softmax and maxpool for MLP decoders
447
+ feats.extend(
448
+ [self.ss1(latents.contiguous()), self.global_maxp(latents).view(b, -1)]
449
+ )
450
+
451
+ # upsample
452
+ u0 = self.up0(latents)
453
+
454
+ # ablations
455
+ if self.no_skip_connection:
456
+ u = self.final(u0)
457
+ elif self.no_perceiver:
458
+ u = self.final(d0)
459
+ else:
460
+ u = self.final(torch.cat([d0, u0], dim=1))
461
+
462
+ # translation decoder
463
+ trans = self.trans_decoder(u)
464
+
465
+ # rotation, gripper, and collision MLPs
466
+ rot_and_grip_out = None
467
+ if self.num_rotation_classes > 0:
468
+ feats.extend(
469
+ [self.ss_final(u.contiguous()), self.global_maxp(u).view(b, -1)]
470
+ )
471
+
472
+ dense0 = self.dense0(torch.cat(feats, dim=1))
473
+ dense1 = self.dense1(dense0) # [B,72*3+2+2]
474
+
475
+ rot_and_grip_collision_out = self.rot_grip_collision_ff(dense1)
476
+ rot_and_grip_out = rot_and_grip_collision_out[
477
+ :, : -self.num_collision_classes
478
+ ]
479
+ collision_out = rot_and_grip_collision_out[:, -self.num_collision_classes :]
480
+
481
+ return trans, rot_and_grip_out, collision_out
third_party/AnyBimanual/agents/peract_bc/qattention_peract_bc_agent.py ADDED
@@ -0,0 +1,939 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from pytorch3d import transforms as torch3d_tf
12
+ from yarr.agents.agent import (
13
+ Agent,
14
+ ActResult,
15
+ ScalarSummary,
16
+ HistogramSummary,
17
+ ImageSummary,
18
+ Summary,
19
+ )
20
+ import matplotlib.pyplot as plt
21
+ import PIL.Image as Image
22
+ import wandb
23
+ import io
24
+ from termcolor import colored, cprint
25
+ from helpers import utils
26
+ from helpers.utils import visualise_voxel, stack_on_channel
27
+ from voxel.voxel_grid import VoxelGrid
28
+ from einops import rearrange
29
+ from helpers.clip.core.clip import build_model, load_clip
30
+
31
+ import transformers
32
+ from helpers.optim.lamb import Lamb
33
+
34
+ from torch.nn.parallel import DistributedDataParallel as DDP
35
+
36
+
37
+ class QFunction(nn.Module):
38
+ def __init__(
39
+ self,
40
+ perceiver_encoder: nn.Module,
41
+ voxelizer: VoxelGrid,
42
+ bounds_offset: float,
43
+ rotation_resolution: float,
44
+ device,
45
+ training,
46
+ ):
47
+ super(QFunction, self).__init__()
48
+ self._rotation_resolution = rotation_resolution
49
+ self._voxelizer = voxelizer
50
+ self._bounds_offset = bounds_offset
51
+ self._qnet = perceiver_encoder.to(device)
52
+
53
+ # distributed training
54
+ if training:
55
+ self._qnet = DDP(self._qnet, device_ids=[device], find_unused_parameters=True)
56
+
57
+ def _argmax_3d(self, tensor_orig):
58
+ b, c, d, h, w = tensor_orig.shape # c will be one
59
+ idxs = tensor_orig.view(b, c, -1).argmax(-1)
60
+ indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
61
+ return indices
62
+
63
+ def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
64
+ coords = self._argmax_3d(q_trans)
65
+ rot_and_grip_indicies = None
66
+ ignore_collision = None
67
+ if q_rot_grip is not None:
68
+ q_rot = torch.stack(
69
+ torch.split(
70
+ q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
71
+ ),
72
+ dim=1,
73
+ )
74
+ rot_and_grip_indicies = torch.cat(
75
+ [
76
+ q_rot[:, 0:1].argmax(-1),
77
+ q_rot[:, 1:2].argmax(-1),
78
+ q_rot[:, 2:3].argmax(-1),
79
+ q_rot_grip[:, -2:].argmax(-1, keepdim=True),
80
+ ],
81
+ -1,
82
+ )
83
+ ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
84
+ return coords, rot_and_grip_indicies, ignore_collision
85
+
86
+ def forward(
87
+ self,
88
+ rgb_pcd,
89
+ proprio,
90
+ pcd,
91
+ lang_goal_emb,
92
+ lang_token_embs,
93
+ bounds=None,
94
+ prev_bounds=None,
95
+ prev_layer_voxel_grid=None,
96
+ arm=None,
97
+ ):
98
+ # rgb_pcd will be list of list (list of [rgb, pcd])
99
+ b = rgb_pcd[0][0].shape[0]
100
+ pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
101
+
102
+ # flatten RGBs and Pointclouds
103
+ rgb = [rp[0] for rp in rgb_pcd]
104
+ feat_size = rgb[0].shape[1]
105
+ flat_imag_features = torch.cat(
106
+ [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
107
+ )
108
+
109
+ # construct voxel grid
110
+ voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
111
+ pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
112
+ )
113
+
114
+ # swap to channels fist
115
+ voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
116
+
117
+ # batch bounds if necessary
118
+ if bounds.shape[0] != b:
119
+ bounds = bounds.repeat(b, 1)
120
+
121
+ # forward pass
122
+ q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet(
123
+ voxel_grid,
124
+ proprio,
125
+ lang_goal_emb,
126
+ lang_token_embs,
127
+ prev_layer_voxel_grid,
128
+ bounds,
129
+ prev_bounds,
130
+ arm=arm,
131
+ )
132
+
133
+ return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid
134
+
135
+
136
+ class QAttentionPerActBCAgent(Agent):
137
+ def __init__(
138
+ self,
139
+ layer: int,
140
+ coordinate_bounds: list,
141
+ perceiver_encoder: nn.Module,
142
+ camera_names: list,
143
+ batch_size: int,
144
+ voxel_size: int,
145
+ bounds_offset: float,
146
+ voxel_feature_size: int,
147
+ image_crop_size: int,
148
+ num_rotation_classes: int,
149
+ rotation_resolution: float,
150
+ lr: float = 0.0001,
151
+ lr_scheduler: bool = False,
152
+ training_iterations: int = 100000,
153
+ num_warmup_steps: int = 20000,
154
+ trans_loss_weight: float = 1.0,
155
+ rot_loss_weight: float = 1.0,
156
+ grip_loss_weight: float = 1.0,
157
+ collision_loss_weight: float = 1.0,
158
+ include_low_dim_state: bool = False,
159
+ image_resolution: list = None,
160
+ lambda_weight_l2: float = 0.0,
161
+ transform_augmentation: bool = True,
162
+ transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
163
+ transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
164
+ transform_augmentation_rot_resolution: int = 5,
165
+ optimizer_type: str = "adam",
166
+ num_devices: int = 1,
167
+ checkpoint_name_prefix=None,
168
+ anybimanual = False,
169
+ cfg=None,
170
+ ):
171
+ self._layer = layer
172
+ self._coordinate_bounds = coordinate_bounds
173
+ self._perceiver_encoder = perceiver_encoder
174
+ self._voxel_feature_size = voxel_feature_size
175
+ self._bounds_offset = bounds_offset
176
+ self._image_crop_size = image_crop_size
177
+ self._lr = lr
178
+ self._lr_scheduler = lr_scheduler
179
+ self._training_iterations = training_iterations
180
+ self._num_warmup_steps = num_warmup_steps
181
+ self._trans_loss_weight = trans_loss_weight
182
+ self._rot_loss_weight = rot_loss_weight
183
+ self._grip_loss_weight = grip_loss_weight
184
+ self._collision_loss_weight = collision_loss_weight
185
+ self._include_low_dim_state = include_low_dim_state
186
+ self._image_resolution = image_resolution or [128, 128]
187
+ self._voxel_size = voxel_size
188
+ self._camera_names = camera_names
189
+ self._num_cameras = len(camera_names)
190
+ self._batch_size = batch_size
191
+ self._lambda_weight_l2 = lambda_weight_l2
192
+ self._transform_augmentation = transform_augmentation
193
+ self._transform_augmentation_xyz = torch.from_numpy(
194
+ np.array(transform_augmentation_xyz)
195
+ )
196
+ self._transform_augmentation_rpy = transform_augmentation_rpy
197
+ self._transform_augmentation_rot_resolution = (
198
+ transform_augmentation_rot_resolution
199
+ )
200
+ self._optimizer_type = optimizer_type
201
+ self._num_devices = num_devices
202
+ self._num_rotation_classes = num_rotation_classes
203
+ self._rotation_resolution = rotation_resolution
204
+
205
+ self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
206
+ checkpoint_name_prefix = checkpoint_name_prefix or "QAttentionAgent"
207
+ self._name = f"{checkpoint_name_prefix}_layer_{self._layer}"
208
+ self.anybimanual = anybimanual
209
+ self.cfg=cfg
210
+
211
+
212
+ def build(self, training: bool, device: torch.device = None):
213
+ self._training = training
214
+
215
+ if device is None:
216
+ device = torch.device("cpu")
217
+
218
+ self._device = device
219
+
220
+ self._voxelizer = VoxelGrid(
221
+ coord_bounds=self._coordinate_bounds,
222
+ voxel_size=self._voxel_size,
223
+ device=device,
224
+ batch_size=self._batch_size if training else 1,
225
+ feature_size=self._voxel_feature_size,
226
+ max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
227
+ )
228
+
229
+ self._q = (
230
+ QFunction(
231
+ self._perceiver_encoder,
232
+ self._voxelizer,
233
+ self._bounds_offset,
234
+ self._rotation_resolution,
235
+ device,
236
+ training,
237
+ )
238
+ .to(device)
239
+ .train(training)
240
+ )
241
+
242
+ grid_for_crop = (
243
+ torch.arange(0, self._image_crop_size, device=device)
244
+ .unsqueeze(0)
245
+ .repeat(self._image_crop_size, 1)
246
+ .unsqueeze(-1)
247
+ )
248
+ self._grid_for_crop = torch.cat(
249
+ [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
250
+ ).unsqueeze(0)
251
+
252
+ self._coordinate_bounds = torch.tensor(
253
+ self._coordinate_bounds, device=device
254
+ ).unsqueeze(0)
255
+
256
+ if self._training:
257
+ # optimizer
258
+ if self._optimizer_type == "lamb":
259
+ self._optimizer = Lamb(
260
+ self._q.parameters(),
261
+ lr=self._lr,
262
+ weight_decay=self._lambda_weight_l2,
263
+ betas=(0.9, 0.999),
264
+ adam=False,
265
+ )
266
+ elif self._optimizer_type == "adam":
267
+ self._optimizer = torch.optim.Adam(
268
+ self._q.parameters(),
269
+ lr=self._lr,
270
+ weight_decay=self._lambda_weight_l2,
271
+ )
272
+ else:
273
+ raise Exception("Unknown optimizer type")
274
+
275
+ # learning rate scheduler
276
+ if self._lr_scheduler:
277
+ self._scheduler = (
278
+ transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
279
+ self._optimizer,
280
+ num_warmup_steps=self._num_warmup_steps,
281
+ num_training_steps=self._training_iterations,
282
+ num_cycles=self._training_iterations // 10000,
283
+ )
284
+ )
285
+
286
+ # one-hot zero tensors
287
+ self._action_trans_one_hot_zeros = torch.zeros(
288
+ (
289
+ self._batch_size,
290
+ 1,
291
+ self._voxel_size,
292
+ self._voxel_size,
293
+ self._voxel_size,
294
+ ),
295
+ dtype=int,
296
+ device=device,
297
+ )
298
+ self._action_rot_x_one_hot_zeros = torch.zeros(
299
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
300
+ )
301
+ self._action_rot_y_one_hot_zeros = torch.zeros(
302
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
303
+ )
304
+ self._action_rot_z_one_hot_zeros = torch.zeros(
305
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
306
+ )
307
+ self._action_grip_one_hot_zeros = torch.zeros(
308
+ (self._batch_size, 2), dtype=int, device=device
309
+ )
310
+ self._action_ignore_collisions_one_hot_zeros = torch.zeros(
311
+ (self._batch_size, 2), dtype=int, device=device
312
+ )
313
+
314
+ # print total params
315
+ logging.info(
316
+ "# Q Params: %d"
317
+ % sum(
318
+ p.numel()
319
+ for name, p in self._q.named_parameters()
320
+ if p.requires_grad and "clip" not in name
321
+ )
322
+ )
323
+ else:
324
+ for param in self._q.parameters():
325
+ param.requires_grad = False
326
+
327
+ # load CLIP for encoding language goals during evaluation
328
+ model, _ = load_clip("RN50", jit=False)
329
+ self._clip_rn50 = build_model(model.state_dict())
330
+ self._clip_rn50 = self._clip_rn50.float().to(device)
331
+ self._clip_rn50.eval()
332
+ del model
333
+
334
+ self._voxelizer.to(device)
335
+ self._q.to(device)
336
+
337
+ def _extract_crop(self, pixel_action, observation):
338
+ # Pixel action will now be (B, 2)
339
+ # observation = stack_on_channel(observation)
340
+ h = observation.shape[-1]
341
+ top_left_corner = torch.clamp(
342
+ pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
343
+ )
344
+ grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
345
+ grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
346
+ # Used for cropping the images across a batch
347
+ # swap fro y x, to x, y
348
+ grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
349
+ crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
350
+ return crop
351
+
352
+ def _preprocess_inputs(self, replay_sample):
353
+ obs = []
354
+ pcds = []
355
+ self._crop_summary = []
356
+ for n in self._camera_names:
357
+ rgb = replay_sample["%s_rgb" % n]
358
+ pcd = replay_sample["%s_point_cloud" % n]
359
+
360
+ obs.append([rgb, pcd])
361
+ pcds.append(pcd)
362
+ return obs, pcds
363
+
364
+ def _act_preprocess_inputs(self, observation):
365
+ obs, pcds = [], []
366
+ for n in self._camera_names:
367
+ rgb = observation["%s_rgb" % n]
368
+ pcd = observation["%s_point_cloud" % n]
369
+
370
+ obs.append([rgb, pcd])
371
+ pcds.append(pcd)
372
+ return obs, pcds
373
+
374
+ def _get_value_from_voxel_index(self, q, voxel_idx):
375
+ b, c, d, h, w = q.shape
376
+ q_trans_flat = q.view(b, c, d * h * w)
377
+ flat_indicies = (
378
+ voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
379
+ )[:, None].int()
380
+ highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
381
+ chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
382
+ ..., 0
383
+ ] # (B, trans + rot + grip)
384
+ return chosen_voxel_values
385
+
386
+ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
387
+ q_rot = torch.stack(
388
+ torch.split(
389
+ rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
390
+ ),
391
+ dim=1,
392
+ ) # B, 3, 72
393
+ q_grip = rot_grip_q[:, -2:]
394
+ rot_and_grip_values = torch.cat(
395
+ [
396
+ q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
397
+ q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
398
+ q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
399
+ q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
400
+ ],
401
+ -1,
402
+ )
403
+ return rot_and_grip_values
404
+
405
+ def _celoss(self, pred, labels):
406
+ return self._cross_entropy_loss(pred, labels.argmax(-1))
407
+
408
+ def _softmax_q_trans(self, q):
409
+ q_shape = q.shape
410
+ return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
411
+
412
+ def _softmax_q_rot_grip(self, q_rot_grip):
413
+ q_rot_x_flat = q_rot_grip[
414
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
415
+ ]
416
+ q_rot_y_flat = q_rot_grip[
417
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
418
+ ]
419
+ q_rot_z_flat = q_rot_grip[
420
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
421
+ ]
422
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
423
+
424
+ q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
425
+ q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
426
+ q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
427
+ q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
428
+
429
+ return torch.cat(
430
+ [
431
+ q_rot_x_flat_softmax,
432
+ q_rot_y_flat_softmax,
433
+ q_rot_z_flat_softmax,
434
+ q_grip_flat_softmax,
435
+ ],
436
+ dim=1,
437
+ )
438
+
439
+ def _softmax_ignore_collision(self, q_collision):
440
+ q_collision_softmax = F.softmax(q_collision, dim=1)
441
+ return q_collision_softmax
442
+
443
+ def update(self, step: int, replay_sample: dict) -> dict:
444
+ action_trans = replay_sample["trans_action_indicies"][
445
+ :, self._layer * 3 : self._layer * 3 + 3
446
+ ].int()
447
+ action_rot_grip = replay_sample["rot_grip_action_indicies"].int()
448
+ action_gripper_pose = replay_sample["gripper_pose"]
449
+ action_ignore_collisions = replay_sample["ignore_collisions"].int()
450
+ lang_goal_emb = replay_sample["lang_goal_emb"].float()
451
+ lang_token_embs = replay_sample["lang_token_embs"].float()
452
+ prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
453
+ prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
454
+ device = self._device
455
+ rank = device
456
+ bounds = self._coordinate_bounds.to(device)
457
+ if self._layer > 0:
458
+ cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)]
459
+ bounds = torch.cat(
460
+ [cp - self._bounds_offset, cp + self._bounds_offset], dim=1
461
+ )
462
+
463
+ proprio = None
464
+ if self._include_low_dim_state:
465
+ proprio = replay_sample["low_dim_state"]
466
+
467
+ obs, pcd = self._preprocess_inputs(replay_sample)
468
+ if proprio.shape[-1] == 4:
469
+ arm = "right"
470
+ else:
471
+ arm = "left"
472
+ # batch size
473
+ bs = pcd[0].shape[0]
474
+
475
+ # SE(3) augmentation of point clouds and actions
476
+ if self._transform_augmentation:
477
+ from voxel import augmentation
478
+ action_trans, action_rot_grip, pcd = augmentation.apply_se3_augmentation(
479
+ pcd,
480
+ action_gripper_pose,
481
+ action_trans,
482
+ action_rot_grip,
483
+ bounds,
484
+ self._layer,
485
+ self._transform_augmentation_xyz,
486
+ self._transform_augmentation_rpy,
487
+ self._transform_augmentation_rot_resolution,
488
+ self._voxel_size,
489
+ self._rotation_resolution,
490
+ self._device,
491
+ )
492
+
493
+ # forward pass
494
+ q_trans, q_rot_grip, q_collision, voxel_grid = self._q(
495
+ obs,
496
+ proprio,
497
+ pcd,
498
+ lang_goal_emb,
499
+ lang_token_embs,
500
+ bounds,
501
+ prev_layer_bounds,
502
+ prev_layer_voxel_grid,
503
+ arm=arm,
504
+ )
505
+
506
+ # argmax to choose best action
507
+ (
508
+ coords,
509
+ rot_and_grip_indicies,
510
+ ignore_collision_indicies,
511
+ ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision)
512
+
513
+ q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0
514
+
515
+ # translation one-hot
516
+ action_trans_one_hot = self._action_trans_one_hot_zeros.clone()
517
+ for b in range(bs):
518
+ gt_coord = action_trans[b, :].int()
519
+ action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1
520
+
521
+ # translation loss
522
+ q_trans_flat = q_trans.view(bs, -1)
523
+ action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1)
524
+ q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat)
525
+
526
+ with_rot_and_grip = rot_and_grip_indicies is not None
527
+ if with_rot_and_grip:
528
+ # rotation, gripper, and collision one-hots
529
+ action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
530
+ action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
531
+ action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
532
+ action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
533
+ action_ignore_collisions_one_hot = (
534
+ self._action_ignore_collisions_one_hot_zeros.clone()
535
+ )
536
+
537
+ for b in range(bs):
538
+ gt_rot_grip = action_rot_grip[b, :].int()
539
+ action_rot_x_one_hot[b, gt_rot_grip[0]] = 1
540
+ action_rot_y_one_hot[b, gt_rot_grip[1]] = 1
541
+ action_rot_z_one_hot[b, gt_rot_grip[2]] = 1
542
+ action_grip_one_hot[b, gt_rot_grip[3]] = 1
543
+
544
+ gt_ignore_collisions = action_ignore_collisions[b, :].int()
545
+ action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1
546
+
547
+ # flatten predictions
548
+ q_rot_x_flat = q_rot_grip[
549
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
550
+ ]
551
+ q_rot_y_flat = q_rot_grip[
552
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
553
+ ]
554
+ q_rot_z_flat = q_rot_grip[
555
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
556
+ ]
557
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
558
+ q_ignore_collisions_flat = q_collision
559
+
560
+ # rotation loss
561
+ q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot)
562
+ q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot)
563
+ q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot)
564
+
565
+ # gripper loss
566
+ q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot)
567
+
568
+ # collision loss
569
+ q_collision_loss += self._celoss(
570
+ q_ignore_collisions_flat, action_ignore_collisions_one_hot
571
+ )
572
+
573
+ combined_losses = (
574
+ (q_trans_loss * self._trans_loss_weight)
575
+ + (q_rot_loss * self._rot_loss_weight)
576
+ + (q_grip_loss * self._grip_loss_weight)
577
+ + (q_collision_loss * self._collision_loss_weight)
578
+ )
579
+ total_loss = combined_losses.mean()
580
+ if step % 10 == 0 and rank == 0 and wandb.run is not None:
581
+ wandb.log({
582
+ 'train/grip_loss': q_grip_loss.mean(),
583
+ 'train/trans_loss': q_trans_loss.mean(),
584
+ 'train/rot_loss': q_rot_loss.mean(),
585
+ 'train/collision_loss': q_collision_loss.mean(),
586
+ 'train/total_loss': total_loss,
587
+ }, step=step)
588
+
589
+ self._optimizer.zero_grad()
590
+ total_loss.backward()
591
+ self._optimizer.step()
592
+
593
+ self._summaries = {
594
+ "losses/total_loss": total_loss,
595
+ "losses/trans_loss": q_trans_loss.mean(),
596
+ "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
597
+ "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
598
+ "losses/collision_loss": q_collision_loss.mean()
599
+ if with_rot_and_grip
600
+ else 0.0,
601
+ }
602
+ self._wandb_summaries = {
603
+ 'losses/total_loss': total_loss,
604
+ 'losses/trans_loss': q_trans_loss.mean(),
605
+ 'losses/rot_loss': q_rot_loss.mean() if with_rot_and_grip else 0.,
606
+ 'losses/grip_loss': q_grip_loss.mean() if with_rot_and_grip else 0.,
607
+ 'losses/collision_loss': q_collision_loss.mean() if with_rot_and_grip else 0.
608
+ }
609
+ if self._lr_scheduler:
610
+ self._scheduler.step()
611
+ self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
612
+
613
+ self._vis_voxel_grid = voxel_grid[0]
614
+ self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0])
615
+ self._vis_max_coordinate = coords[0]
616
+ self._vis_gt_coordinate = action_trans[0]
617
+
618
+ # Note: PerAct doesn't use multi-layer voxel grids like C2FARM
619
+ # stack prev_layer_voxel_grid(s) from previous layers into a list
620
+ if prev_layer_voxel_grid is None:
621
+ prev_layer_voxel_grid = [voxel_grid]
622
+ else:
623
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
624
+
625
+ # stack prev_layer_bound(s) from previous layers into a list
626
+ if prev_layer_bounds is None:
627
+ prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
628
+ else:
629
+ prev_layer_bounds = prev_layer_bounds + [bounds]
630
+
631
+ q_trans_vis=True
632
+ log_freq = getattr(getattr(getattr(self, "cfg", None), "framework", None), "log_freq", None)
633
+ if log_freq and step % log_freq == 0 and rank == 0:
634
+ print(f"{arm}_arm_predict: {self._vis_max_coordinate}")
635
+ print(f"{arm}_gt: {self._vis_gt_coordinate}")
636
+ rendered_img = visualise_voxel(
637
+ voxel_grid[0].cpu().detach().numpy(), # [10, 100, 100, 100]
638
+ self._vis_translation_qvalue.detach().cpu().numpy() if q_trans_vis else None,
639
+ self._vis_max_coordinate.detach().cpu().numpy(),
640
+ self._vis_gt_coordinate.detach().cpu().numpy(),
641
+ voxel_size=0.045,
642
+ # voxel_size=0.1, # more focus ??
643
+ rotation_amount=np.deg2rad(-90),
644
+ highlight_alpha=1.0,
645
+ alpha=0.4,
646
+ )
647
+ os.makedirs('recon', exist_ok=True)
648
+ # plot three images in one row with subplots:
649
+ rgb_src = obs[0][0][0].squeeze(0).permute(1, 2, 0) / 2 + 0.5
650
+
651
+ fig, axs = plt.subplots(1, 4, figsize=(9, 3))
652
+ # src
653
+ axs[0].imshow(rgb_src.cpu().numpy())
654
+ axs[0].title.set_text('src')
655
+
656
+ axs[1].imshow(rendered_img)
657
+ axs[1].text(0, 40, 'predicted', color='blue')
658
+ axs[1].text(0, 80, 'gt', color='red')
659
+ for ax in axs:
660
+ ax.axis('off')
661
+ plt.tight_layout()
662
+
663
+ if rank == 0:
664
+ if wandb.run is not None:
665
+ buf = io.BytesIO()
666
+ plt.savefig(buf, format='png')
667
+ buf.seek(0)
668
+
669
+ image = Image.open(buf)
670
+ wandb.log({"eval/recon_img": wandb.Image(image)}, step=step)
671
+
672
+ buf.close()
673
+ cprint(f'Saved to wandb', 'cyan')
674
+ else:
675
+ plt.savefig(f'recon/{step}_rgb.png')
676
+ workdir = os.getcwd()
677
+ cprint(f'Saved {workdir}/recon/{step}_rgb.png locally', 'cyan')
678
+ return {
679
+ "total_loss": total_loss,
680
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
681
+ "prev_layer_bounds": prev_layer_bounds,
682
+ }
683
+
684
+ def update_wandb_summaries(self):
685
+ summaries = dict()
686
+ for k, v in self._wandb_summaries.items():
687
+ summaries[k] = v
688
+ return summaries
689
+
690
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
691
+ deterministic = True
692
+ bounds = self._coordinate_bounds
693
+ prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
694
+ prev_layer_bounds = observation.get("prev_layer_bounds", None)
695
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
696
+
697
+ # extract CLIP language embs
698
+ with torch.no_grad():
699
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
700
+ (
701
+ lang_goal_emb,
702
+ lang_token_embs,
703
+ ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
704
+
705
+ # voxelization resolution
706
+ res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
707
+ max_rot_index = int(360 // self._rotation_resolution)
708
+ proprio = None
709
+
710
+ if self._include_low_dim_state:
711
+ proprio = observation["low_dim_state"]
712
+ proprio = proprio[0].to(self._device)
713
+
714
+ obs, pcd = self._act_preprocess_inputs(observation)
715
+
716
+ # correct batch size and device
717
+ obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
718
+ pcd = [p[0].to(self._device) for p in pcd]
719
+ lang_goal_emb = lang_goal_emb.to(self._device)
720
+ lang_token_embs = lang_token_embs.to(self._device)
721
+ bounds = torch.as_tensor(bounds, device=self._device)
722
+ prev_layer_voxel_grid = (
723
+ prev_layer_voxel_grid.to(self._device)
724
+ if prev_layer_voxel_grid is not None
725
+ else None
726
+ )
727
+ prev_layer_bounds = (
728
+ prev_layer_bounds.to(self._device)
729
+ if prev_layer_bounds is not None
730
+ else None
731
+ )
732
+
733
+ # inference
734
+ q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q(
735
+ obs,
736
+ proprio,
737
+ pcd,
738
+ lang_goal_emb,
739
+ lang_token_embs,
740
+ bounds,
741
+ prev_layer_bounds,
742
+ prev_layer_voxel_grid,
743
+ )
744
+
745
+ # softmax Q predictions
746
+ q_trans = self._softmax_q_trans(q_trans)
747
+ q_rot_grip = (
748
+ self._softmax_q_rot_grip(q_rot_grip)
749
+ if q_rot_grip is not None
750
+ else q_rot_grip
751
+ )
752
+ q_ignore_collisions = (
753
+ self._softmax_ignore_collision(q_ignore_collisions)
754
+ if q_ignore_collisions is not None
755
+ else q_ignore_collisions
756
+ )
757
+
758
+ # argmax Q predictions
759
+ (
760
+ coords,
761
+ rot_and_grip_indicies,
762
+ ignore_collisions,
763
+ ) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions)
764
+
765
+ rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None
766
+ ignore_collisions_action = (
767
+ ignore_collisions.int() if ignore_collisions is not None else None
768
+ )
769
+
770
+ coords = coords.int()
771
+ attention_coordinate = bounds[:, :3] + res * coords + res / 2
772
+
773
+ # stack prev_layer_voxel_grid(s) into a list
774
+ # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
775
+ if prev_layer_voxel_grid is None:
776
+ prev_layer_voxel_grid = [vox_grid]
777
+ else:
778
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
779
+
780
+ if prev_layer_bounds is None:
781
+ prev_layer_bounds = [bounds]
782
+ else:
783
+ prev_layer_bounds = prev_layer_bounds + [bounds]
784
+
785
+ observation_elements = {
786
+ "attention_coordinate": attention_coordinate,
787
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
788
+ "prev_layer_bounds": prev_layer_bounds,
789
+ }
790
+ info = {
791
+ "voxel_grid_depth%d" % self._layer: vox_grid,
792
+ "q_depth%d" % self._layer: q_trans,
793
+ "voxel_idx_depth%d" % self._layer: coords,
794
+ }
795
+ self._act_voxel_grid = vox_grid[0]
796
+ self._act_max_coordinate = coords[0]
797
+ self._act_qvalues = q_trans[0].detach()
798
+ return ActResult(
799
+ (coords, rot_grip_action, ignore_collisions_action),
800
+ observation_elements=observation_elements,
801
+ info=info,
802
+ )
803
+
804
+ def update_summaries(self) -> List[Summary]:
805
+ summaries = [
806
+ ImageSummary(
807
+ "%s/update_qattention" % self._name,
808
+ transforms.ToTensor()(
809
+ visualise_voxel(
810
+ self._vis_voxel_grid.detach().cpu().numpy(),
811
+ self._vis_translation_qvalue.detach().cpu().numpy(),
812
+ self._vis_max_coordinate.detach().cpu().numpy(),
813
+ self._vis_gt_coordinate.detach().cpu().numpy(),
814
+ )
815
+ ),
816
+ )
817
+ ]
818
+
819
+ for n, v in self._summaries.items():
820
+ summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
821
+
822
+ for name, crop in self._crop_summary:
823
+ crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
824
+ summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
825
+
826
+ for tag, param in self._q.named_parameters():
827
+ # assert not torch.isnan(param.grad.abs() <= 1.0).all()
828
+ summaries.append(
829
+ HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
830
+ )
831
+ summaries.append(
832
+ HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
833
+ )
834
+
835
+ return summaries
836
+
837
+ def act_summaries(self) -> List[Summary]:
838
+ return [
839
+ ImageSummary(
840
+ "%s/act_Qattention" % self._name,
841
+ transforms.ToTensor()(
842
+ visualise_voxel(
843
+ self._act_voxel_grid.cpu().numpy(),
844
+ self._act_qvalues.cpu().numpy(),
845
+ self._act_max_coordinate.cpu().numpy(),
846
+ )
847
+ ),
848
+ )
849
+ ]
850
+ def concat_weights(self, param, target_size, dims=-1):
851
+ if param.size(-1) < target_size:
852
+ param = torch.cat([param, param], dims)
853
+ return param
854
+
855
+ def load_weights(self, savedir: str):
856
+ device = (
857
+ self._device
858
+ if not self._training
859
+ else torch.device("cuda:%d" % self._device)
860
+ )
861
+ weight_file = os.path.join(savedir, "%s.pt" % self._name)
862
+ state_dict = torch.load(weight_file, map_location=device)
863
+
864
+ # load only keys that are in the current model
865
+ merged_state_dict = self._q.state_dict()
866
+ if not self._training:
867
+ for k, v in state_dict.items():
868
+ if not self._training:
869
+ k = k.replace("_qnet.module", "_qnet")
870
+ if k in merged_state_dict:
871
+ merged_state_dict[k] = v
872
+ else:
873
+ if "_voxelizer" not in k:
874
+ logging.warning("key %s not found in checkpoint" % k)
875
+ else:
876
+ for k, v in state_dict.items():
877
+ if not self._training:
878
+ k = k.replace("_qnet.module", "_qnet")
879
+ elif k == "_qnet.module.pos_encoding":
880
+ if (v.shape[1] != 8077 or v.shape[1] != 8154) and v.shape[1] < 154:
881
+ if self.anybimanual:
882
+ lang_max_seq_len = 154
883
+ else:
884
+ lang_max_seq_len = 77
885
+ spatial_size = v.shape[1]
886
+ input_dim_before_seq = v.shape[-1]
887
+ flattened_v = v.view(1, -1, input_dim_before_seq) # (1, spatial_size**3, self.input_dim_before_seq)
888
+ new_pos_encoding = torch.randn(1, lang_max_seq_len, input_dim_before_seq, device=device)
889
+ merged_pos_encoding = torch.cat([flattened_v, new_pos_encoding], dim=1) # (1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq)
890
+ merged_state_dict["_qnet.module.pos_encoding"] = merged_pos_encoding
891
+ else:
892
+ merged_state_dict["_qnet.module.pos_encoding"] = v
893
+ elif k.startswith("_qnet.module.cross_attend_blocks"):
894
+ if self.anybimanual:
895
+ if v.size(-1) == 128:
896
+ merged_state_dict[k] = self.concat_weights(v, 256)
897
+ elif k.startswith("_qnet.module.decoder_cross_attn"):
898
+ if self.anybimanual:
899
+ if v.size(0) == 128:
900
+ merged_state_dict[k] = self.concat_weights(v, 256, 0)
901
+ merged_state_dict[k] = self.concat_weights(v, 256, 0)
902
+ if v.size(-1) == 128:
903
+ merged_state_dict[k] = self.concat_weights(v, 256)
904
+ merged_state_dict[k] = self.concat_weights(v, 256)
905
+ elif k == "_qnet.module.up0.conv_up.0.conv3d.weight":
906
+ if self.anybimanual:
907
+ if v.size(1) == 128:
908
+ merged_state_dict[k] = self.concat_weights(v, 256, 1)
909
+ elif k.startswith("_qnet.module.dense0"):
910
+ if self.anybimanual:
911
+ if v.size(-1) == 1024:
912
+ merged_state_dict[k] = torch.cat([v, v[:, :512]], dim=-1)
913
+ elif k in merged_state_dict:
914
+ merged_state_dict[k] = v
915
+ else:
916
+ if "_voxelizer" not in k:
917
+ logging.warning("key %s not found in checkpoint" % k)
918
+
919
+ if not self._training:
920
+ # reshape voxelizer weights
921
+ b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
922
+ merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
923
+ "_voxelizer._ones_max_coords"
924
+ ][0:1]
925
+ flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
926
+ merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
927
+ "_voxelizer._flat_output"
928
+ ][0 : flat_shape // b]
929
+ merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
930
+ "_voxelizer._tiled_batch_indices"
931
+ ][0:1]
932
+ merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
933
+ "_voxelizer._index_grid"
934
+ ][0:1]
935
+ self._q.load_state_dict(merged_state_dict)
936
+ print("loaded weights from %s" % weight_file)
937
+
938
+ def save_weights(self, savedir: str):
939
+ torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
third_party/AnyBimanual/agents/peract_bc/qattention_stack_agent.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from yarr.agents.agent import Agent, ActResult, Summary
5
+
6
+ import numpy as np
7
+
8
+ from helpers import utils
9
+ from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
10
+
11
+ NAME = "QAttentionStackAgent"
12
+
13
+
14
+ class QAttentionStackAgent(Agent):
15
+ def __init__(
16
+ self,
17
+ qattention_agents: List[QAttentionPerActBCAgent],
18
+ rotation_resolution: float,
19
+ camera_names: List[str],
20
+ rotation_prediction_depth: int = 0,
21
+ ):
22
+ super(QAttentionStackAgent, self).__init__()
23
+ self._qattention_agents = qattention_agents
24
+ self._rotation_resolution = rotation_resolution
25
+ self._camera_names = camera_names
26
+ self._rotation_prediction_depth = rotation_prediction_depth
27
+
28
+ def build(self, training: bool, device=None) -> None:
29
+ self._device = device
30
+ if self._device is None:
31
+ self._device = torch.device("cpu")
32
+ for qa in self._qattention_agents:
33
+ qa.build(training, device)
34
+
35
+ def update(self, step: int, replay_sample: dict) -> dict:
36
+ priorities = 0
37
+ total_losses = 0.0
38
+ for qa in self._qattention_agents:
39
+ update_dict = qa.update(step, replay_sample)
40
+ replay_sample.update(update_dict)
41
+ total_losses += update_dict["total_loss"]
42
+ return {
43
+ "total_losses": total_losses,
44
+ }
45
+
46
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
47
+ observation_elements = {}
48
+ translation_results, rot_grip_results, ignore_collisions_results = [], [], []
49
+ infos = {}
50
+ for depth, qagent in enumerate(self._qattention_agents):
51
+ act_results = qagent.act(step, observation, deterministic)
52
+ attention_coordinate = (
53
+ act_results.observation_elements["attention_coordinate"].cpu().numpy()
54
+ )
55
+ observation_elements[
56
+ "attention_coordinate_layer_%d" % depth
57
+ ] = attention_coordinate[0]
58
+
59
+ translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
60
+ translation_results.append(translation_idxs)
61
+ if rot_grip_idxs is not None:
62
+ rot_grip_results.append(rot_grip_idxs)
63
+ if ignore_collisions_idxs is not None:
64
+ ignore_collisions_results.append(ignore_collisions_idxs)
65
+
66
+ observation["attention_coordinate"] = act_results.observation_elements[
67
+ "attention_coordinate"
68
+ ]
69
+ observation["prev_layer_voxel_grid"] = act_results.observation_elements[
70
+ "prev_layer_voxel_grid"
71
+ ]
72
+ observation["prev_layer_bounds"] = act_results.observation_elements[
73
+ "prev_layer_bounds"
74
+ ]
75
+
76
+ for n in self._camera_names:
77
+ px, py = utils.point_to_pixel_index(
78
+ attention_coordinate[0],
79
+ observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
80
+ observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
81
+ )
82
+ pc_t = torch.tensor(
83
+ [[[py, px]]], dtype=torch.float32, device=self._device
84
+ )
85
+ observation["%s_pixel_coord" % n] = pc_t
86
+ observation_elements["%s_pixel_coord" % n] = [py, px]
87
+
88
+ infos.update(act_results.info)
89
+
90
+ rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
91
+ ignore_collisions = float(
92
+ torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
93
+ )
94
+ observation_elements["trans_action_indicies"] = (
95
+ torch.cat(translation_results, 1)[0].cpu().numpy()
96
+ )
97
+ observation_elements["rot_grip_action_indicies"] = rgai
98
+ continuous_action = np.concatenate(
99
+ [
100
+ act_results.observation_elements["attention_coordinate"]
101
+ .cpu()
102
+ .numpy()[0],
103
+ utils.discrete_euler_to_quaternion(
104
+ rgai[-4:-1], self._rotation_resolution
105
+ ),
106
+ rgai[-1:],
107
+ [ignore_collisions],
108
+ ]
109
+ )
110
+ return ActResult(
111
+ continuous_action, observation_elements=observation_elements, info=infos
112
+ )
113
+
114
+ def update_summaries(self) -> List[Summary]:
115
+ summaries = []
116
+ for qa in self._qattention_agents:
117
+ summaries.extend(qa.update_summaries())
118
+ return summaries
119
+
120
+ def act_summaries(self) -> List[Summary]:
121
+ s = []
122
+ for qa in self._qattention_agents:
123
+ s.extend(qa.act_summaries())
124
+ return s
125
+
126
+ def load_weights(self, savedir: str):
127
+ for qa in self._qattention_agents:
128
+ qa.load_weights(savedir)
129
+
130
+ def save_weights(self, savedir: str):
131
+ for qa in self._qattention_agents:
132
+ qa.save_weights(savedir)
third_party/AnyBimanual/agents/peract_bc/skill_manager.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import transformers
4
+ from agents.peract_bimanual.trajectory_gpt2 import GPT2Model
5
+ import torch.nn.functional as F
6
+ class SkillManager(nn.Module):
7
+ def __init__(
8
+ self,
9
+ num_classes,
10
+ embedding_matrix=None,
11
+ voxel_dim=128,
12
+ lang_dim=128,
13
+ hidden_size=128,
14
+ output_dim=18,
15
+ max_voxels=8000,
16
+ max_lang_tokens=77,
17
+ **kwargs):
18
+ super().__init__()
19
+
20
+ self.hidden_size = hidden_size
21
+ self.output_dim = output_dim
22
+
23
+ # GPT-2 configuration
24
+ config = transformers.GPT2Config(
25
+ vocab_size=1, # not used
26
+ n_embd=hidden_size,
27
+ n_head=4,
28
+ n_ctx=1077,
29
+ )
30
+
31
+ self.max_voxels = max_voxels
32
+ self.max_lang_tokens = max_lang_tokens
33
+ self.embed_voxel = nn.Linear(voxel_dim, hidden_size)
34
+ self.embed_lang = nn.Linear(lang_dim, hidden_size)
35
+ self.transformer = GPT2Model(config)
36
+ self.embed_ln = nn.LayerNorm(hidden_size)
37
+ self.predict_logits = nn.Linear(hidden_size, output_dim)
38
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ self.num_class = num_classes
40
+ if embedding_matrix is not None:
41
+ self.embeddings_matrix = embedding_matrix.to(self.device)
42
+
43
+ def forward(self, voxel_embedding, language_embedding):
44
+ batch_size = voxel_embedding.shape[0]
45
+ voxel_embeddings = self.embed_voxel(voxel_embedding) # [b, 8000, hidden_size]
46
+ language_embeddings = self.embed_lang(language_embedding) # [b, 77, hidden_size]
47
+ voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, hidden_size, 8000]
48
+ voxel_embeddings = F.avg_pool1d(voxel_embeddings, kernel_size=16, stride=16) # [b, hidden_size, 1000]
49
+ voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, 1000, hidden_size]
50
+ inputs = torch.cat([language_embeddings, voxel_embeddings], dim=1) # [b, 8077, hidden_size]
51
+ stacked_inputs = self.embed_ln(inputs)
52
+ attention_mask = torch.ones(
53
+ (batch_size, self.max_lang_tokens + self.max_voxels),
54
+ device=voxel_embedding.device,
55
+ dtype=torch.long # Ensure correct dtype
56
+ )
57
+ assert torch.isfinite(attention_mask).all(), "attention_mask contains NaN or Inf"
58
+ assert torch.all((attention_mask == 1)), "attention_mask contains values not equal to 1"
59
+ transformer_outputs = self.transformer(
60
+ inputs_embeds=stacked_inputs,
61
+ attention_mask=None,
62
+ )
63
+
64
+ hidden_state = transformer_outputs.last_hidden_state # [b, 8077, hidden_size]
65
+ aggregated_hidden = hidden_state.mean(dim=1) # [b, hidden_size]
66
+ logits = self.predict_logits(aggregated_hidden) # [b, output_dim]
67
+ probs = F.softmax(logits, dim=1)
68
+ skill = torch.matmul(probs, self.embeddings_matrix.to(probs.device))
69
+ skill = skill.view(-1,77,512)
70
+ return skill
third_party/AnyBimanual/agents/peract_bc/trajectory_gpt2.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT-2 model."""
17
+
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.nn import CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.file_utils import (
28
+ ModelOutput,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ replace_return_docstrings,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPastAndCrossAttentions,
36
+ )
37
+ from transformers.modeling_utils import (
38
+ Conv1D,
39
+ PreTrainedModel,
40
+ SequenceSummary,
41
+ find_pruneable_heads_and_indices,
42
+ prune_conv1d_layer,
43
+ )
44
+ from transformers.utils import logging
45
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
46
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CONFIG_FOR_DOC = "GPT2Config"
51
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
52
+
53
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
54
+ "gpt2",
55
+ "gpt2-medium",
56
+ "gpt2-large",
57
+ "gpt2-xl",
58
+ "distilgpt2",
59
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
60
+ ]
61
+
62
+
63
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
64
+ """Load tf checkpoints in a pytorch model"""
65
+ try:
66
+ import re
67
+
68
+ import tensorflow as tf
69
+ except ImportError:
70
+ logger.error(
71
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
72
+ "https://www.tensorflow.org/install/ for installation instructions."
73
+ )
74
+ raise
75
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
76
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
77
+ # Load weights from TF model
78
+ init_vars = tf.train.list_variables(tf_path)
79
+ names = []
80
+ arrays = []
81
+ for name, shape in init_vars:
82
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
83
+ array = tf.train.load_variable(tf_path, name)
84
+ names.append(name)
85
+ arrays.append(array.squeeze())
86
+
87
+ for name, array in zip(names, arrays):
88
+ name = name[6:] # skip "model/"
89
+ name = name.split("/")
90
+ pointer = model
91
+ for m_name in name:
92
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
93
+ scope_names = re.split(r"(\d+)", m_name)
94
+ else:
95
+ scope_names = [m_name]
96
+ if scope_names[0] == "w" or scope_names[0] == "g":
97
+ pointer = getattr(pointer, "weight")
98
+ elif scope_names[0] == "b":
99
+ pointer = getattr(pointer, "bias")
100
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
101
+ pointer = getattr(pointer, scope_names[0])
102
+ pointer = getattr(pointer, "weight")
103
+ else:
104
+ pointer = getattr(pointer, scope_names[0])
105
+ if len(scope_names) >= 2:
106
+ num = int(scope_names[1])
107
+ pointer = pointer[num]
108
+ try:
109
+ assert (
110
+ pointer.shape == array.shape
111
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
112
+ except AssertionError as e:
113
+ e.args += (pointer.shape, array.shape)
114
+ raise
115
+ logger.info("Initialize PyTorch weight {}".format(name))
116
+ pointer.data = torch.from_numpy(array)
117
+ return model
118
+
119
+
120
+ class Attention(nn.Module):
121
+ def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
122
+ super().__init__()
123
+
124
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
125
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
126
+ assert n_state % config.n_head == 0
127
+ self.register_buffer(
128
+ "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
129
+ )
130
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
131
+ self.n_head = config.n_head
132
+ self.split_size = n_state
133
+ self.scale = scale
134
+ self.is_cross_attention = is_cross_attention
135
+ if self.is_cross_attention:
136
+ self.c_attn = Conv1D(2 * n_state, nx)
137
+ self.q_attn = Conv1D(n_state, nx)
138
+ else:
139
+ self.c_attn = Conv1D(3 * n_state, nx)
140
+ self.c_proj = Conv1D(n_state, nx)
141
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
142
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
143
+ self.pruned_heads = set()
144
+
145
+ def prune_heads(self, heads):
146
+ if len(heads) == 0:
147
+ return
148
+ heads, index = find_pruneable_heads_and_indices(
149
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
150
+ )
151
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
152
+
153
+ # Prune conv1d layers
154
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
155
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
156
+
157
+ # Update hyper params
158
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
159
+ self.n_head = self.n_head - len(heads)
160
+ self.pruned_heads = self.pruned_heads.union(heads)
161
+
162
+ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
163
+ w = torch.matmul(q, k)
164
+ if self.scale:
165
+ w = w / (float(v.size(-1)) ** 0.5)
166
+ nd, ns = w.size(-2), w.size(-1)
167
+
168
+ if not self.is_cross_attention:
169
+ # if only "normal" attention layer implements causal mask
170
+ mask = self.bias[:, :, ns - nd: ns, :ns]
171
+ w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
172
+
173
+ if attention_mask is not None:
174
+ # Apply the attention mask
175
+ w = w + attention_mask
176
+
177
+ w = nn.Softmax(dim=-1)(w)
178
+ w = self.attn_dropout(w)
179
+
180
+ # Mask heads if we want to
181
+ if head_mask is not None:
182
+ w = w * head_mask
183
+
184
+ outputs = [torch.matmul(w, v)]
185
+ if output_attentions:
186
+ outputs.append(w)
187
+ return outputs
188
+
189
+ def merge_heads(self, x):
190
+ x = x.permute(0, 2, 1, 3).contiguous()
191
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
192
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
193
+
194
+ def split_heads(self, x, k=False):
195
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
196
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
197
+ if k:
198
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
199
+ else:
200
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states,
205
+ layer_past=None,
206
+ attention_mask=None,
207
+ head_mask=None,
208
+ encoder_hidden_states=None,
209
+ encoder_attention_mask=None,
210
+ use_cache=False,
211
+ output_attentions=False,
212
+ ):
213
+ if encoder_hidden_states is not None:
214
+ assert hasattr(
215
+ self, "q_attn"
216
+ ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
217
+ query = self.q_attn(hidden_states)
218
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
219
+ attention_mask = encoder_attention_mask
220
+ else:
221
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
222
+
223
+ query = self.split_heads(query)
224
+ key = self.split_heads(key, k=True)
225
+ value = self.split_heads(value)
226
+ if layer_past is not None:
227
+ past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
228
+ key = torch.cat((past_key, key), dim=-1)
229
+ value = torch.cat((past_value, value), dim=-2)
230
+
231
+ if use_cache is True:
232
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
233
+ else:
234
+ present = (None,)
235
+
236
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
237
+ a = attn_outputs[0]
238
+
239
+ a = self.merge_heads(a)
240
+ a = self.c_proj(a)
241
+ a = self.resid_dropout(a)
242
+
243
+ outputs = [a, present] + attn_outputs[1:]
244
+ return outputs # a, present, (attentions)
245
+
246
+
247
+ class MLP(nn.Module):
248
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
249
+ super().__init__()
250
+ nx = config.n_embd
251
+ self.c_fc = Conv1D(n_state, nx)
252
+ self.c_proj = Conv1D(nx, n_state)
253
+ self.act = ACT2FN[config.activation_function]
254
+ self.dropout = nn.Dropout(config.resid_pdrop)
255
+
256
+ def forward(self, x):
257
+ h = self.act(self.c_fc(x))
258
+ h2 = self.c_proj(h)
259
+ return self.dropout(h2)
260
+
261
+
262
+ class AdapterMLP(nn.Module):
263
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
264
+ super().__init__()
265
+ nx = config.n_embd
266
+ self.c_fc = Conv1D(n_state, nx)
267
+ self.c_proj = Conv1D(nx, n_state)
268
+ self.act = ACT2FN[config.activation_function]
269
+ self.dropout = nn.Dropout(config.resid_pdrop)
270
+
271
+ def forward(self, x):
272
+ h = self.act(self.c_fc(x))
273
+ h2 = self.c_proj(h)
274
+ return self.dropout(h2)
275
+
276
+
277
+ class Block(nn.Module):
278
+ def __init__(self, n_ctx, config, scale=False):
279
+ super().__init__()
280
+ hidden_size = config.n_embd
281
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
282
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
283
+ self.attn = Attention(hidden_size, n_ctx, config, scale)
284
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
285
+ # self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
286
+ if config.add_cross_attention:
287
+ self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
288
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
289
+ self.mlp = MLP(inner_dim, config)
290
+ # self.adapter_mlp = AdapterMLP(512, config) # ADAPTER
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states,
295
+ layer_past=None,
296
+ attention_mask=None,
297
+ head_mask=None,
298
+ encoder_hidden_states=None,
299
+ encoder_attention_mask=None,
300
+ use_cache=False,
301
+ output_attentions=False,
302
+ ):
303
+ attn_outputs = self.attn(
304
+ self.ln_1(hidden_states),
305
+ layer_past=layer_past,
306
+ attention_mask=attention_mask,
307
+ head_mask=head_mask,
308
+ use_cache=use_cache,
309
+ output_attentions=output_attentions,
310
+ )
311
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
312
+ outputs = attn_outputs[1:]
313
+ # residual connection
314
+ hidden_states = attn_output + hidden_states
315
+
316
+ if encoder_hidden_states is not None:
317
+ # add one self-attention block for cross-attention
318
+ assert hasattr(
319
+ self, "crossattention"
320
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
321
+ cross_attn_outputs = self.crossattention(
322
+ self.ln_cross_attn(hidden_states),
323
+ attention_mask=attention_mask,
324
+ head_mask=head_mask,
325
+ encoder_hidden_states=encoder_hidden_states,
326
+ encoder_attention_mask=encoder_attention_mask,
327
+ output_attentions=output_attentions,
328
+ )
329
+ attn_output = cross_attn_outputs[0]
330
+ # residual connection
331
+ hidden_states = hidden_states + attn_output
332
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
333
+
334
+ feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
335
+ # residual connection
336
+ hidden_states = hidden_states + feed_forward_hidden_states
337
+ # hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states))
338
+
339
+ outputs = [hidden_states] + outputs
340
+ return outputs # hidden_states, present, (attentions, cross_attentions)
341
+
342
+
343
+ class GPT2PreTrainedModel(PreTrainedModel):
344
+ """
345
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
346
+ models.
347
+ """
348
+
349
+ config_class = GPT2Config
350
+ load_tf_weights = load_tf_weights_in_gpt2
351
+ base_model_prefix = "transformer"
352
+
353
+ def __init__(self, *inputs, **kwargs):
354
+ super().__init__(*inputs, **kwargs)
355
+
356
+ def _init_weights(self, module):
357
+ """Initialize the weights."""
358
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
359
+ # Slightly different from the TF version which uses truncated_normal for initialization
360
+ # cf https://github.com/pytorch/pytorch/pull/5617
361
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
362
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
363
+ module.bias.data.zero_()
364
+ elif isinstance(module, nn.LayerNorm):
365
+ module.bias.data.zero_()
366
+ module.weight.data.fill_(1.0)
367
+ # module.weight.data.fill_(.01) # KL: Adapter change
368
+
369
+
370
+ @dataclass
371
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
372
+ """
373
+ Base class for outputs of models predicting if two sentences are consecutive or not.
374
+ Args:
375
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
376
+ Language modeling loss.
377
+ mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
378
+ Multiple choice classification loss.
379
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
380
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
381
+ mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
382
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
383
+ past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
384
+ List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
385
+ batch_size, num_heads, sequence_length, embed_size_per_head)`).
386
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
387
+ :obj:`past_key_values` input) to speed up sequential decoding.
388
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
389
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
390
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
391
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
392
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
393
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
394
+ sequence_length, sequence_length)`.
395
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
396
+ heads.
397
+ """
398
+
399
+ loss: Optional[torch.FloatTensor] = None
400
+ mc_loss: Optional[torch.FloatTensor] = None
401
+ logits: torch.FloatTensor = None
402
+ mc_logits: torch.FloatTensor = None
403
+ past_key_values: Optional[List[torch.FloatTensor]] = None
404
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
405
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
406
+
407
+
408
+ GPT2_START_DOCSTRING = r"""
409
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
410
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
411
+ pruning heads etc.)
412
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
413
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
414
+ general usage and behavior.
415
+ Parameters:
416
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
417
+ Initializing with a config file does not load the weights associated with the model, only the
418
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
419
+ weights.
420
+ """
421
+
422
+ GPT2_INPUTS_DOCSTRING = r"""
423
+ Args:
424
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
425
+ :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
426
+ ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
427
+ sequence tokens in the vocabulary.
428
+ If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
429
+ passed as ``input_ids``.
430
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
431
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
432
+ details.
433
+ `What are input IDs? <../glossary.html#input-ids>`__
434
+ past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
435
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
436
+ :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
437
+ have their past given to this model should not be passed as ``input_ids`` as they have already been
438
+ computed.
439
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
440
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
441
+ - 1 for tokens that are **not masked**,
442
+ - 0 for tokens that are **masked**.
443
+ `What are attention masks? <../glossary.html#attention-mask>`__
444
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
445
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
446
+ 1]``:
447
+ - 0 corresponds to a `sentence A` token,
448
+ - 1 corresponds to a `sentence B` token.
449
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
450
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
451
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
452
+ config.max_position_embeddings - 1]``.
453
+ `What are position IDs? <../glossary.html#position-ids>`_
454
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
455
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
456
+ - 1 indicates the head is **not masked**,
457
+ - 0 indicates the head is **masked**.
458
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
459
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
460
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
461
+ vectors than the model's internal embedding lookup matrix.
462
+ If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
463
+ :obj:`past_key_values`).
464
+ use_cache (:obj:`bool`, `optional`):
465
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
466
+ decoding (see :obj:`past_key_values`).
467
+ output_attentions (:obj:`bool`, `optional`):
468
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
469
+ tensors for more detail.
470
+ output_hidden_states (:obj:`bool`, `optional`):
471
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
472
+ more detail.
473
+ return_dict (:obj:`bool`, `optional`):
474
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
475
+ """
476
+ PARALLELIZE_DOCSTRING = r"""
477
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
478
+ it will evenly distribute blocks across all devices.
479
+ Args:
480
+ device_map (:obj:`Dict[int, list]`, optional, defaults to None):
481
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
482
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
483
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
484
+ following number of attention modules:
485
+ - gpt2: 12
486
+ - gpt2-medium: 24
487
+ - gpt2-large: 36
488
+ - gpt2-xl: 48
489
+ Example::
490
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
491
+ model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
492
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
493
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
494
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
495
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
496
+ model.parallelize(device_map)
497
+ """
498
+ DEPARALLELIZE_DOCSTRING = r"""
499
+ Moves the model to cpu from a model parallel state.
500
+ Example::
501
+ # On a 4 GPU machine with gpt2-large:
502
+ model = GPT2LMHeadModel.from_pretrained('gpt2-large')
503
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
504
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
505
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
506
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
507
+ model.parallelize(device_map) # Splits the model across several devices
508
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
509
+ """
510
+
511
+
512
+ @add_start_docstrings(
513
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
514
+ GPT2_START_DOCSTRING,
515
+ )
516
+ class GPT2Model(GPT2PreTrainedModel):
517
+ def __init__(self, config):
518
+ super().__init__(config)
519
+
520
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
521
+ # self.wpe = nn.Embedding(config.n_positions, config.n_embd)
522
+ self.drop = nn.Dropout(config.embd_pdrop)
523
+ self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
524
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
525
+
526
+ self.init_weights()
527
+ # Model parallel
528
+ self.model_parallel = False
529
+ self.device_map = None
530
+
531
+ self.use_layers = None
532
+
533
+ def set_layers(self, num_layers):
534
+ assert 1 <= num_layers <= len(self.h)
535
+ if num_layers is not None:
536
+ num_layers -= 1
537
+ self.use_layers = num_layers
538
+
539
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
540
+ def parallelize(self, device_map=None):
541
+ # Check validity of device_map
542
+ self.device_map = (
543
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
544
+ )
545
+ assert_device_map(self.device_map, len(self.h))
546
+ self.model_parallel = True
547
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
548
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
549
+ self.wte = self.wte.to(self.first_device)
550
+ self.wpe = self.wpe.to(self.first_device)
551
+ # Load onto devices
552
+ for k, v in self.device_map.items():
553
+ for block in v:
554
+ cuda_device = "cuda:" + str(k)
555
+ self.h[block] = self.h[block].to(cuda_device)
556
+ # ln_f to last
557
+ self.ln_f = self.ln_f.to(self.last_device)
558
+
559
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
560
+ def deparallelize(self):
561
+ self.model_parallel = False
562
+ self.device_map = None
563
+ self.first_device = "cpu"
564
+ self.last_device = "cpu"
565
+ self.wte = self.wte.to("cpu")
566
+ self.wpe = self.wpe.to("cpu")
567
+ for index in range(len(self.h)):
568
+ self.h[index] = self.h[index].to("cpu")
569
+ self.ln_f = self.ln_f.to("cpu")
570
+ torch.cuda.empty_cache()
571
+
572
+ def get_input_embeddings(self):
573
+ return self.wte
574
+
575
+ def set_input_embeddings(self, new_embeddings):
576
+ self.wte = new_embeddings
577
+
578
+ def _prune_heads(self, heads_to_prune):
579
+ """
580
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
581
+ """
582
+ for layer, heads in heads_to_prune.items():
583
+ self.h[layer].attn.prune_heads(heads)
584
+
585
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
586
+ @add_code_sample_docstrings(
587
+ processor_class=_TOKENIZER_FOR_DOC,
588
+ checkpoint="gpt2",
589
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
590
+ config_class=_CONFIG_FOR_DOC,
591
+ )
592
+ def forward(
593
+ self,
594
+ input_ids=None,
595
+ past_key_values=None,
596
+ attention_mask=None,
597
+ token_type_ids=None,
598
+ position_ids=None,
599
+ head_mask=None,
600
+ inputs_embeds=None,
601
+ encoder_hidden_states=None,
602
+ encoder_attention_mask=None,
603
+ use_cache=None,
604
+ output_attentions=None,
605
+ output_hidden_states=None,
606
+ return_dict=None,
607
+ ):
608
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
609
+ output_hidden_states = (
610
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
611
+ )
612
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
613
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614
+
615
+ if input_ids is not None and inputs_embeds is not None:
616
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
617
+ elif input_ids is not None:
618
+ input_shape = input_ids.size()
619
+ input_ids = input_ids.view(-1, input_shape[-1])
620
+ batch_size = input_ids.shape[0]
621
+ elif inputs_embeds is not None:
622
+ input_shape = inputs_embeds.size()[:-1]
623
+ batch_size = inputs_embeds.shape[0]
624
+ else:
625
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
626
+
627
+ if token_type_ids is not None:
628
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
629
+ if position_ids is not None:
630
+ position_ids = position_ids.view(-1, input_shape[-1])
631
+
632
+ if past_key_values is None:
633
+ past_length = 0
634
+ past_key_values = [None] * len(self.h)
635
+ else:
636
+ past_length = past_key_values[0][0].size(-2)
637
+ if position_ids is None:
638
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
639
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
640
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
641
+
642
+ # Attention mask.
643
+ if attention_mask is not None:
644
+ assert batch_size > 0, "batch_size has to be defined and > 0"
645
+ attention_mask = attention_mask.view(batch_size, -1)
646
+ # We create a 3D attention mask from a 2D tensor mask.
647
+ # Sizes are [batch_size, 1, 1, to_seq_length]
648
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
649
+ # this attention mask is more simple than the triangular masking of causal attention
650
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
651
+ attention_mask = attention_mask[:, None, None, :]
652
+
653
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
+ # masked positions, this operation will create a tensor which is 0.0 for
655
+ # positions we want to attend and -10000.0 for masked positions.
656
+ # Since we are adding it to the raw scores before the softmax, this is
657
+ # effectively the same as removing these entirely.
658
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
+ attention_mask = (1.0 - attention_mask) * -10000.0
660
+
661
+ # If a 2D ou 3D attention mask is provided for the cross-attention
662
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
663
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
664
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
665
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
666
+ if encoder_attention_mask is None:
667
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
668
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
669
+ else:
670
+ encoder_attention_mask = None
671
+
672
+ # Prepare head mask if needed
673
+ # 1.0 in head_mask indicate we keep the head
674
+ # attention_probs has shape bsz x n_heads x N x N
675
+ # head_mask has shape n_layer x batch x n_heads x N x N
676
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
677
+
678
+ if inputs_embeds is None:
679
+ inputs_embeds = self.wte(input_ids)
680
+ # position_embeds = self.wpe(position_ids)
681
+ hidden_states = inputs_embeds # + position_embeds
682
+
683
+ if token_type_ids is not None:
684
+ token_type_embeds = self.wte(token_type_ids)
685
+ hidden_states = hidden_states + token_type_embeds
686
+
687
+ hidden_states = self.drop(hidden_states)
688
+
689
+ output_shape = input_shape + (hidden_states.size(-1),)
690
+
691
+ presents = () if use_cache else None
692
+ all_self_attentions = () if output_attentions else None
693
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
694
+ all_hidden_states = () if output_hidden_states else None
695
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
696
+
697
+ if self.use_layers is not None and i >= self.use_layers:
698
+ break
699
+
700
+ # Model parallel
701
+ if self.model_parallel:
702
+ torch.cuda.set_device(hidden_states.device)
703
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
704
+ if layer_past is not None:
705
+ layer_past = layer_past.to(hidden_states.device)
706
+ # Ensure that attention_mask is always on the same device as hidden_states
707
+ if attention_mask is not None:
708
+ attention_mask = attention_mask.to(hidden_states.device)
709
+ if isinstance(head_mask, torch.Tensor):
710
+ head_mask = head_mask.to(hidden_states.device)
711
+ if output_hidden_states:
712
+ all_hidden_states = all_hidden_states + (hidden_states,)
713
+
714
+ if getattr(self.config, "gradient_checkpointing", False):
715
+
716
+ def create_custom_forward(module):
717
+ def custom_forward(*inputs):
718
+ # checkpointing only works with tuple returns, not with lists
719
+ return tuple(output for output in module(*inputs, use_cache, output_attentions))
720
+
721
+ return custom_forward
722
+
723
+ outputs = torch.utils.checkpoint.checkpoint(
724
+ create_custom_forward(block),
725
+ hidden_states,
726
+ layer_past,
727
+ attention_mask,
728
+ head_mask[i],
729
+ encoder_hidden_states,
730
+ encoder_attention_mask,
731
+ )
732
+ else:
733
+ outputs = block(
734
+ hidden_states,
735
+ layer_past=layer_past,
736
+ attention_mask=attention_mask,
737
+ head_mask=head_mask[i],
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ encoder_attention_mask=encoder_attention_mask,
740
+ use_cache=use_cache,
741
+ output_attentions=output_attentions,
742
+ )
743
+
744
+ hidden_states, present = outputs[:2]
745
+ if use_cache is True:
746
+ presents = presents + (present,)
747
+
748
+ if output_attentions:
749
+ all_self_attentions = all_self_attentions + (outputs[2],)
750
+ if self.config.add_cross_attention:
751
+ all_cross_attentions = all_cross_attentions + (outputs[3],)
752
+
753
+ # Model Parallel: If it's the last layer for that device, put things on the next device
754
+ if self.model_parallel:
755
+ for k, v in self.device_map.items():
756
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
757
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
758
+
759
+ hidden_states = self.ln_f(hidden_states)
760
+
761
+ hidden_states = hidden_states.view(*output_shape)
762
+ # Add last hidden state
763
+ if output_hidden_states:
764
+ all_hidden_states = all_hidden_states + (hidden_states,)
765
+
766
+ if not return_dict:
767
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
768
+
769
+ return BaseModelOutputWithPastAndCrossAttentions(
770
+ last_hidden_state=hidden_states,
771
+ past_key_values=presents,
772
+ hidden_states=all_hidden_states,
773
+ attentions=all_self_attentions,
774
+ cross_attentions=all_cross_attentions,
775
+ )
third_party/AnyBimanual/agents/peract_bc/visual_aligner.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class VisualAligner(nn.Module):
6
+ def __init__(self, input_dim=128, hidden_dim=256, mask_dim=128):
7
+ super(VisualAligner, self).__init__()
8
+
9
+ self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
10
+
11
+ self.conv_res1 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
12
+ self.conv_res2 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
13
+
14
+ self.conv2_right = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
15
+ self.conv2_left = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
16
+
17
+ self.activation = nn.ReLU()
18
+
19
+ def forward(self, ins):
20
+ ins = ins.transpose(1, 2)
21
+
22
+ features = self.activation(self.conv1(ins))
23
+
24
+ residual = features
25
+ features = self.activation(self.conv_res1(features))
26
+ features = self.conv_res2(features)
27
+ features = features + residual
28
+
29
+ mask_right = self.activation(self.conv2_right(features))
30
+ mask_left = self.activation(self.conv2_left(features))
31
+
32
+ mask_right = mask_right.transpose(1, 2)
33
+ mask_left = mask_left.transpose(1, 2)
34
+ ins = ins.transpose(1, 2)
35
+
36
+ masked_ins1 = ins * mask_right
37
+ masked_ins2 = ins * mask_left
38
+
39
+ return masked_ins1, masked_ins2
third_party/AnyBimanual/agents/peract_bimanual/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ import agents.peract_bimanual.launch_utils
third_party/AnyBimanual/agents/peract_bimanual/launch_utils.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from ARM
2
+ # Source: https://github.com/stepjam/ARM
3
+ # License: https://github.com/stepjam/ARM/LICENSE
4
+ from helpers.preprocess_agent import PreprocessAgent
5
+
6
+ from agents.peract_bimanual.perceiver_lang_io import PerceiverVoxelLangEncoder
7
+ from agents.peract_bimanual.qattention_peract_bc_agent import QAttentionPerActBCAgent
8
+ from agents.peract_bimanual.qattention_stack_agent import QAttentionStackAgent
9
+ from agents.peract_bimanual.skill_manager import SkillManager
10
+ from agents.peract_bimanual.visual_aligner import VisualAligner
11
+ from omegaconf import DictConfig
12
+ import pickle
13
+ import torch
14
+ import os
15
+ def create_agent(cfg: DictConfig):
16
+ depth_0bounds = cfg.rlbench.scene_bounds
17
+ cam_resolution = cfg.rlbench.camera_resolution
18
+
19
+ num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
20
+ qattention_agents = []
21
+
22
+ current_dir = os.path.dirname(os.path.abspath(__file__))
23
+ pkl_path = os.path.join(current_dir, "../../lang_token.pkl")
24
+ pkl_path = os.path.abspath(pkl_path)
25
+ with open(pkl_path, "rb") as f:
26
+ embeddings_dict = pickle.load(f)
27
+ flattened_embeddings = []
28
+ for key in embeddings_dict.keys():
29
+ embedding = torch.tensor(embeddings_dict[key])
30
+ flattened_embedding = embedding.view(-1)
31
+ flattened_embeddings.append(flattened_embedding)
32
+ embeddings_matrix = torch.stack(flattened_embeddings)
33
+
34
+ skill_manager = SkillManager(num_classes=18,embedding_matrix=embeddings_matrix)
35
+ visual_aligner = VisualAligner()
36
+ for depth, vox_size in enumerate(cfg.method.voxel_sizes):
37
+ last = depth == len(cfg.method.voxel_sizes) - 1
38
+ if cfg.framework.use_skill:
39
+ perceiver_encoder = PerceiverVoxelLangEncoder(
40
+ depth=cfg.method.transformer_depth,
41
+ iterations=cfg.method.transformer_iterations,
42
+ voxel_size=vox_size,
43
+ initial_dim=3 + 3 + 1 + 3,
44
+ low_dim_size=cfg.method.low_dim_size,
45
+ layer=depth,
46
+ num_rotation_classes=num_rotation_classes if last else 0,
47
+ num_grip_classes=2 if last else 0,
48
+ num_collision_classes=2 if last else 0,
49
+ input_axis=3,
50
+ num_latents=cfg.method.num_latents,
51
+ latent_dim=cfg.method.latent_dim,
52
+ cross_heads=cfg.method.cross_heads,
53
+ latent_heads=cfg.method.latent_heads,
54
+ cross_dim_head=cfg.method.cross_dim_head,
55
+ latent_dim_head=cfg.method.latent_dim_head,
56
+ weight_tie_layers=False,
57
+ activation=cfg.method.activation,
58
+ pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
59
+ input_dropout=cfg.method.input_dropout,
60
+ attn_dropout=cfg.method.attn_dropout,
61
+ decoder_dropout=cfg.method.decoder_dropout,
62
+ lang_fusion_type=cfg.method.lang_fusion_type,
63
+ voxel_patch_size=cfg.method.voxel_patch_size,
64
+ voxel_patch_stride=cfg.method.voxel_patch_stride,
65
+ no_skip_connection=cfg.method.no_skip_connection,
66
+ no_perceiver=cfg.method.no_perceiver,
67
+ no_language=cfg.method.no_language,
68
+ final_dim=cfg.method.final_dim,
69
+ anybimanual=cfg.framework.anybimanual,
70
+ skill_manager = skill_manager,
71
+ visual_aligner = visual_aligner
72
+ )
73
+
74
+ qattention_agent = QAttentionPerActBCAgent(
75
+ layer=depth,
76
+ coordinate_bounds=depth_0bounds,
77
+ perceiver_encoder=perceiver_encoder,
78
+ camera_names=cfg.rlbench.cameras,
79
+ voxel_size=vox_size,
80
+ bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
81
+ image_crop_size=cfg.method.image_crop_size,
82
+ lr=cfg.method.lr,
83
+ training_iterations=cfg.framework.training_iterations,
84
+ lr_scheduler=cfg.method.lr_scheduler,
85
+ num_warmup_steps=cfg.method.num_warmup_steps,
86
+ trans_loss_weight=cfg.method.trans_loss_weight,
87
+ rot_loss_weight=cfg.method.rot_loss_weight,
88
+ grip_loss_weight=cfg.method.grip_loss_weight,
89
+ collision_loss_weight=cfg.method.collision_loss_weight,
90
+ include_low_dim_state=True,
91
+ image_resolution=cam_resolution,
92
+ batch_size=cfg.replay.batch_size,
93
+ voxel_feature_size=3,
94
+ lambda_weight_l2=cfg.method.lambda_weight_l2,
95
+ num_rotation_classes=num_rotation_classes,
96
+ rotation_resolution=cfg.method.rotation_resolution,
97
+ transform_augmentation=cfg.method.transform_augmentation.apply_se3,
98
+ transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
99
+ transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
100
+ transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
101
+ optimizer_type=cfg.method.optimizer,
102
+ num_devices=cfg.ddp.num_devices,
103
+ anybimanual=cfg.framework.anybimanual,
104
+ load_exists_weights = cfg.framework.load_existing_weights,
105
+ frozen = cfg.framework.frozen,
106
+ cfg = cfg,
107
+ aug_type=cfg.framework.augmentation_type,
108
+ )
109
+ qattention_agents.append(qattention_agent)
110
+
111
+ rotation_agent = QAttentionStackAgent(
112
+ qattention_agents=qattention_agents,
113
+ rotation_resolution=cfg.method.rotation_resolution,
114
+ camera_names=cfg.rlbench.cameras,
115
+ )
116
+ preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
117
+ return preprocess_agent
third_party/AnyBimanual/agents/peract_bimanual/perceiver_lang_io.py ADDED
@@ -0,0 +1,628 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Perceiver IO implementation adpated for manipulation
2
+ # Source: https://github.com/lucidrains/perceiver-pytorch
3
+ # License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
4
+
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from einops import rearrange
9
+ from einops import repeat
10
+ import numpy as np
11
+ from perceiver_pytorch.perceiver_pytorch import cache_fn
12
+ from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
13
+
14
+ from helpers.network_utils import (
15
+ DenseBlock,
16
+ SpatialSoftmax3D,
17
+ Conv3DBlock,
18
+ Conv3DUpsampleBlock,
19
+ )
20
+
21
+ def symmetric_kl_divergence(left, right):
22
+ eps = 1e-2
23
+ left_prob = torch.clamp(F.log_softmax(left, dim=-1), min=-10, max=10)
24
+ right_prob = torch.clamp(F.log_softmax(right, dim=-1), min=-10, max=10)
25
+
26
+ kl_left_to_right = F.kl_div(left_prob, right_prob.exp(), reduction="batchmean")*eps
27
+ kl_right_to_left = F.kl_div(right_prob, left_prob.exp(), reduction="batchmean")*eps
28
+
29
+ symmetric_kl = -(kl_left_to_right + kl_right_to_left) / 2.0
30
+ return symmetric_kl
31
+
32
+ def l1_norm(tensor):
33
+ return torch.sum(torch.abs(tensor)) + 1e-4 * torch.norm(tensor)
34
+
35
+ def l2_1_norm(tensor):
36
+ l2_norm_per_skill = torch.norm(tensor, dim=-1)
37
+ return torch.sum(l2_norm_per_skill)
38
+
39
+ torch.autograd.set_detect_anomaly(True)
40
+ # PerceiverIO adapted for 6-DoF manipulation
41
+ class PerceiverVoxelLangEncoder(nn.Module):
42
+ def __init__(
43
+ self,
44
+ depth, # number of self-attention layers
45
+ iterations, # number cross-attention iterations (PerceiverIO uses just 1)
46
+ voxel_size, # N voxels per side (size: N*N*N)
47
+ initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
48
+ low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
49
+ layer=0,
50
+ num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
51
+ num_grip_classes=2, # open or not open
52
+ num_collision_classes=2, # collisions allowed or not allowed
53
+ input_axis=3, # 3D tensors have 3 axes
54
+ num_latents=512, # number of latent vectors
55
+ im_channels=64, # intermediate channel size
56
+ latent_dim=512, # dimensions of latent vectors
57
+ cross_heads=1, # number of cross-attention heads
58
+ latent_heads=8, # number of latent heads
59
+ cross_dim_head=64,
60
+ latent_dim_head=64,
61
+ activation="relu",
62
+ weight_tie_layers=False,
63
+ pos_encoding_with_lang=True,
64
+ input_dropout=0.1,
65
+ attn_dropout=0.1,
66
+ decoder_dropout=0.0,
67
+ lang_fusion_type="seq",
68
+ voxel_patch_size=9,
69
+ voxel_patch_stride=8,
70
+ no_skip_connection=False,
71
+ no_perceiver=False,
72
+ no_language=False,
73
+ final_dim=64,
74
+ anybimanual=False,
75
+ skill_manager=None,
76
+ visual_aligner=None,
77
+ ):
78
+ super().__init__()
79
+ self.depth = depth
80
+ self.layer = layer
81
+ self.init_dim = int(initial_dim)
82
+ self.iterations = iterations
83
+ self.input_axis = input_axis
84
+ self.voxel_size = voxel_size
85
+ self.low_dim_size = low_dim_size
86
+ self.im_channels = im_channels
87
+ self.pos_encoding_with_lang = pos_encoding_with_lang
88
+ self.lang_fusion_type = lang_fusion_type
89
+ self.voxel_patch_size = voxel_patch_size
90
+ self.voxel_patch_stride = voxel_patch_stride
91
+ self.num_rotation_classes = num_rotation_classes
92
+ self.num_grip_classes = num_grip_classes
93
+ self.num_collision_classes = num_collision_classes
94
+ self.final_dim = final_dim
95
+ self.input_dropout = input_dropout
96
+ self.attn_dropout = attn_dropout
97
+ self.decoder_dropout = decoder_dropout
98
+ self.no_skip_connection = no_skip_connection
99
+ self.no_perceiver = no_perceiver
100
+ self.no_language = no_language
101
+ self.anybimanual = anybimanual
102
+ self.skill_manager = skill_manager
103
+ self.visual_aligner = visual_aligner
104
+ # patchified input dimensions
105
+ spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
106
+ # 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
107
+ self.input_dim_before_seq = (
108
+ self.im_channels * 3
109
+ if self.lang_fusion_type == "concat"
110
+ else self.im_channels * 2
111
+ )
112
+ if self.anybimanual:
113
+ self.input_dim_before_seq_ = self.input_dim_before_seq*2
114
+ else:
115
+ self.input_dim_before_seq_ = self.input_dim_before_seq
116
+ # CLIP language feature dimensions
117
+ if self.anybimanual:
118
+ lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 154
119
+ else:
120
+ lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
121
+
122
+ # learnable positional encoding
123
+ # peract2 pos_encoding_with_lang = True / peract = Falses?
124
+ if self.pos_encoding_with_lang:
125
+ self.pos_encoding = nn.Parameter(
126
+ torch.randn(
127
+ 1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
128
+ )
129
+ )
130
+ else:
131
+ # assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
132
+ self.pos_encoding = nn.Parameter(
133
+ torch.randn(
134
+ 1,
135
+ spatial_size,
136
+ spatial_size,
137
+ spatial_size,
138
+ self.input_dim_before_seq,
139
+ )
140
+ )
141
+
142
+ # voxel input preprocessing 1x1 conv encoder
143
+ self.input_preprocess = Conv3DBlock(
144
+ self.init_dim,
145
+ self.im_channels,
146
+ kernel_sizes=1,
147
+ strides=1,
148
+ norm=None,
149
+ activation=activation,
150
+ )
151
+
152
+ # patchify conv
153
+ self.patchify = Conv3DBlock(
154
+ self.input_preprocess.out_channels,
155
+ self.im_channels,
156
+ kernel_sizes=self.voxel_patch_size,
157
+ strides=self.voxel_patch_stride,
158
+ norm=None,
159
+ activation=activation,
160
+ )
161
+ # language preprocess
162
+ if self.lang_fusion_type == "concat":
163
+ self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
164
+ elif self.lang_fusion_type == "seq":
165
+ self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
166
+
167
+ # proprioception
168
+ if self.low_dim_size > 0:
169
+ self.proprio_preprocess = DenseBlock(
170
+ self.low_dim_size,
171
+ self.im_channels,
172
+ norm=None,
173
+ activation=activation,
174
+ )
175
+ # pooling functions
176
+ self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
177
+ self.global_maxp = nn.AdaptiveMaxPool3d(1)
178
+
179
+ # 1st 3D softmax
180
+ self.ss0 = SpatialSoftmax3D(
181
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
182
+ )
183
+ flat_size = self.im_channels * 4
184
+
185
+ # latent vectors (that are randomly initialized)
186
+ self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
187
+
188
+ if self.anybimanual:
189
+ self.cross_attend_blocks = nn.ModuleList(
190
+ [
191
+ PreNorm(
192
+ latent_dim,
193
+ Attention(
194
+ latent_dim,
195
+ self.input_dim_before_seq_,
196
+ heads=cross_heads,
197
+ dim_head=cross_dim_head,
198
+ dropout=input_dropout,
199
+ ),
200
+ context_dim=self.input_dim_before_seq_,
201
+ ),
202
+ PreNorm(
203
+ latent_dim,
204
+ Attention(
205
+ latent_dim,
206
+ self.input_dim_before_seq_,
207
+ heads=cross_heads,
208
+ dim_head=cross_dim_head,
209
+ dropout=input_dropout,
210
+ ),
211
+ context_dim=self.input_dim_before_seq_,
212
+ ),
213
+ PreNorm(latent_dim, FeedForward(latent_dim)),
214
+ PreNorm(latent_dim, FeedForward(latent_dim)),
215
+ ]
216
+ )
217
+ else:
218
+ # encoder cross attention
219
+ self.cross_attend_blocks = nn.ModuleList(
220
+ [
221
+ PreNorm(
222
+ latent_dim,
223
+ Attention(
224
+ latent_dim,
225
+ self.input_dim_before_seq_,
226
+ heads=cross_heads,
227
+ dim_head=cross_dim_head,
228
+ dropout=input_dropout,
229
+ ),
230
+ context_dim=self.input_dim_before_seq_,
231
+ ),
232
+ PreNorm(latent_dim, FeedForward(latent_dim)),
233
+ PreNorm(latent_dim, FeedForward(latent_dim)),
234
+ ]
235
+ )
236
+
237
+ get_latent_attn = lambda: PreNorm(
238
+ latent_dim,
239
+ Attention(
240
+ latent_dim,
241
+ heads=latent_heads,
242
+ dim_head=latent_dim_head,
243
+ dropout=attn_dropout,
244
+ ),
245
+ )
246
+ get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
247
+ get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
248
+
249
+ # self attention layers
250
+ self.layers = nn.ModuleList([])
251
+ cache_args = {"_cache": weight_tie_layers}
252
+
253
+ for i in range(depth):
254
+ self.layers.append(
255
+ nn.ModuleList(
256
+ [get_latent_attn(**cache_args), get_latent_ff(**cache_args),
257
+ get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
258
+ )
259
+ )
260
+
261
+
262
+ self.combined_latent_attn = get_latent_attn(**cache_args)
263
+ self.combined_latent_ff = get_latent_ff(**cache_args)
264
+
265
+
266
+ # decoder cross attention
267
+ self.decoder_cross_attn_right = PreNorm(
268
+ self.input_dim_before_seq_,
269
+ Attention(
270
+ self.input_dim_before_seq_,
271
+ latent_dim,
272
+ heads=cross_heads,
273
+ dim_head=cross_dim_head,
274
+ dropout=decoder_dropout,
275
+ ),
276
+ context_dim=latent_dim,
277
+ )
278
+
279
+ self.decoder_cross_attn_left = PreNorm(
280
+ self.input_dim_before_seq_,
281
+ Attention(
282
+ self.input_dim_before_seq_,
283
+ latent_dim,
284
+ heads=cross_heads,
285
+ dim_head=cross_dim_head,
286
+ dropout=decoder_dropout,
287
+ ),
288
+ context_dim=latent_dim,
289
+ )
290
+
291
+ # upsample conv
292
+ self.up0 = Conv3DUpsampleBlock(
293
+ self.input_dim_before_seq_,
294
+ self.final_dim,
295
+ kernel_sizes=self.voxel_patch_size,
296
+ strides=self.voxel_patch_stride,
297
+ norm=None,
298
+ activation=activation,
299
+ )
300
+
301
+ # 2nd 3D softmax
302
+ self.ss1 = SpatialSoftmax3D(
303
+ spatial_size, spatial_size, spatial_size, self.input_dim_before_seq_
304
+ )
305
+
306
+ flat_size += self.input_dim_before_seq_ * 4
307
+
308
+ # final 3D softmax
309
+ self.final = Conv3DBlock(
310
+ self.im_channels
311
+ if (self.no_perceiver or self.no_skip_connection)
312
+ else self.im_channels * 2,
313
+ self.im_channels,
314
+ kernel_sizes=3,
315
+ strides=1,
316
+ norm=None,
317
+ activation=activation,
318
+ )
319
+
320
+ self.right_trans_decoder = Conv3DBlock(
321
+ self.final_dim,
322
+ 1,
323
+ kernel_sizes=3,
324
+ strides=1,
325
+ norm=None,
326
+ activation=None,
327
+ )
328
+
329
+ self.left_trans_decoder = Conv3DBlock(
330
+ self.final_dim,
331
+ 1,
332
+ kernel_sizes=3,
333
+ strides=1,
334
+ norm=None,
335
+ activation=None,
336
+ )
337
+
338
+ # rotation, gripper, and collision MLP layers
339
+ if self.num_rotation_classes > 0:
340
+ self.ss_final = SpatialSoftmax3D(
341
+ self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
342
+ )
343
+
344
+ flat_size += self.im_channels * 4
345
+
346
+ self.right_dense0 = DenseBlock(flat_size, 256, None, activation)
347
+ self.right_dense1 = DenseBlock(256, self.final_dim, None, activation)
348
+
349
+ self.left_dense0 = DenseBlock(flat_size, 256, None, activation)
350
+ self.left_dense1 = DenseBlock(256, self.final_dim, None, activation)
351
+
352
+ self.right_rot_grip_collision_ff = DenseBlock(
353
+ self.final_dim,
354
+ self.num_rotation_classes * 3
355
+ + self.num_grip_classes
356
+ + self.num_collision_classes,
357
+ None,
358
+ None,
359
+ )
360
+
361
+ self.left_rot_grip_collision_ff = DenseBlock(
362
+ self.final_dim,
363
+ self.num_rotation_classes * 3
364
+ + self.num_grip_classes
365
+ + self.num_collision_classes,
366
+ None,
367
+ None,
368
+ )
369
+
370
+ def encode_text(self, x):
371
+ with torch.no_grad():
372
+ text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
373
+
374
+ text_feat = text_feat.detach()
375
+ text_emb = text_emb.detach()
376
+ text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
377
+ return text_feat, text_emb
378
+
379
+ def forward(
380
+ self,
381
+ ins,
382
+ proprio,
383
+ lang_goal_emb,
384
+ lang_token_embs,
385
+ prev_layer_voxel_grid,
386
+ bounds,
387
+ prev_layer_bounds,
388
+ mask=None,
389
+ ):
390
+ # preprocess input
391
+ ins_numpy = str(ins.cpu().numpy())
392
+ d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
393
+
394
+ # aggregated features from 1st softmax and maxpool for MLP decoders
395
+ feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
396
+
397
+ # patchify input (5x5x5 patches)
398
+ ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
399
+
400
+ b, c, d, h, w, device = *ins.shape, ins.device
401
+ axis = [d, h, w]
402
+ assert (
403
+ len(axis) == self.input_axis
404
+ ), "input must have the same number of axis as input_axis"
405
+ # concat proprio
406
+ if self.low_dim_size > 0:
407
+ p = self.proprio_preprocess(proprio) # [B,8] -> [B,64]
408
+ p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
409
+ ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
410
+
411
+ # language ablation
412
+ if self.no_language:
413
+ lang_goal_emb = torch.zeros_like(lang_goal_emb)
414
+ lang_token_embs = torch.zeros_like(lang_token_embs)
415
+
416
+ # option 1: tile and concat lang goal to input
417
+ if self.lang_fusion_type == "concat":
418
+ lang_emb = lang_goal_emb
419
+ lang_emb = lang_emb.to(dtype=ins.dtype)
420
+ l = self.lang_preprocess(lang_emb)
421
+ l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
422
+ ins = torch.cat([ins, l], dim=1)
423
+
424
+ # channel last
425
+ ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
426
+ # add pos encoding to grid
427
+ if not self.pos_encoding_with_lang:
428
+ ins = ins + self.pos_encoding
429
+
430
+ ######################## NOTE #############################
431
+ # NOTE: If you add positional encodings ^here the lang embs
432
+ # won't have positional encodings. I accidently forgot
433
+ # to turn this off for all the experiments in the paper.
434
+ # So I guess those models were using language embs
435
+ # as a bag of words :( But it doesn't matter much for
436
+ # RLBench tasks since we don't test for novel instructions
437
+ # at test time anyway. The recommend way is to add
438
+ # positional encodings to the final input sequence
439
+ # fed into the Perceiver Transformer, as done below
440
+ # (and also in the Colab tutorial).
441
+ ###########################################################
442
+
443
+ # concat to channels of and flatten axis
444
+ queries_orig_shape = ins.shape
445
+
446
+ # rearrange input to be channel last
447
+ ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
448
+ ins_wo_prev_layers = ins
449
+
450
+ # option 2: add lang token embs as a sequence
451
+ if self.anybimanual:
452
+ if self.lang_fusion_type == "seq":
453
+ l = self.lang_preprocess(lang_token_embs) # [B,77,512] -> [B,77,128]
454
+ mask_right, mask_left = self.visual_aligner(ins)
455
+ L_voxel = symmetric_kl_divergence(mask_left, mask_right)
456
+ right_skill = self.skill_manager(mask_right, l)
457
+ left_skill = self.skill_manager(mask_left, l)
458
+ right_skill = self.lang_preprocess(right_skill)
459
+ left_skill = self.lang_preprocess(left_skill)
460
+ L_skill = (
461
+ l1_norm(left_skill) + l1_norm(right_skill) +
462
+ 0.01 * (l2_1_norm(left_skill) + l2_1_norm(right_skill))
463
+ )
464
+ l_right = torch.cat((right_skill, l), dim=1)
465
+ ins_right = torch.cat((l_right, mask_right), dim=1)
466
+ l_left = torch.cat((left_skill, l), dim=1)
467
+ ins_left = torch.cat((l_left, mask_left), dim=1)
468
+ if self.pos_encoding_with_lang:
469
+ ins_right = ins_right + self.pos_encoding
470
+ ins_left = ins_left + self.pos_encoding
471
+ else:
472
+ if self.lang_fusion_type == "seq":
473
+ # print(lang_token_embs.requires_grad) # False
474
+ l = self.lang_preprocess(lang_token_embs) # [B,77,512] -> [B,77,128]
475
+ # print(l.requires_grad) # True
476
+ ins = torch.cat((l, ins), dim=1)
477
+ # add pos encoding to language + flattened grid (the recommended way)
478
+ if self.pos_encoding_with_lang:
479
+ ins = ins + self.pos_encoding
480
+
481
+ # batchify latents
482
+ if self.anybimanual:
483
+ x = repeat(self.latents, "n d -> b n d", b=b)
484
+ cross_attn_right, cross_attn_left, cross_ff_right, cross_ff_left = self.cross_attend_blocks
485
+ else:
486
+ x = repeat(self.latents, "n d -> b n d", b=b)
487
+ cross_attn, cross_ff_right, cross_ff_left = self.cross_attend_blocks
488
+
489
+ if self.anybimanual:
490
+ ins_r = torch.cat((l_right, ins),dim=1)
491
+ ins_l = torch.cat((l_left, ins), dim=1)
492
+ ins_right = torch.cat((ins_right, ins_r), dim=2)
493
+ ins_left = torch.cat((ins_left, ins_l), dim=2)
494
+ for it in range(self.iterations):
495
+ # encoder cross attention
496
+ if self.anybimanual:
497
+ x_r, x_l = x.chunk(2, dim=1)
498
+ x_right = cross_attn_right(x_r, context=ins_right, mask=mask) + x_r
499
+ x_left = cross_attn_left(x_l, context=ins_left, mask=mask) + x_l
500
+ else:
501
+ x = cross_attn(x, context=ins, mask=mask) + x
502
+ x_right, x_left = x.chunk(2, dim=1)
503
+ x_right = cross_ff_right(x_right) + x_right
504
+ x_left = cross_ff_left(x_left) + x_left
505
+ # self-attention layers
506
+ for self_attn_right, self_ff_right, self_attn_left, self_ff_left in self.layers:
507
+
508
+ x_right = self_attn_right(x_right) + x_right
509
+ x_right = self_ff_right(x_right) + x_right
510
+
511
+ x_left = self_attn_left(x_left) + x_left
512
+ x_left = self_ff_left(x_left) + x_left
513
+
514
+ x = torch.concat([x_right, x_left], dim=1)
515
+ x = self.combined_latent_attn(x) + x
516
+ x = self.combined_latent_ff(x) + x
517
+
518
+ x_right, x_left = x.chunk(2, dim=1)
519
+
520
+ # decoder cross attention
521
+ if self.anybimanual:
522
+ latents_right = self.decoder_cross_attn_right(ins_right, context=x_right)
523
+ latents_left = self.decoder_cross_attn_left(ins_left, context=x_left)
524
+ if self.lang_fusion_type == "seq":
525
+ latents_right = latents_right[:, l_right.shape[1] :]
526
+ latents_left = latents_left[:, l_left.shape[1] :]
527
+ else:
528
+ latents_right = self.decoder_cross_attn_right(ins, context=x_right)
529
+ latents_left = self.decoder_cross_attn_left(ins, context=x_left)
530
+ if self.lang_fusion_type == "seq":
531
+ latents_right = latents_right[:, l.shape[1] :]
532
+ latents_left = latents_left[:, l.shape[1] :]
533
+
534
+ # crop out the language part of the output sequence
535
+
536
+ # reshape back to voxel grid
537
+ latents_right = latents_right.view(
538
+ b, *queries_orig_shape[1:-1], latents_right.shape[-1]
539
+ ) # [B,20,20,20,64]
540
+ latents_right = rearrange(latents_right, "b ... d -> b d ...") # [B,64,20,20,20]
541
+
542
+ # reshape back to voxel grid
543
+ latents_left = latents_left.view(
544
+ b, *queries_orig_shape[1:-1], latents_left.shape[-1]
545
+ ) # [B,20,20,20,64]
546
+ latents_left = rearrange(latents_left, "b ... d -> b d ...") # [B,64,20,20,20]
547
+
548
+ # aggregated features from 2nd softmax and maxpool for MLP decoders
549
+
550
+ feats_right = feats.copy()
551
+ feats_left = feats
552
+
553
+
554
+ feats_right.extend(
555
+ [self.ss1(latents_right.contiguous()), self.global_maxp(latents_right).view(b, -1)]
556
+ )
557
+ feats_left.extend(
558
+ [self.ss1(latents_left.contiguous()), self.global_maxp(latents_left).view(b, -1)]
559
+ )
560
+
561
+ # upsample
562
+ u0_right = self.up0(latents_right)
563
+ u0_left = self.up0(latents_left)
564
+
565
+ # ablations
566
+ if self.no_skip_connection:
567
+ u_right = self.final(u0_right)
568
+ u_left = self.final(u0_left)
569
+ elif self.no_perceiver:
570
+ u_right = self.final(d0)
571
+ u_left = self.final(d0)
572
+ else:
573
+ u_right = self.final(torch.cat([d0, u0_right], dim=1))
574
+ u_left = self.final(torch.cat([d0, u0_left], dim=1))
575
+
576
+ # translation decoder
577
+ right_trans = self.right_trans_decoder(u_right)
578
+ left_trans = self.left_trans_decoder(u_left)
579
+
580
+ # rotation, gripper, and collision MLPs
581
+ rot_and_grip_out = None
582
+ if self.num_rotation_classes > 0:
583
+ feats_right.extend(
584
+ [self.ss_final(u_right.contiguous()), self.global_maxp(u_right).view(b, -1)]
585
+ )
586
+
587
+ right_dense0 = self.right_dense0(torch.cat(feats_right, dim=1))
588
+ right_dense1 = self.right_dense1(right_dense0) # [B,72*3+2+2]
589
+
590
+ right_rot_and_grip_collision_out = self.right_rot_grip_collision_ff(
591
+ right_dense1
592
+ )
593
+ right_rot_and_grip_out = right_rot_and_grip_collision_out[
594
+ :, : -self.num_collision_classes
595
+ ]
596
+ right_collision_out = right_rot_and_grip_collision_out[
597
+ :, -self.num_collision_classes :
598
+ ]
599
+
600
+ feats_left.extend(
601
+ [self.ss_final(u_left.contiguous()), self.global_maxp(u_left).view(b, -1)]
602
+ )
603
+
604
+ left_dense0 = self.left_dense0(torch.cat(feats_left, dim=1))
605
+ left_dense1 = self.left_dense1(left_dense0) # [B,72*3+2+2]
606
+
607
+ left_rot_and_grip_collision_out = self.left_rot_grip_collision_ff(
608
+ left_dense1
609
+ )
610
+ left_rot_and_grip_out = left_rot_and_grip_collision_out[
611
+ :, : -self.num_collision_classes
612
+ ]
613
+ left_collision_out = left_rot_and_grip_collision_out[
614
+ :, -self.num_collision_classes :
615
+ ]
616
+
617
+ if not self.anybimanual:
618
+ L_skill = 0
619
+ L_voxel = 0
620
+ return (
621
+ right_trans,
622
+ right_rot_and_grip_out,
623
+ right_collision_out,
624
+ left_trans,
625
+ left_rot_and_grip_out,
626
+ left_collision_out
627
+ ), L_skill,L_voxel,
628
+
third_party/AnyBimanual/agents/peract_bimanual/qattention_peract_bc_agent.py ADDED
@@ -0,0 +1,1317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import os
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ import torch.nn.functional as F
10
+ from torchvision import transforms
11
+ from pytorch3d import transforms as torch3d_tf
12
+ from yarr.agents.agent import (
13
+ Agent,
14
+ ActResult,
15
+ ScalarSummary,
16
+ HistogramSummary,
17
+ ImageSummary,
18
+ Summary,
19
+ )
20
+ import io
21
+ import PIL.Image as Image
22
+ import matplotlib.pyplot as plt
23
+ from helpers import utils
24
+ from helpers.utils import visualise_voxel, stack_on_channel
25
+ from voxel import augmentation_ab
26
+ from voxel.voxel_grid import VoxelGrid
27
+ from voxel.augmentation import apply_se3_augmentation
28
+ from einops import rearrange
29
+ from helpers.clip.core.clip import build_model, load_clip
30
+
31
+ import transformers
32
+ from helpers.optim.lamb import Lamb
33
+ import wandb
34
+ from termcolor import colored, cprint
35
+ from torch.nn.parallel import DistributedDataParallel as DDP
36
+ NAME = "QAttentionAgent"
37
+ import plotly.graph_objects as go
38
+
39
+ class QFunction(nn.Module):
40
+ def __init__(
41
+ self,
42
+ perceiver_encoder: nn.Module,
43
+ voxelizer: VoxelGrid,
44
+ bounds_offset: float,
45
+ rotation_resolution: float,
46
+ device,
47
+ training,
48
+ ):
49
+ super(QFunction, self).__init__()
50
+ self._rotation_resolution = rotation_resolution
51
+ self._voxelizer = voxelizer
52
+ self._bounds_offset = bounds_offset
53
+ self._qnet = perceiver_encoder.to(device)
54
+ # distributed training
55
+ if training:
56
+ self._qnet = DDP(self._qnet, device_ids=[device], find_unused_parameters=True)
57
+
58
+ def _argmax_3d(self, tensor_orig):
59
+ b, c, d, h, w = tensor_orig.shape # c will be one
60
+ idxs = tensor_orig.view(b, c, -1).argmax(-1)
61
+ indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
62
+ return indices
63
+
64
+ def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
65
+ coords = self._argmax_3d(q_trans)
66
+ rot_and_grip_indicies = None
67
+ ignore_collision = None
68
+ if q_rot_grip is not None:
69
+ q_rot = torch.stack(
70
+ torch.split(
71
+ q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
72
+ ),
73
+ dim=1,
74
+ )
75
+ rot_and_grip_indicies = torch.cat(
76
+ [
77
+ q_rot[:, 0:1].argmax(-1),
78
+ q_rot[:, 1:2].argmax(-1),
79
+ q_rot[:, 2:3].argmax(-1),
80
+ q_rot_grip[:, -2:].argmax(-1, keepdim=True),
81
+ ],
82
+ -1,
83
+ )
84
+ ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
85
+ return coords, rot_and_grip_indicies, ignore_collision
86
+
87
+ def forward(
88
+ self,
89
+ rgb_pcd,
90
+ proprio,
91
+ pcd,
92
+ lang_goal_emb,
93
+ lang_token_embs,
94
+ bounds=None,
95
+ prev_bounds=None,
96
+ prev_layer_voxel_grid=None,
97
+ ):
98
+ # rgb_pcd will be list of list (list of [rgb, pcd])
99
+ b = rgb_pcd[0][0].shape[0]
100
+ pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
101
+
102
+ # flatten RGBs and Pointclouds
103
+ rgb = [rp[0] for rp in rgb_pcd]
104
+ feat_size = rgb[0].shape[1]
105
+ flat_imag_features = torch.cat(
106
+ [p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
107
+ )
108
+
109
+ # construct voxel grid
110
+ voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
111
+ pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
112
+ )
113
+
114
+ # swap to channels fist
115
+ voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
116
+
117
+ # print(voxel_grid.shape) # [b, 10, 100, 100, 100]
118
+ # batch bounds if necessary
119
+ if bounds.shape[0] != b:
120
+ bounds = bounds.repeat(b, 1)
121
+ # print(lang_goal_emb.shape) # [B, 1024]
122
+ # forward pass
123
+
124
+ #TO DO: return more information
125
+ split_pred, L_skill, L_voxel = self._qnet(
126
+ voxel_grid,
127
+ proprio,
128
+ lang_goal_emb,
129
+ lang_token_embs,
130
+ prev_layer_voxel_grid,
131
+ bounds,
132
+ prev_bounds,
133
+ )
134
+ return split_pred, voxel_grid, L_skill, L_voxel
135
+
136
+
137
+ class QAttentionPerActBCAgent(Agent):
138
+ def __init__(
139
+ self,
140
+ layer: int,
141
+ coordinate_bounds: list,
142
+ perceiver_encoder: nn.Module,
143
+ camera_names: list,
144
+ batch_size: int,
145
+ voxel_size: int,
146
+ bounds_offset: float,
147
+ voxel_feature_size: int,
148
+ image_crop_size: int,
149
+ num_rotation_classes: int,
150
+ rotation_resolution: float,
151
+ lr: float = 0.0001,
152
+ lr_scheduler: bool = False,
153
+ training_iterations: int = 100000,
154
+ num_warmup_steps: int = 20000,
155
+ trans_loss_weight: float = 1.0,
156
+ rot_loss_weight: float = 1.0,
157
+ grip_loss_weight: float = 1.0,
158
+ collision_loss_weight: float = 1.0,
159
+ include_low_dim_state: bool = False,
160
+ image_resolution: list = None,
161
+ lambda_weight_l2: float = 0.0,
162
+ transform_augmentation: bool = True,
163
+ transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
164
+ transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
165
+ transform_augmentation_rot_resolution: int = 5,
166
+ optimizer_type: str = "adam",
167
+ num_devices: int = 1,
168
+ anybimanual = False,
169
+ load_exists_weights = False,
170
+ frozen = False,
171
+ cfg = None,
172
+ aug_type = "standard",
173
+ ):
174
+ self.frozen = frozen
175
+ self.load = load_exists_weights
176
+ self._layer = layer
177
+ self._coordinate_bounds = coordinate_bounds
178
+ self._perceiver_encoder = perceiver_encoder
179
+ self._voxel_feature_size = voxel_feature_size
180
+ self._bounds_offset = bounds_offset
181
+ self._image_crop_size = image_crop_size
182
+ self._lr = lr
183
+ self._lr_scheduler = lr_scheduler
184
+ self._training_iterations = training_iterations
185
+ self._num_warmup_steps = num_warmup_steps
186
+ self._trans_loss_weight = trans_loss_weight
187
+ self._rot_loss_weight = rot_loss_weight
188
+ self._grip_loss_weight = grip_loss_weight
189
+ self._collision_loss_weight = collision_loss_weight
190
+ self._include_low_dim_state = include_low_dim_state
191
+ self._image_resolution = image_resolution or [128, 128]
192
+ self._voxel_size = voxel_size
193
+ self._camera_names = camera_names
194
+ self._num_cameras = len(camera_names)
195
+ self._batch_size = batch_size
196
+ self._lambda_weight_l2 = lambda_weight_l2
197
+ self._transform_augmentation = transform_augmentation
198
+ self._transform_augmentation_xyz = torch.from_numpy(
199
+ np.array(transform_augmentation_xyz)
200
+ )
201
+ self._transform_augmentation_rpy = transform_augmentation_rpy
202
+ self._transform_augmentation_rot_resolution = (
203
+ transform_augmentation_rot_resolution
204
+ )
205
+ self._optimizer_type = optimizer_type
206
+ self._num_devices = num_devices
207
+ self._num_rotation_classes = num_rotation_classes
208
+ self._rotation_resolution = rotation_resolution
209
+
210
+ self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
211
+ self._name = NAME + "_layer" + str(self._layer)
212
+ self.anybimanual = anybimanual
213
+ self.aug_type = aug_type
214
+ self.cfg = cfg
215
+ def build(self, training: bool, device: torch.device = None):
216
+ self._training = training
217
+
218
+ if device is None:
219
+ device = torch.device("cpu")
220
+
221
+ self._device = device
222
+
223
+ self._voxelizer = VoxelGrid(
224
+ coord_bounds=self._coordinate_bounds,
225
+ voxel_size=self._voxel_size,
226
+ device=device,
227
+ batch_size=self._batch_size if training else 1,
228
+ feature_size=self._voxel_feature_size,
229
+ max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
230
+ )
231
+
232
+ self._q = (
233
+ QFunction(
234
+ self._perceiver_encoder,
235
+ self._voxelizer,
236
+ self._bounds_offset,
237
+ self._rotation_resolution,
238
+ device,
239
+ training,
240
+ )
241
+ .to(device)
242
+ .train(training)
243
+ )
244
+
245
+ grid_for_crop = (
246
+ torch.arange(0, self._image_crop_size, device=device)
247
+ .unsqueeze(0)
248
+ .repeat(self._image_crop_size, 1)
249
+ .unsqueeze(-1)
250
+ )
251
+ self._grid_for_crop = torch.cat(
252
+ [grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
253
+ ).unsqueeze(0)
254
+
255
+ self._coordinate_bounds = torch.tensor(
256
+ self._coordinate_bounds, device=device
257
+ ).unsqueeze(0)
258
+
259
+ if self._training:
260
+ # optimizer
261
+ if self._optimizer_type == "lamb":
262
+ self._optimizer = Lamb(
263
+ self._q.parameters(),
264
+ lr=self._lr,
265
+ weight_decay=self._lambda_weight_l2,
266
+ betas=(0.9, 0.999),
267
+ adam=False,
268
+ )
269
+ elif self._optimizer_type == "adam":
270
+ self._optimizer = torch.optim.Adam(
271
+ self._q.parameters(),
272
+ lr=self._lr,
273
+ weight_decay=self._lambda_weight_l2,
274
+ )
275
+ else:
276
+ raise Exception("Unknown optimizer type")
277
+
278
+ # learning rate scheduler
279
+ if self._lr_scheduler:
280
+ self._scheduler = (
281
+ transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
282
+ self._optimizer,
283
+ num_warmup_steps=self._num_warmup_steps,
284
+ num_training_steps=self._training_iterations,
285
+ num_cycles=self._training_iterations // 10000,
286
+ )
287
+ )
288
+
289
+ # one-hot zero tensors
290
+ self._action_trans_one_hot_zeros = torch.zeros(
291
+ (
292
+ self._batch_size,
293
+ 1,
294
+ self._voxel_size,
295
+ self._voxel_size,
296
+ self._voxel_size,
297
+ ),
298
+ dtype=int,
299
+ device=device,
300
+ )
301
+ self._action_rot_x_one_hot_zeros = torch.zeros(
302
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
303
+ )
304
+ self._action_rot_y_one_hot_zeros = torch.zeros(
305
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
306
+ )
307
+ self._action_rot_z_one_hot_zeros = torch.zeros(
308
+ (self._batch_size, self._num_rotation_classes), dtype=int, device=device
309
+ )
310
+ self._action_grip_one_hot_zeros = torch.zeros(
311
+ (self._batch_size, 2), dtype=int, device=device
312
+ )
313
+ self._action_ignore_collisions_one_hot_zeros = torch.zeros(
314
+ (self._batch_size, 2), dtype=int, device=device
315
+ )
316
+
317
+ # print total params
318
+ logging.info(
319
+ "# Q Params: %d"
320
+ % sum(
321
+ p.numel()
322
+ for name, p in self._q.named_parameters()
323
+ if p.requires_grad and "clip" not in name
324
+ )
325
+ )
326
+ # for name, p in self._q.named_parameters():
327
+ # print(f"Param: {name}, requires_grad: {p.requires_grad}")
328
+ else:
329
+ for param in self._q.parameters():
330
+ param.requires_grad = False
331
+
332
+ # load CLIP for encoding language goals during evaluation
333
+ model, _ = load_clip("RN50", jit=False)
334
+ self._clip_rn50 = build_model(model.state_dict())
335
+ self._clip_rn50 = self._clip_rn50.float().to(device)
336
+ self._clip_rn50.eval()
337
+ del model
338
+
339
+ self._voxelizer.to(device)
340
+ self._q.to(device)
341
+
342
+ def _extract_crop(self, pixel_action, observation):
343
+ # Pixel action will now be (B, 2)
344
+ # observation = stack_on_channel(observation)
345
+ h = observation.shape[-1]
346
+ top_left_corner = torch.clamp(
347
+ pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
348
+ )
349
+ grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
350
+ grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
351
+ # Used for cropping the images across a batch
352
+ # swap fro y x, to x, y
353
+ grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
354
+ crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
355
+ return crop
356
+
357
+ def _preprocess_inputs(self, replay_sample):
358
+ obs = []
359
+ pcds = []
360
+ rgbs = []
361
+ self._crop_summary = []
362
+ for n in self._camera_names:
363
+ rgb = replay_sample["%s_rgb" % n]
364
+ pcd = replay_sample["%s_point_cloud" % n]
365
+
366
+ obs.append([rgb, pcd])
367
+ pcds.append(pcd)
368
+ rgbs.append(rgb)
369
+ return obs, pcds, rgbs
370
+
371
+ def _act_preprocess_inputs(self, observation):
372
+ obs, pcds = [], []
373
+ for n in self._camera_names:
374
+ rgb = observation["%s_rgb" % n]
375
+ pcd = observation["%s_point_cloud" % n]
376
+
377
+ obs.append([rgb, pcd])
378
+ pcds.append(pcd)
379
+ return obs, pcds
380
+
381
+ def _get_value_from_voxel_index(self, q, voxel_idx):
382
+ b, c, d, h, w = q.shape
383
+ q_trans_flat = q.view(b, c, d * h * w)
384
+ flat_indicies = (
385
+ voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
386
+ )[:, None].int()
387
+ highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
388
+ chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
389
+ ..., 0
390
+ ] # (B, trans + rot + grip)
391
+ return chosen_voxel_values
392
+
393
+ def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
394
+ q_rot = torch.stack(
395
+ torch.split(
396
+ rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
397
+ ),
398
+ dim=1,
399
+ ) # B, 3, 72
400
+ q_grip = rot_grip_q[:, -2:]
401
+ rot_and_grip_values = torch.cat(
402
+ [
403
+ q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
404
+ q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
405
+ q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
406
+ q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
407
+ ],
408
+ -1,
409
+ )
410
+ return rot_and_grip_values
411
+
412
+ def _celoss(self, pred, labels):
413
+ return self._cross_entropy_loss(pred, labels.argmax(-1))
414
+
415
+ def _softmax_q_trans(self, q):
416
+ q_shape = q.shape
417
+ return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
418
+
419
+ def _softmax_q_rot_grip(self, q_rot_grip):
420
+ q_rot_x_flat = q_rot_grip[
421
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
422
+ ]
423
+ q_rot_y_flat = q_rot_grip[
424
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
425
+ ]
426
+ q_rot_z_flat = q_rot_grip[
427
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
428
+ ]
429
+ q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
430
+
431
+ q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
432
+ q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
433
+ q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
434
+ q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
435
+
436
+ return torch.cat(
437
+ [
438
+ q_rot_x_flat_softmax,
439
+ q_rot_y_flat_softmax,
440
+ q_rot_z_flat_softmax,
441
+ q_grip_flat_softmax,
442
+ ],
443
+ dim=1,
444
+ )
445
+
446
+ def _softmax_ignore_collision(self, q_collision):
447
+ q_collision_softmax = F.softmax(q_collision, dim=1)
448
+ return q_collision_softmax
449
+
450
+ def update(self, step: int, replay_sample: dict) -> dict:
451
+ if step > 50:
452
+ for name, param in self._q.named_parameters():
453
+ if 'fc1_right' in name:
454
+ param.requires_grad = False
455
+ if 'fc1_left' in name:
456
+ param.requires_grad = False
457
+ right_action_trans = replay_sample["right_trans_action_indicies"][
458
+ ..., self._layer * 3 : self._layer * 3 + 3
459
+ ].int()
460
+ right_action_rot_grip = replay_sample["right_rot_grip_action_indicies"].int()
461
+ right_action_gripper_pose = replay_sample["right_gripper_pose"]
462
+ right_action_ignore_collisions = replay_sample["right_ignore_collisions"].int()
463
+
464
+ left_action_trans = replay_sample["left_trans_action_indicies"][
465
+ ..., self._layer * 3 : self._layer * 3 + 3
466
+ ].int()
467
+ left_action_rot_grip = replay_sample["left_rot_grip_action_indicies"].int()
468
+ left_action_gripper_pose = replay_sample["left_gripper_pose"]
469
+ left_action_ignore_collisions = replay_sample["left_ignore_collisions"].int()
470
+
471
+ lang_goal_emb = replay_sample["lang_goal_emb"].float()
472
+ lang_token_embs = replay_sample["lang_token_embs"].float()
473
+ prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
474
+ prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
475
+ device = self._device
476
+
477
+ rank = device
478
+ bounds = self._coordinate_bounds.to(device)
479
+ if self._layer > 0:
480
+ right_cp = replay_sample[
481
+ "right_attention_coordinate_layer_%d" % (self._layer - 1)
482
+ ]
483
+
484
+ left_cp = replay_sample[
485
+ "left_attention_coordinate_layer_%d" % (self._layer - 1)
486
+ ]
487
+
488
+ right_bounds = torch.cat(
489
+ [right_cp - self._bounds_offset, right_cp + self._bounds_offset], dim=1
490
+ )
491
+ left_bounds = torch.cat(
492
+ [left_cp - self._bounds_offset, left_cp + self._bounds_offset], dim=1
493
+ )
494
+
495
+ else:
496
+ right_bounds = bounds
497
+ left_bounds = bounds
498
+
499
+ right_proprio = None
500
+ left_proprio = None
501
+ if self._include_low_dim_state:
502
+ right_proprio = replay_sample["right_low_dim_state"]
503
+ left_proprio = replay_sample["left_low_dim_state"]
504
+
505
+ # ..TODO::
506
+ # Can we add the coordinates of both robots?
507
+ #
508
+
509
+ obs, pcd, rgbs = self._preprocess_inputs(replay_sample)
510
+
511
+ # batch size
512
+ bs = pcd[0].shape[0]
513
+
514
+ # We can move the point cloud w.r.t to the other robot's cooridinate system
515
+ # similar to apply_se3_augmentation
516
+ # SE(3) augmentation of point clouds and actions
517
+ if self._transform_augmentation:
518
+ from voxel import augmentation, augmentation_ab
519
+ if self.aug_type == "ab":
520
+ (
521
+ right_action_trans,
522
+ right_action_rot_grip,
523
+ left_action_trans,
524
+ left_action_rot_grip,
525
+ pcd,
526
+ ) = augmentation_ab.bimanual_apply_se3_augmentation(
527
+ pcd,
528
+ right_action_gripper_pose,
529
+ right_action_trans,
530
+ right_action_rot_grip,
531
+ left_action_gripper_pose,
532
+ left_action_trans,
533
+ left_action_rot_grip,
534
+ bounds,
535
+ self._layer,
536
+ self._transform_augmentation_xyz,
537
+ self._transform_augmentation_rpy,
538
+ self._transform_augmentation_rot_resolution,
539
+ self._voxel_size,
540
+ self._rotation_resolution,
541
+ self._device,
542
+ )
543
+ else:
544
+ (
545
+ right_action_trans,
546
+ right_action_rot_grip,
547
+ left_action_trans,
548
+ left_action_rot_grip,
549
+ pcd,
550
+ ) = augmentation.bimanual_apply_se3_augmentation(
551
+ pcd,
552
+ right_action_gripper_pose,
553
+ right_action_trans,
554
+ right_action_rot_grip,
555
+ left_action_gripper_pose,
556
+ left_action_trans,
557
+ left_action_rot_grip,
558
+ bounds,
559
+ self._layer,
560
+ self._transform_augmentation_xyz,
561
+ self._transform_augmentation_rpy,
562
+ self._transform_augmentation_rot_resolution,
563
+ self._voxel_size,
564
+ self._rotation_resolution,
565
+ self._device,
566
+ )
567
+ else:
568
+ right_action_trans = right_action_trans.int()
569
+ left_action_trans = left_action_trans.int()
570
+
571
+ proprio = torch.cat((right_proprio, left_proprio), dim=1)
572
+ right_action = (
573
+ right_action_trans,
574
+ right_action_rot_grip,
575
+ right_action_ignore_collisions,
576
+ )
577
+ left_action = (
578
+ left_action_trans,
579
+ left_action_rot_grip,
580
+ left_action_ignore_collisions,
581
+ )
582
+ # forward pass
583
+ q, voxel_grid, L_skill, L_voxel = self._q(
584
+ obs,
585
+ proprio,
586
+ pcd,
587
+ lang_goal_emb,
588
+ lang_token_embs,
589
+ bounds,
590
+ prev_layer_bounds,
591
+ prev_layer_voxel_grid,
592
+
593
+ )
594
+
595
+ (
596
+ right_q_trans,
597
+ right_q_rot_grip,
598
+ right_q_collision,
599
+ left_q_trans,
600
+ left_q_rot_grip,
601
+ left_q_collision,
602
+ ) = q
603
+
604
+ # argmax to choose best action
605
+ (
606
+ right_coords,
607
+ right_rot_and_grip_indicies,
608
+ right_ignore_collision_indicies,
609
+ ) = self._q.choose_highest_action(
610
+ right_q_trans, right_q_rot_grip, right_q_collision
611
+ )
612
+
613
+ (
614
+ left_coords,
615
+ left_rot_and_grip_indicies,
616
+ left_ignore_collision_indicies,
617
+ ) = self._q.choose_highest_action(
618
+ left_q_trans, left_q_rot_grip, left_q_collision
619
+ )
620
+
621
+ right_q_trans_loss, right_q_rot_loss, right_q_grip_loss, right_q_collision_loss = 0.0, 0.0, 0.0, 0.0
622
+ left_q_trans_loss, left_q_rot_loss, left_q_grip_loss, left_q_collision_loss = 0.0, 0.0, 0.0, 0.0
623
+
624
+ # translation one-hot
625
+ right_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
626
+ left_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
627
+ for b in range(bs):
628
+ right_gt_coord = right_action_trans[b, :].int()
629
+ right_action_trans_one_hot[
630
+ b, :, right_gt_coord[0], right_gt_coord[1], right_gt_coord[2]
631
+ ] = 1
632
+ left_gt_coord = left_action_trans[b, :].int()
633
+ left_action_trans_one_hot[
634
+ b, :, left_gt_coord[0], left_gt_coord[1], left_gt_coord[2]
635
+ ] = 1
636
+
637
+ # translation loss
638
+ right_q_trans_flat = right_q_trans.view(bs, -1)
639
+ right_action_trans_one_hot_flat = right_action_trans_one_hot.view(bs, -1)
640
+ right_q_trans_loss = self._celoss(
641
+ right_q_trans_flat, right_action_trans_one_hot_flat
642
+ )
643
+ left_q_trans_flat = left_q_trans.view(bs, -1)
644
+ left_action_trans_one_hot_flat = left_action_trans_one_hot.view(bs, -1)
645
+ left_q_trans_loss = self._celoss(
646
+ left_q_trans_flat, left_action_trans_one_hot_flat
647
+ )
648
+
649
+ q_trans_loss = right_q_trans_loss + left_q_trans_loss
650
+
651
+ with_rot_and_grip = (
652
+ len(right_rot_and_grip_indicies) > 0 and len(left_rot_and_grip_indicies) > 0
653
+ )
654
+ if with_rot_and_grip:
655
+ # rotation, gripper, and collision one-hots
656
+ right_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
657
+ right_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
658
+ right_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
659
+ right_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
660
+ right_action_ignore_collisions_one_hot = (
661
+ self._action_ignore_collisions_one_hot_zeros.clone()
662
+ )
663
+
664
+ left_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
665
+ left_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
666
+ left_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
667
+ left_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
668
+ left_action_ignore_collisions_one_hot = (
669
+ self._action_ignore_collisions_one_hot_zeros.clone()
670
+ )
671
+
672
+ for b in range(bs):
673
+ right_gt_rot_grip = right_action_rot_grip[b, :].int()
674
+ right_action_rot_x_one_hot[b, right_gt_rot_grip[0]] = 1
675
+ right_action_rot_y_one_hot[b, right_gt_rot_grip[1]] = 1
676
+ right_action_rot_z_one_hot[b, right_gt_rot_grip[2]] = 1
677
+ right_action_grip_one_hot[b, right_gt_rot_grip[3]] = 1
678
+
679
+ right_gt_ignore_collisions = right_action_ignore_collisions[b, :].int()
680
+ right_action_ignore_collisions_one_hot[
681
+ b, right_gt_ignore_collisions[0]
682
+ ] = 1
683
+
684
+ left_gt_rot_grip = left_action_rot_grip[b, :].int()
685
+ left_action_rot_x_one_hot[b, left_gt_rot_grip[0]] = 1
686
+ left_action_rot_y_one_hot[b, left_gt_rot_grip[1]] = 1
687
+ left_action_rot_z_one_hot[b, left_gt_rot_grip[2]] = 1
688
+ left_action_grip_one_hot[b, left_gt_rot_grip[3]] = 1
689
+
690
+ left_gt_ignore_collisions = left_action_ignore_collisions[b, :].int()
691
+ left_action_ignore_collisions_one_hot[
692
+ b, left_gt_ignore_collisions[0]
693
+ ] = 1
694
+
695
+ # flatten predictions
696
+ right_q_rot_x_flat = right_q_rot_grip[
697
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
698
+ ]
699
+ right_q_rot_y_flat = right_q_rot_grip[
700
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
701
+ ]
702
+ right_q_rot_z_flat = right_q_rot_grip[
703
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
704
+ ]
705
+ right_q_grip_flat = right_q_rot_grip[:, 3 * self._num_rotation_classes :]
706
+ right_q_ignore_collisions_flat = right_q_collision
707
+
708
+ left_q_rot_x_flat = left_q_rot_grip[
709
+ :, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
710
+ ]
711
+ left_q_rot_y_flat = left_q_rot_grip[
712
+ :, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
713
+ ]
714
+ left_q_rot_z_flat = left_q_rot_grip[
715
+ :, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
716
+ ]
717
+ left_q_grip_flat = left_q_rot_grip[:, 3 * self._num_rotation_classes :]
718
+ left_q_ignore_collisions_flat = left_q_collision
719
+
720
+
721
+ # rotation loss
722
+ right_q_rot_loss += self._celoss(right_q_rot_x_flat, right_action_rot_x_one_hot)
723
+ right_q_rot_loss += self._celoss(right_q_rot_y_flat, right_action_rot_y_one_hot)
724
+ right_q_rot_loss += self._celoss(right_q_rot_z_flat, right_action_rot_z_one_hot)
725
+
726
+ left_q_rot_loss += self._celoss(left_q_rot_x_flat, left_action_rot_x_one_hot)
727
+ left_q_rot_loss += self._celoss(left_q_rot_y_flat, left_action_rot_y_one_hot)
728
+ left_q_rot_loss += self._celoss(left_q_rot_z_flat, left_action_rot_z_one_hot)
729
+
730
+ # gripper loss
731
+ right_q_grip_loss += self._celoss(right_q_grip_flat, right_action_grip_one_hot)
732
+ left_q_grip_loss += self._celoss(left_q_grip_flat, left_action_grip_one_hot)
733
+
734
+ # collision loss
735
+ right_q_collision_loss += self._celoss(
736
+ right_q_ignore_collisions_flat, right_action_ignore_collisions_one_hot
737
+ )
738
+ left_q_collision_loss += self._celoss(
739
+ left_q_ignore_collisions_flat, left_action_ignore_collisions_one_hot
740
+ )
741
+
742
+
743
+ q_trans_loss = right_q_trans_loss + left_q_trans_loss
744
+ q_rot_loss = right_q_rot_loss + left_q_rot_loss
745
+ q_grip_loss = right_q_grip_loss + left_q_grip_loss
746
+ q_collision_loss = right_q_collision_loss + left_q_collision_loss
747
+
748
+ combined_losses = (
749
+ (q_trans_loss * self._trans_loss_weight)
750
+ + (q_rot_loss * self._rot_loss_weight)
751
+ + (q_grip_loss * self._grip_loss_weight)
752
+ + (q_collision_loss * self._collision_loss_weight)
753
+ + 0.0001*L_skill
754
+ + 0.01*L_voxel
755
+ )
756
+ total_loss = combined_losses.mean()
757
+
758
+ if step % 10 == 0 and rank == 0 and wandb.run is not None:
759
+ wandb.log({
760
+ 'train/grip_loss': q_grip_loss.mean(),
761
+ 'train/trans_loss': q_trans_loss.mean(),
762
+ 'train/rot_loss': q_rot_loss.mean(),
763
+ 'train/collision_loss': q_collision_loss.mean(),
764
+ 'train/total_loss': total_loss,
765
+ }, step=step)
766
+
767
+ torch.autograd.set_detect_anomaly(True)
768
+ self._optimizer.zero_grad()
769
+ total_loss.backward()
770
+ self._optimizer.step()
771
+ torch.cuda.empty_cache()
772
+
773
+ self._summaries = {
774
+ "losses/total_loss": total_loss,
775
+ "losses/trans_loss": q_trans_loss.mean(),
776
+ "losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
777
+ "losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
778
+
779
+ "losses/right/trans_loss": q_trans_loss.mean(),
780
+ "losses/right/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
781
+ "losses/right/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
782
+ "losses/right/collision_loss": q_collision_loss.mean() if with_rot_and_grip else 0.0,
783
+
784
+ "losses/left/trans_loss": q_trans_loss.mean(),
785
+ "losses/left/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
786
+ "losses/left/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
787
+ "losses/left/collision_loss": q_collision_loss.mean() if with_rot_and_grip else 0.0,
788
+
789
+ "losses/collision_loss": q_collision_loss.mean()
790
+ if with_rot_and_grip
791
+ else 0.0,
792
+ }
793
+
794
+ self._wandb_summaries = {
795
+ 'losses/total_loss': total_loss,
796
+ 'losses/trans_loss': q_trans_loss.mean(),
797
+ 'losses/rot_loss': q_rot_loss.mean() if with_rot_and_grip else 0.,
798
+ 'losses/grip_loss': q_grip_loss.mean() if with_rot_and_grip else 0.,
799
+ 'losses/collision_loss': q_collision_loss.mean() if with_rot_and_grip else 0.
800
+ }
801
+
802
+ if self._lr_scheduler:
803
+ self._scheduler.step()
804
+ self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
805
+
806
+ self._vis_voxel_grid = voxel_grid[0]
807
+ self._right_vis_translation_qvalue = self._softmax_q_trans(right_q_trans[0])
808
+ self._right_vis_max_coordinate = right_coords[0]
809
+ self._right_vis_gt_coordinate = right_action_trans[0]
810
+
811
+ self._left_vis_translation_qvalue = self._softmax_q_trans(left_q_trans[0])
812
+ self._left_vis_max_coordinate = left_coords[0]
813
+ self._left_vis_gt_coordinate = left_action_trans[0]
814
+
815
+
816
+ # Note: PerAct doesn't use multi-layer voxel grids like C2FARM
817
+ # stack prev_layer_voxel_grid(s) from previous layers into a list
818
+ if prev_layer_voxel_grid is None:
819
+ prev_layer_voxel_grid = [voxel_grid]
820
+ else:
821
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
822
+
823
+ # stack prev_layer_bound(s) from previous layers into a list
824
+ if prev_layer_bounds is None:
825
+ prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
826
+ else:
827
+ prev_layer_bounds = prev_layer_bounds + [bounds]
828
+
829
+ q_trans_vis=True
830
+ log_freq = getattr(getattr(getattr(self, "cfg", None), "framework", None), "log_freq", None)
831
+ if log_freq and step % log_freq == 0 and rank == 0:
832
+ print("right_predict: ",self._right_vis_max_coordinate)
833
+ print("right_gt: ",self._right_vis_gt_coordinate)
834
+ print("left_predict: ",self._left_vis_max_coordinate)
835
+ print("left_gt: ",self._left_vis_gt_coordinate)
836
+ rendered_img_right = visualise_voxel(
837
+ voxel_grid[0].cpu().detach().numpy(), # [10, 100, 100, 100]
838
+ self._right_vis_translation_qvalue.detach().cpu().numpy() if q_trans_vis else None,
839
+ self._right_vis_max_coordinate.detach().cpu().numpy(),
840
+ self._right_vis_gt_coordinate.detach().cpu().numpy(),
841
+ voxel_size=0.045,
842
+ # voxel_size=0.1, # more focus ??
843
+ rotation_amount=np.deg2rad(-90),
844
+ highlight_alpha=1.0,
845
+ alpha=0.4,
846
+ )
847
+ rendered_img_left = visualise_voxel(
848
+ voxel_grid[0].cpu().detach().numpy(), # [10, 100, 100, 100]
849
+ self._left_vis_translation_qvalue.detach().cpu().numpy() if q_trans_vis else None,
850
+ self._left_vis_max_coordinate.detach().cpu().numpy(),
851
+ self._left_vis_gt_coordinate.detach().cpu().numpy(),
852
+ voxel_size=0.045,
853
+ # voxel_size=0.1, # more focus ??
854
+ rotation_amount=np.deg2rad(-90),
855
+ highlight_alpha=1.0,
856
+ alpha=0.4,
857
+ )
858
+ os.makedirs('recon', exist_ok=True)
859
+ # plot three images in one row with subplots:
860
+ rgb_src = obs[0][0][0].squeeze(0).permute(1, 2, 0) / 2 + 0.5
861
+
862
+ fig, axs = plt.subplots(1, 4, figsize=(9, 3))
863
+ # src
864
+ axs[0].imshow(rgb_src.cpu().numpy())
865
+ axs[0].title.set_text('src')
866
+
867
+ axs[1].imshow(rendered_img_right)
868
+ axs[1].text(0, 40, 'predicted', color='blue')
869
+ axs[1].text(0, 80, 'gt', color='red')
870
+ axs[2].imshow(rendered_img_left)
871
+ axs[2].text(0, 40, 'predicted', color='blue')
872
+ axs[2].text(0, 80, 'gt', color='red')
873
+ for ax in axs:
874
+ ax.axis('off')
875
+ plt.tight_layout()
876
+
877
+ if rank == 0:
878
+ if wandb.run is not None:
879
+ buf = io.BytesIO()
880
+ plt.savefig(buf, format='png')
881
+ buf.seek(0)
882
+
883
+ image = Image.open(buf)
884
+ wandb.log({"eval/recon_img": wandb.Image(image)}, step=step)
885
+
886
+ buf.close()
887
+ cprint(f'Saved to wandb', 'cyan')
888
+ else:
889
+ plt.savefig(f'recon/{step}_rgb.png')
890
+ workdir = os.getcwd()
891
+ cprint(f'Saved {workdir}/recon/{step}_rgb.png locally', 'cyan')
892
+
893
+ return {
894
+ "total_loss": total_loss,
895
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
896
+ "prev_layer_bounds": prev_layer_bounds,
897
+ }
898
+
899
+ def act(self, step: int,observation: dict,deterministic=False) -> ActResult:
900
+ deterministic = True
901
+ bounds = self._coordinate_bounds
902
+ prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
903
+ prev_layer_bounds = observation.get("prev_layer_bounds", None)
904
+ lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
905
+
906
+ # extract CLIP language embs
907
+ with torch.no_grad():
908
+ lang_goal_tokens = lang_goal_tokens.to(device=self._device)
909
+ (
910
+ lang_goal_emb,
911
+ lang_token_embs,
912
+ ) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
913
+
914
+ # voxelization resolution
915
+ res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
916
+ max_rot_index = int(360 // self._rotation_resolution)
917
+ right_proprio = None
918
+ left_proprio = None
919
+
920
+ if self._include_low_dim_state:
921
+ right_proprio = observation["right_low_dim_state"]
922
+ left_proprio = observation["left_low_dim_state"]
923
+ right_proprio = right_proprio[0].to(self._device)
924
+ left_proprio = left_proprio[0].to(self._device)
925
+
926
+ obs, pcd = self._act_preprocess_inputs(observation)
927
+
928
+ # correct batch size and device
929
+ obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
930
+
931
+ pcd = [p[0].to(self._device) for p in pcd]
932
+ lang_goal_emb = lang_goal_emb.to(self._device)
933
+ lang_token_embs = lang_token_embs.to(self._device)
934
+ bounds = torch.as_tensor(bounds, device=self._device)
935
+ prev_layer_voxel_grid = (
936
+ prev_layer_voxel_grid.to(self._device)
937
+ if prev_layer_voxel_grid is not None
938
+ else None
939
+ )
940
+ prev_layer_bounds = (
941
+ prev_layer_bounds.to(self._device)
942
+ if prev_layer_bounds is not None
943
+ else None
944
+ )
945
+
946
+ proprio = torch.cat((right_proprio, left_proprio), dim=1)
947
+ # inference
948
+ (
949
+ right_q_trans,
950
+ right_q_rot_grip,
951
+ right_q_ignore_collisions,
952
+ left_q_trans,
953
+ left_q_rot_grip,
954
+ left_q_ignore_collisions,
955
+ ), vox_grid = self._q(
956
+ obs,
957
+ proprio,
958
+ pcd,
959
+ lang_goal_emb,
960
+ lang_token_embs,
961
+ bounds,
962
+ prev_layer_bounds,
963
+ prev_layer_voxel_grid
964
+ )
965
+
966
+ # softmax Q predictions
967
+ right_q_trans = self._softmax_q_trans(right_q_trans)
968
+ left_q_trans = self._softmax_q_trans(left_q_trans)
969
+
970
+ if right_q_rot_grip is not None:
971
+ right_q_rot_grip = self._softmax_q_rot_grip(right_q_rot_grip)
972
+
973
+ if left_q_rot_grip is not None:
974
+ left_q_rot_grip = self._softmax_q_rot_grip(left_q_rot_grip)
975
+
976
+ if right_q_ignore_collisions is not None:
977
+ right_q_ignore_collisions = self._softmax_ignore_collision(
978
+ right_q_ignore_collisions
979
+ )
980
+
981
+ if left_q_ignore_collisions is not None:
982
+ left_q_ignore_collisions = self._softmax_ignore_collision(
983
+ left_q_ignore_collisions
984
+ )
985
+
986
+ # argmax Q predictions
987
+ (
988
+ right_coords,
989
+ right_rot_and_grip_indicies,
990
+ right_ignore_collisions,
991
+ ) = self._q.choose_highest_action(
992
+ right_q_trans, right_q_rot_grip, right_q_ignore_collisions
993
+ )
994
+ (
995
+ left_coords,
996
+ left_rot_and_grip_indicies,
997
+ left_ignore_collisions,
998
+ ) = self._q.choose_highest_action(
999
+ left_q_trans, left_q_rot_grip, left_q_ignore_collisions
1000
+ )
1001
+
1002
+ if right_q_rot_grip is not None:
1003
+ right_rot_grip_action = right_rot_and_grip_indicies
1004
+ if right_q_ignore_collisions is not None:
1005
+ right_ignore_collisions_action = right_ignore_collisions.int()
1006
+
1007
+ if left_q_rot_grip is not None:
1008
+ left_rot_grip_action = left_rot_and_grip_indicies
1009
+ if left_q_ignore_collisions is not None:
1010
+ left_ignore_collisions_action = left_ignore_collisions.int()
1011
+
1012
+ right_coords = right_coords.int()
1013
+ left_coords = left_coords.int()
1014
+
1015
+ right_attention_coordinate = bounds[:, :3] + res * right_coords + res / 2
1016
+ left_attention_coordinate = bounds[:, :3] + res * left_coords + res / 2
1017
+
1018
+ # stack prev_layer_voxel_grid(s) into a list
1019
+ # NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
1020
+ if prev_layer_voxel_grid is None:
1021
+ prev_layer_voxel_grid = [vox_grid]
1022
+ else:
1023
+ prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
1024
+
1025
+ if prev_layer_bounds is None:
1026
+ prev_layer_bounds = [bounds]
1027
+ else:
1028
+ prev_layer_bounds = prev_layer_bounds + [bounds]
1029
+
1030
+ observation_elements = {
1031
+ "right_attention_coordinate": right_attention_coordinate,
1032
+ "left_attention_coordinate": left_attention_coordinate,
1033
+ "prev_layer_voxel_grid": prev_layer_voxel_grid,
1034
+ "prev_layer_bounds": prev_layer_bounds,
1035
+ }
1036
+ info = {
1037
+ "voxel_grid_depth%d" % self._layer: vox_grid,
1038
+ "right_q_depth%d" % self._layer: right_q_trans,
1039
+ "right_voxel_idx_depth%d" % self._layer: right_coords,
1040
+ "left_q_depth%d" % self._layer: left_q_trans,
1041
+ "left_voxel_idx_depth%d" % self._layer: left_coords,
1042
+ }
1043
+ self._act_voxel_grid = vox_grid[0]
1044
+ self._right_act_max_coordinate = right_coords[0]
1045
+ self._right_act_qvalues = right_q_trans[0].detach()
1046
+ self._left_act_max_coordinate = left_coords[0]
1047
+ self._left_act_qvalues = left_q_trans[0].detach()
1048
+
1049
+ action = (
1050
+ right_coords,
1051
+ right_rot_grip_action,
1052
+ right_ignore_collisions,
1053
+ left_coords,
1054
+ left_rot_grip_action,
1055
+ left_ignore_collisions,
1056
+ )
1057
+
1058
+ return ActResult(action, observation_elements=observation_elements, info=info)
1059
+
1060
+ def update_summaries(self) -> List[Summary]:
1061
+ # voxel_grid = self._vis_voxel_grid.detach().cpu().numpy()
1062
+ summaries = []
1063
+ # summaries.append(
1064
+ # ImageSummary(
1065
+ # "%s/right_update_qattention" % self._name,
1066
+ # transforms.ToTensor()(
1067
+ # visualise_voxel(
1068
+ # voxel_grid,
1069
+ # self._right_vis_translation_qvalue.detach().cpu().numpy(),
1070
+ # self._right_vis_max_coordinate.detach().cpu().numpy(),
1071
+ # self._right_vis_gt_coordinate.detach().cpu().numpy(),
1072
+ # )
1073
+ # ),
1074
+ # )
1075
+ # )
1076
+ # summaries.append(
1077
+ # ImageSummary(
1078
+ # "%s/left_update_qattention" % self._name,
1079
+ # transforms.ToTensor()(
1080
+ # visualise_voxel(
1081
+ # voxel_grid,
1082
+ # self._left_vis_translation_qvalue.detach().cpu().numpy(),
1083
+ # self._left_vis_max_coordinate.detach().cpu().numpy(),
1084
+ # self._left_vis_gt_coordinate.detach().cpu().numpy(),
1085
+ # )
1086
+ # ),
1087
+ # )
1088
+ # )
1089
+ for n, v in self._summaries.items():
1090
+ summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
1091
+
1092
+ for name, crop in self._crop_summary:
1093
+ crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
1094
+ summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
1095
+
1096
+ for tag, param in self._q.named_parameters():
1097
+ # assert not torch.isnan(param.grad.abs() <= 1.0).all()
1098
+ summaries.append(
1099
+ HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
1100
+ )
1101
+ summaries.append(
1102
+ HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
1103
+ )
1104
+
1105
+ return summaries
1106
+
1107
+ def update_wandb_summaries(self):
1108
+ summaries = dict()
1109
+
1110
+ for k, v in self._wandb_summaries.items():
1111
+ summaries[k] = v
1112
+ return summaries
1113
+
1114
+ def act_summaries(self) -> List[Summary]:
1115
+ # voxel_grid = self._act_voxel_grid.cpu().numpy()
1116
+ # right_q_attention = self._right_act_qvalues.cpu().numpy()
1117
+ # right_highlight_coordinate = self._right_act_max_coordinate.cpu().numpy()
1118
+ # right_visualization = visualise_voxel(
1119
+ # voxel_grid, right_q_attention, right_highlight_coordinate
1120
+ # )
1121
+
1122
+ # left_q_attention = self._left_act_qvalues.cpu().numpy()
1123
+ # left_highlight_coordinate = self._left_act_max_coordinate.cpu().numpy()
1124
+ # left_visualization = visualise_voxel(
1125
+ # voxel_grid, left_q_attention, left_highlight_coordinate
1126
+ # )
1127
+
1128
+ # return [
1129
+ # ImageSummary(
1130
+ # f"{self._name}/right_act_Qattention",
1131
+ # transforms.ToTensor()(right_visualization),
1132
+ # ),
1133
+ # ImageSummary(
1134
+ # f"{self._name}/left_act_Qattention",
1135
+ # transforms.ToTensor()(left_visualization),
1136
+ # ),
1137
+ # ]
1138
+ return []
1139
+
1140
+ def concat_weights(self, param, target_size, dims=-1):
1141
+ if param.size(-1) < target_size:
1142
+ param = torch.cat([param, param], dims)
1143
+ return param
1144
+ def load_weights(self, savedir: str):
1145
+ device = (
1146
+ self._device
1147
+ if not self._training
1148
+ else torch.device("cuda:%d" % self._device)
1149
+ )
1150
+
1151
+ weight_file = os.path.join(savedir, "%s.pt" % self._name)
1152
+ state_dict = torch.load(weight_file, map_location=device)
1153
+ merged_state_dict = self._q.state_dict()
1154
+
1155
+ if not self._training:
1156
+ for k, v in state_dict.items():
1157
+ if not self._training:
1158
+ k = k.replace("_qnet.module", "_qnet")
1159
+ if k in merged_state_dict:
1160
+ merged_state_dict[k] = v
1161
+ else:
1162
+ if "_voxelizer" not in k:
1163
+ logging.warning("key %s not found in checkpoint" % k)
1164
+ else:
1165
+ for k, v in state_dict.items():
1166
+ if not self._training:
1167
+ k = k.replace("_qnet.module", "_qnet")
1168
+ # cross_attn
1169
+ if k.startswith("_qnet.module.decoder_cross_attn"):
1170
+ right_key = k.replace("_qnet.module.decoder_cross_attn", "_qnet.module.decoder_cross_attn_right")
1171
+ merged_state_dict[right_key] = v
1172
+ left_key = k.replace("_qnet.module.decoder_cross_attn", "_qnet.module.decoder_cross_attn_left")
1173
+ merged_state_dict[left_key] = v
1174
+ if self.anybimanual:
1175
+ if v.size(0) == 128:
1176
+ merged_state_dict[right_key] = self.concat_weights(v, 256, 0)
1177
+ merged_state_dict[left_key] = self.concat_weights(v, 256, 0)
1178
+ if v.size(-1) == 128:
1179
+ merged_state_dict[right_key] = self.concat_weights(v, 256)
1180
+ merged_state_dict[left_key] = self.concat_weights(v, 256)
1181
+ elif k == "_qnet.module.up0.conv_up.0.conv3d.weight":
1182
+ if self.anybimanual:
1183
+ if v.size(1) == 128:
1184
+ merged_state_dict[k] = self.concat_weights(v, 256, 1)
1185
+ else:
1186
+ merged_state_dict[k] = v
1187
+ # trans_decoder
1188
+ elif k.startswith("_qnet.module.trans_decoder"):
1189
+ right_key = k.replace("_qnet.module.trans_decoder", "_qnet.module.right_trans_decoder")
1190
+ merged_state_dict[right_key] = v
1191
+
1192
+ left_key = k.replace("_qnet.module.trans_decoder", "_qnet.module.left_trans_decoder")
1193
+ merged_state_dict[left_key] = v
1194
+ # dense0
1195
+ elif k.startswith("_qnet.module.dense0"):
1196
+ right_key = k.replace("_qnet.module.dense0", "_qnet.module.right_dense0")
1197
+ merged_state_dict[right_key] = v
1198
+
1199
+ left_key = k.replace("_qnet.module.dense0", "_qnet.module.left_dense0")
1200
+ merged_state_dict[left_key] = v
1201
+ if self.anybimanual:
1202
+ if v.size(-1) == 1024:
1203
+ merged_state_dict[right_key] = torch.cat([v, v[:, :512]], dim=-1)
1204
+ merged_state_dict[left_key] = torch.cat([v, v[:, :512]], dim=-1)
1205
+ # dense1
1206
+ elif k.startswith("_qnet.module.dense1"):
1207
+ right_key = k.replace("_qnet.module.dense1", "_qnet.module.right_dense1")
1208
+ merged_state_dict[right_key] = v
1209
+
1210
+ left_key = k.replace("_qnet.module.dense1", "_qnet.module.left_dense1")
1211
+ merged_state_dict[left_key] = v
1212
+ # collision
1213
+ elif k.startswith("_qnet.module.rot_grip_collision_ff"):
1214
+ right_key = k.replace("_qnet.module.rot_grip_collision_ff", "_qnet.module.right_rot_grip_collision_ff")
1215
+ merged_state_dict[right_key] = v
1216
+
1217
+ left_key = k.replace("_qnet.module.rot_grip_collision_ff", "_qnet.module.left_rot_grip_collision_ff")
1218
+ merged_state_dict[left_key] = v
1219
+ elif k.startswith("_qnet.module.cross_attend_blocks"):
1220
+ if self.anybimanual:
1221
+ if k.startswith("_qnet.module.cross_attend_blocks.0"):
1222
+ merged_state_dict[k] = v
1223
+ k_1 = k.replace("_qnet.module.cross_attend_blocks.0","_qnet.module.cross_attend_blocks.1")
1224
+ merged_state_dict[k_1] = v
1225
+ if self.anybimanual:
1226
+ if v.size(-1) == 128:
1227
+ merged_state_dict[k_1] = self.concat_weights(v, 256)
1228
+ else:
1229
+ k_2 = k.replace("_qnet.module.cross_attend_blocks.1","_qnet.module.cross_attend_blocks.2")
1230
+ k_3 = k.replace("_qnet.module.cross_attend_blocks.1","_qnet.module.cross_attend_blocks.3")
1231
+ merged_state_dict[k_2] = v
1232
+ merged_state_dict[k_3] = v
1233
+ if self.anybimanual:
1234
+ if v.size(-1) == 128:
1235
+ merged_state_dict[k_2] = self.concat_weights(v, 256)
1236
+ merged_state_dict[k_3] = self.concat_weights(v, 256)
1237
+ else:
1238
+ if k.startswith("_qnet.module.cross_attend_blocks.0"):
1239
+ merged_state_dict[k] = v
1240
+ else:
1241
+ merged_state_dict[k] = v
1242
+ k_2 = k.replace("_qnet.module.cross_attend_blocks.1","_qnet.module.cross_attend_blocks.2")
1243
+ merged_state_dict[k_2] = v
1244
+ if self.anybimanual:
1245
+ if v.size(-1) == 128:
1246
+ merged_state_dict[k_2] = self.concat_weights(v, 256)
1247
+ if self.anybimanual:
1248
+ if v.size(-1) == 128:
1249
+ merged_state_dict[k] = self.concat_weights(v, 256)
1250
+ # proprio
1251
+ elif k == '_qnet.module.proprio_preprocess.linear.weight':
1252
+ if v.shape[1] != 8:
1253
+ new_v = torch.cat([v,v], dim=1)
1254
+ merged_state_dict['_qnet.module.proprio_preprocess.linear.weight'] = new_v
1255
+ else:
1256
+ merged_state_dict[k] = v
1257
+ elif k == '_qnet.module.proprio_preprocess.linear.bias':
1258
+ merged_state_dict['_qnet.module.proprio_preprocess.linear.bias'] = v
1259
+ # pos_with_lang
1260
+ elif k == "_qnet.module.pos_encoding":
1261
+ if (v.shape[1] != 8077 or v.shape[1] != 8154) and v.shape[1] < 154:
1262
+ if self.anybimanual:
1263
+ lang_max_seq_len = 154
1264
+ else:
1265
+ lang_max_seq_len = 77
1266
+ spatial_size = v.shape[1]
1267
+ input_dim_before_seq = v.shape[-1]
1268
+ flattened_v = v.view(1, -1, input_dim_before_seq) # (1, spatial_size**3, self.input_dim_before_seq)
1269
+ new_pos_encoding = torch.randn(1, lang_max_seq_len, input_dim_before_seq, device=device)
1270
+ merged_pos_encoding = torch.cat([flattened_v, new_pos_encoding], dim=1) # (1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq)
1271
+ merged_state_dict["_qnet.module.pos_encoding"] = merged_pos_encoding
1272
+ else:
1273
+ merged_state_dict["_qnet.module.pos_encoding"] = v
1274
+ elif k in merged_state_dict:
1275
+ merged_state_dict[k] = v
1276
+
1277
+ # else:
1278
+ # if "_voxelizer" not in k:
1279
+ # logging.warning("key %s not found in checkpoint" % k)
1280
+
1281
+
1282
+ if not self._training:
1283
+ b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
1284
+ merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
1285
+ "_voxelizer._ones_max_coords"
1286
+ ][0:1]
1287
+ flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
1288
+ merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
1289
+ "_voxelizer._flat_output"
1290
+ ][0 : flat_shape // b]
1291
+ merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
1292
+ "_voxelizer._tiled_batch_indices"
1293
+ ][0:1]
1294
+ merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
1295
+ "_voxelizer._index_grid"
1296
+ ][0:1]
1297
+ self._q.load_state_dict(merged_state_dict)
1298
+
1299
+ if self.frozen:
1300
+ print("Freezing parameters from PerAct")
1301
+ for name, param in self._q.named_parameters():
1302
+ if name in state_dict:
1303
+ param.requires_grad = False
1304
+
1305
+ logging.info(
1306
+ "# Q Params: %d"
1307
+ % sum(
1308
+ p.numel()
1309
+ for name, p in self._q.named_parameters()
1310
+ if p.requires_grad and "clip" not in name
1311
+ )
1312
+ )
1313
+ print("loaded weights from %s" % weight_file)
1314
+
1315
+
1316
+ def save_weights(self, savedir: str):
1317
+ torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
third_party/AnyBimanual/agents/peract_bimanual/qattention_stack_agent.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ import torch
4
+ from yarr.agents.agent import Agent, ActResult, Summary
5
+
6
+ import numpy as np
7
+
8
+ from helpers import utils
9
+ from agents.peract_bimanual.qattention_peract_bc_agent import QAttentionPerActBCAgent
10
+
11
+ NAME = "QAttentionStackAgent"
12
+
13
+
14
+ class QAttentionStackAgent(Agent):
15
+ def __init__(
16
+ self,
17
+ qattention_agents: List[QAttentionPerActBCAgent],
18
+ rotation_resolution: float,
19
+ camera_names: List[str],
20
+ rotation_prediction_depth: int = 0,
21
+ ):
22
+ super(QAttentionStackAgent, self).__init__()
23
+ self._qattention_agents = qattention_agents
24
+ self._rotation_resolution = rotation_resolution
25
+ self._camera_names = camera_names
26
+ self._rotation_prediction_depth = rotation_prediction_depth
27
+
28
+ def build(self, training: bool, device=None) -> None:
29
+ self._device = device
30
+ if self._device is None:
31
+ self._device = torch.device("cpu")
32
+ for qa in self._qattention_agents:
33
+ qa.build(training, device)
34
+
35
+ def update(self, step: int, replay_sample: dict) -> dict:
36
+ priorities = 0
37
+ total_losses = 0.0
38
+ for qa in self._qattention_agents:
39
+ update_dict = qa.update(step, replay_sample)
40
+ replay_sample.update(update_dict)
41
+ total_losses += update_dict["total_loss"]
42
+ return {
43
+ "total_losses": total_losses,
44
+ }
45
+
46
+ def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
47
+ observation_elements = {}
48
+ (
49
+ right_translation_results,
50
+ right_rot_grip_results,
51
+ right_ignore_collisions_results,
52
+ ) = ([], [], [])
53
+ (
54
+ left_translation_results,
55
+ left_rot_grip_results,
56
+ left_ignore_collisions_results,
57
+ ) = ([], [], [])
58
+
59
+ infos = {}
60
+ for depth, qagent in enumerate(self._qattention_agents):
61
+ act_results = qagent.act(step, observation, deterministic)
62
+ right_attention_coordinate = (
63
+ act_results.observation_elements["right_attention_coordinate"]
64
+ .cpu()
65
+ .numpy()
66
+ )
67
+ left_attention_coordinate = (
68
+ act_results.observation_elements["left_attention_coordinate"]
69
+ .cpu()
70
+ .numpy()
71
+ )
72
+ observation_elements[
73
+ "right_attention_coordinate_layer_%d" % depth
74
+ ] = right_attention_coordinate[0]
75
+ observation_elements[
76
+ "left_attention_coordinate_layer_%d" % depth
77
+ ] = left_attention_coordinate[0]
78
+
79
+ (
80
+ right_translation_idxs,
81
+ right_rot_grip_idxs,
82
+ right_ignore_collisions_idxs,
83
+ left_translation_idxs,
84
+ left_rot_grip_idxs,
85
+ left_ignore_collisions_idxs,
86
+ ) = act_results.action
87
+
88
+ right_translation_results.append(right_translation_idxs)
89
+ if right_rot_grip_idxs is not None:
90
+ right_rot_grip_results.append(right_rot_grip_idxs)
91
+ if right_ignore_collisions_idxs is not None:
92
+ right_ignore_collisions_results.append(right_ignore_collisions_idxs)
93
+
94
+ left_translation_results.append(left_translation_idxs)
95
+ if left_rot_grip_idxs is not None:
96
+ left_rot_grip_results.append(left_rot_grip_idxs)
97
+ if left_ignore_collisions_idxs is not None:
98
+ left_ignore_collisions_results.append(left_ignore_collisions_idxs)
99
+
100
+ observation[
101
+ "right_attention_coordinate"
102
+ ] = act_results.observation_elements["right_attention_coordinate"]
103
+ observation["left_attention_coordinate"] = act_results.observation_elements[
104
+ "left_attention_coordinate"
105
+ ]
106
+
107
+ observation["prev_layer_voxel_grid"] = act_results.observation_elements[
108
+ "prev_layer_voxel_grid"
109
+ ]
110
+ observation["prev_layer_bounds"] = act_results.observation_elements[
111
+ "prev_layer_bounds"
112
+ ]
113
+
114
+ for n in self._camera_names:
115
+ extrinsics = observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy()
116
+ intrinsics = observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy()
117
+ px, py = utils.point_to_pixel_index(
118
+ right_attention_coordinate[0], extrinsics, intrinsics
119
+ )
120
+ pc_t = torch.tensor(
121
+ [[[py, px]]], dtype=torch.float32, device=self._device
122
+ )
123
+ observation[f"right_{n}_pixel_coord"] = pc_t
124
+ observation_elements[f"right_{n}_pixel_coord"] = [py, px]
125
+
126
+ px, py = utils.point_to_pixel_index(
127
+ left_attention_coordinate[0], extrinsics, intrinsics
128
+ )
129
+ pc_t = torch.tensor(
130
+ [[[py, px]]], dtype=torch.float32, device=self._device
131
+ )
132
+ observation[f"left_{n}_pixel_coord"] = pc_t
133
+ observation_elements[f"left_{n}_pixel_coord"] = [py, px]
134
+ infos.update(act_results.info)
135
+
136
+ right_rgai = torch.cat(right_rot_grip_results, 1)[0].cpu().numpy()
137
+ # ..todo:: utils.correct_rotation_instability does nothing so we can ignore it
138
+ # right_rgai = utils.correct_rotation_instability(right_rgai, self._rotation_resolution)
139
+ right_ignore_collisions = (
140
+ torch.cat(right_ignore_collisions_results, 1)[0].cpu().numpy()
141
+ )
142
+ right_trans_action_indicies = (
143
+ torch.cat(right_translation_results, 1)[0].cpu().numpy()
144
+ )
145
+
146
+ observation_elements[
147
+ "right_trans_action_indicies"
148
+ ] = right_trans_action_indicies[:3]
149
+ observation_elements["right_rot_grip_action_indicies"] = right_rgai[:4]
150
+
151
+ left_rgai = torch.cat(left_rot_grip_results, 1)[0].cpu().numpy()
152
+ left_ignore_collisions = (
153
+ torch.cat(left_ignore_collisions_results, 1)[0].cpu().numpy()
154
+ )
155
+ left_trans_action_indicies = (
156
+ torch.cat(left_translation_results, 1)[0].cpu().numpy()
157
+ )
158
+
159
+ observation_elements["left_trans_action_indicies"] = left_trans_action_indicies[
160
+ 3:
161
+ ]
162
+ observation_elements["left_rot_grip_action_indicies"] = left_rgai[4:]
163
+
164
+ continuous_action = np.concatenate(
165
+ [
166
+ right_attention_coordinate[0],
167
+ utils.discrete_euler_to_quaternion(
168
+ right_rgai[-4:-1], self._rotation_resolution
169
+ ),
170
+ right_rgai[-1:],
171
+ right_ignore_collisions,
172
+ left_attention_coordinate[0],
173
+ utils.discrete_euler_to_quaternion(
174
+ left_rgai[-4:-1], self._rotation_resolution
175
+ ),
176
+ left_rgai[-1:],
177
+ left_ignore_collisions,
178
+ ]
179
+ )
180
+ return ActResult(
181
+ continuous_action, observation_elements=observation_elements, info=infos
182
+ )
183
+
184
+ def update_summaries(self) -> List[Summary]:
185
+ summaries = []
186
+ for qa in self._qattention_agents:
187
+ summaries.extend(qa.update_summaries())
188
+ return summaries
189
+
190
+ def update_wandb_summaries(self):
191
+ summaries = {}
192
+ for qa in self._qattention_agents:
193
+ summaries.update(qa.update_wandb_summaries())
194
+ return summaries
195
+ def act_summaries(self) -> List[Summary]:
196
+ s = []
197
+ for qa in self._qattention_agents:
198
+ s.extend(qa.act_summaries())
199
+ return s
200
+
201
+ def load_weights(self, savedir: str):
202
+ for qa in self._qattention_agents:
203
+ print(dir(qa))
204
+ qa.load_weights(savedir)
205
+
206
+
207
+ def save_weights(self, savedir: str):
208
+ for qa in self._qattention_agents:
209
+ qa.save_weights(savedir)
third_party/AnyBimanual/agents/peract_bimanual/skill_manager.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import transformers
4
+ from agents.peract_bimanual.trajectory_gpt2 import GPT2Model
5
+ import torch.nn.functional as F
6
+ class SkillManager(nn.Module):
7
+ def __init__(
8
+ self,
9
+ num_classes,
10
+ embedding_matrix=None,
11
+ voxel_dim=128,
12
+ lang_dim=128,
13
+ hidden_size=128,
14
+ output_dim=18,
15
+ max_voxels=8000,
16
+ max_lang_tokens=77,
17
+ **kwargs):
18
+ super().__init__()
19
+
20
+ self.hidden_size = hidden_size
21
+ self.output_dim = output_dim
22
+
23
+ # GPT-2 configuration
24
+ config = transformers.GPT2Config(
25
+ vocab_size=1, # not used
26
+ n_embd=hidden_size,
27
+ n_head=4,
28
+ n_ctx=1077,
29
+ )
30
+
31
+ self.max_voxels = max_voxels
32
+ self.max_lang_tokens = max_lang_tokens
33
+ self.embed_voxel = nn.Linear(voxel_dim, hidden_size)
34
+ self.embed_lang = nn.Linear(lang_dim, hidden_size)
35
+ self.transformer = GPT2Model(config)
36
+ self.embed_ln = nn.LayerNorm(hidden_size)
37
+ self.predict_logits = nn.Linear(hidden_size, output_dim)
38
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
39
+ self.num_class = num_classes
40
+ if embedding_matrix is not None:
41
+ self.embeddings_matrix = embedding_matrix.to(self.device)
42
+
43
+ def forward(self, voxel_embedding, language_embedding):
44
+ batch_size = voxel_embedding.shape[0]
45
+ voxel_embeddings = self.embed_voxel(voxel_embedding) # [b, 8000, hidden_size]
46
+ language_embeddings = self.embed_lang(language_embedding) # [b, 77, hidden_size]
47
+ voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, hidden_size, 8000]
48
+ voxel_embeddings = F.avg_pool1d(voxel_embeddings, kernel_size=16, stride=16) # [b, hidden_size, 1000]
49
+ voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, 1000, hidden_size]
50
+ inputs = torch.cat([language_embeddings, voxel_embeddings], dim=1) # [b, 8077, hidden_size]
51
+ stacked_inputs = self.embed_ln(inputs)
52
+ attention_mask = torch.ones(
53
+ (batch_size, self.max_lang_tokens + self.max_voxels),
54
+ device=voxel_embedding.device,
55
+ dtype=torch.long # Ensure correct dtype
56
+ )
57
+ assert torch.isfinite(attention_mask).all(), "attention_mask contains NaN or Inf"
58
+ assert torch.all((attention_mask == 1)), "attention_mask contains values not equal to 1"
59
+ transformer_outputs = self.transformer(
60
+ inputs_embeds=stacked_inputs,
61
+ attention_mask=None,
62
+ )
63
+
64
+ hidden_state = transformer_outputs.last_hidden_state # [b, 8077, hidden_size]
65
+ aggregated_hidden = hidden_state.mean(dim=1) # [b, hidden_size]
66
+ logits = self.predict_logits(aggregated_hidden) # [b, output_dim]
67
+ probs = F.softmax(logits, dim=1)
68
+ skill = torch.matmul(probs, self.embeddings_matrix.to(probs.device))
69
+ skill = skill.view(-1,77,512)
70
+ return skill
third_party/AnyBimanual/agents/peract_bimanual/trajectory_gpt2.py ADDED
@@ -0,0 +1,775 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
3
+ # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """PyTorch OpenAI GPT-2 model."""
17
+
18
+ import os
19
+ from dataclasses import dataclass
20
+ from typing import List, Optional, Tuple
21
+
22
+ import torch
23
+ import torch.nn as nn
24
+ from torch.nn import CrossEntropyLoss, MSELoss
25
+
26
+ from transformers.activations import ACT2FN
27
+ from transformers.file_utils import (
28
+ ModelOutput,
29
+ add_code_sample_docstrings,
30
+ add_start_docstrings,
31
+ add_start_docstrings_to_model_forward,
32
+ replace_return_docstrings,
33
+ )
34
+ from transformers.modeling_outputs import (
35
+ BaseModelOutputWithPastAndCrossAttentions,
36
+ )
37
+ from transformers.modeling_utils import (
38
+ Conv1D,
39
+ PreTrainedModel,
40
+ SequenceSummary,
41
+ find_pruneable_heads_and_indices,
42
+ prune_conv1d_layer,
43
+ )
44
+ from transformers.utils import logging
45
+ from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
46
+ from transformers.models.gpt2.configuration_gpt2 import GPT2Config
47
+
48
+ logger = logging.get_logger(__name__)
49
+
50
+ _CONFIG_FOR_DOC = "GPT2Config"
51
+ _TOKENIZER_FOR_DOC = "GPT2Tokenizer"
52
+
53
+ GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
54
+ "gpt2",
55
+ "gpt2-medium",
56
+ "gpt2-large",
57
+ "gpt2-xl",
58
+ "distilgpt2",
59
+ # See all GPT-2 models at https://huggingface.co/models?filter=gpt2
60
+ ]
61
+
62
+
63
+ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
64
+ """Load tf checkpoints in a pytorch model"""
65
+ try:
66
+ import re
67
+
68
+ import tensorflow as tf
69
+ except ImportError:
70
+ logger.error(
71
+ "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
72
+ "https://www.tensorflow.org/install/ for installation instructions."
73
+ )
74
+ raise
75
+ tf_path = os.path.abspath(gpt2_checkpoint_path)
76
+ logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
77
+ # Load weights from TF model
78
+ init_vars = tf.train.list_variables(tf_path)
79
+ names = []
80
+ arrays = []
81
+ for name, shape in init_vars:
82
+ logger.info("Loading TF weight {} with shape {}".format(name, shape))
83
+ array = tf.train.load_variable(tf_path, name)
84
+ names.append(name)
85
+ arrays.append(array.squeeze())
86
+
87
+ for name, array in zip(names, arrays):
88
+ name = name[6:] # skip "model/"
89
+ name = name.split("/")
90
+ pointer = model
91
+ for m_name in name:
92
+ if re.fullmatch(r"[A-Za-z]+\d+", m_name):
93
+ scope_names = re.split(r"(\d+)", m_name)
94
+ else:
95
+ scope_names = [m_name]
96
+ if scope_names[0] == "w" or scope_names[0] == "g":
97
+ pointer = getattr(pointer, "weight")
98
+ elif scope_names[0] == "b":
99
+ pointer = getattr(pointer, "bias")
100
+ elif scope_names[0] == "wpe" or scope_names[0] == "wte":
101
+ pointer = getattr(pointer, scope_names[0])
102
+ pointer = getattr(pointer, "weight")
103
+ else:
104
+ pointer = getattr(pointer, scope_names[0])
105
+ if len(scope_names) >= 2:
106
+ num = int(scope_names[1])
107
+ pointer = pointer[num]
108
+ try:
109
+ assert (
110
+ pointer.shape == array.shape
111
+ ), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
112
+ except AssertionError as e:
113
+ e.args += (pointer.shape, array.shape)
114
+ raise
115
+ logger.info("Initialize PyTorch weight {}".format(name))
116
+ pointer.data = torch.from_numpy(array)
117
+ return model
118
+
119
+
120
+ class Attention(nn.Module):
121
+ def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
122
+ super().__init__()
123
+
124
+ n_state = nx # in Attention: n_state=768 (nx=n_embd)
125
+ # [switch nx => n_state from Block to Attention to keep identical to TF implem]
126
+ assert n_state % config.n_head == 0
127
+ self.register_buffer(
128
+ "bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
129
+ )
130
+ self.register_buffer("masked_bias", torch.tensor(-1e4))
131
+ self.n_head = config.n_head
132
+ self.split_size = n_state
133
+ self.scale = scale
134
+ self.is_cross_attention = is_cross_attention
135
+ if self.is_cross_attention:
136
+ self.c_attn = Conv1D(2 * n_state, nx)
137
+ self.q_attn = Conv1D(n_state, nx)
138
+ else:
139
+ self.c_attn = Conv1D(3 * n_state, nx)
140
+ self.c_proj = Conv1D(n_state, nx)
141
+ self.attn_dropout = nn.Dropout(config.attn_pdrop)
142
+ self.resid_dropout = nn.Dropout(config.resid_pdrop)
143
+ self.pruned_heads = set()
144
+
145
+ def prune_heads(self, heads):
146
+ if len(heads) == 0:
147
+ return
148
+ heads, index = find_pruneable_heads_and_indices(
149
+ heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
150
+ )
151
+ index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
152
+
153
+ # Prune conv1d layers
154
+ self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
155
+ self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
156
+
157
+ # Update hyper params
158
+ self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
159
+ self.n_head = self.n_head - len(heads)
160
+ self.pruned_heads = self.pruned_heads.union(heads)
161
+
162
+ def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
163
+ w = torch.matmul(q, k)
164
+ if self.scale:
165
+ w = w / (float(v.size(-1)) ** 0.5)
166
+ nd, ns = w.size(-2), w.size(-1)
167
+
168
+ if not self.is_cross_attention:
169
+ # if only "normal" attention layer implements causal mask
170
+ mask = self.bias[:, :, ns - nd: ns, :ns]
171
+ w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
172
+
173
+ if attention_mask is not None:
174
+ # Apply the attention mask
175
+ w = w + attention_mask
176
+
177
+ w = nn.Softmax(dim=-1)(w)
178
+ w = self.attn_dropout(w)
179
+
180
+ # Mask heads if we want to
181
+ if head_mask is not None:
182
+ w = w * head_mask
183
+
184
+ outputs = [torch.matmul(w, v)]
185
+ if output_attentions:
186
+ outputs.append(w)
187
+ return outputs
188
+
189
+ def merge_heads(self, x):
190
+ x = x.permute(0, 2, 1, 3).contiguous()
191
+ new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
192
+ return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
193
+
194
+ def split_heads(self, x, k=False):
195
+ new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
196
+ x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
197
+ if k:
198
+ return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
199
+ else:
200
+ return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
201
+
202
+ def forward(
203
+ self,
204
+ hidden_states,
205
+ layer_past=None,
206
+ attention_mask=None,
207
+ head_mask=None,
208
+ encoder_hidden_states=None,
209
+ encoder_attention_mask=None,
210
+ use_cache=False,
211
+ output_attentions=False,
212
+ ):
213
+ if encoder_hidden_states is not None:
214
+ assert hasattr(
215
+ self, "q_attn"
216
+ ), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
217
+ query = self.q_attn(hidden_states)
218
+ key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
219
+ attention_mask = encoder_attention_mask
220
+ else:
221
+ query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
222
+
223
+ query = self.split_heads(query)
224
+ key = self.split_heads(key, k=True)
225
+ value = self.split_heads(value)
226
+ if layer_past is not None:
227
+ past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
228
+ key = torch.cat((past_key, key), dim=-1)
229
+ value = torch.cat((past_value, value), dim=-2)
230
+
231
+ if use_cache is True:
232
+ present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
233
+ else:
234
+ present = (None,)
235
+
236
+ attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
237
+ a = attn_outputs[0]
238
+
239
+ a = self.merge_heads(a)
240
+ a = self.c_proj(a)
241
+ a = self.resid_dropout(a)
242
+
243
+ outputs = [a, present] + attn_outputs[1:]
244
+ return outputs # a, present, (attentions)
245
+
246
+
247
+ class MLP(nn.Module):
248
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
249
+ super().__init__()
250
+ nx = config.n_embd
251
+ self.c_fc = Conv1D(n_state, nx)
252
+ self.c_proj = Conv1D(nx, n_state)
253
+ self.act = ACT2FN[config.activation_function]
254
+ self.dropout = nn.Dropout(config.resid_pdrop)
255
+
256
+ def forward(self, x):
257
+ h = self.act(self.c_fc(x))
258
+ h2 = self.c_proj(h)
259
+ return self.dropout(h2)
260
+
261
+
262
+ class AdapterMLP(nn.Module):
263
+ def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
264
+ super().__init__()
265
+ nx = config.n_embd
266
+ self.c_fc = Conv1D(n_state, nx)
267
+ self.c_proj = Conv1D(nx, n_state)
268
+ self.act = ACT2FN[config.activation_function]
269
+ self.dropout = nn.Dropout(config.resid_pdrop)
270
+
271
+ def forward(self, x):
272
+ h = self.act(self.c_fc(x))
273
+ h2 = self.c_proj(h)
274
+ return self.dropout(h2)
275
+
276
+
277
+ class Block(nn.Module):
278
+ def __init__(self, n_ctx, config, scale=False):
279
+ super().__init__()
280
+ hidden_size = config.n_embd
281
+ inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
282
+ self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
283
+ self.attn = Attention(hidden_size, n_ctx, config, scale)
284
+ self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
285
+ # self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
286
+ if config.add_cross_attention:
287
+ self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
288
+ self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
289
+ self.mlp = MLP(inner_dim, config)
290
+ # self.adapter_mlp = AdapterMLP(512, config) # ADAPTER
291
+
292
+ def forward(
293
+ self,
294
+ hidden_states,
295
+ layer_past=None,
296
+ attention_mask=None,
297
+ head_mask=None,
298
+ encoder_hidden_states=None,
299
+ encoder_attention_mask=None,
300
+ use_cache=False,
301
+ output_attentions=False,
302
+ ):
303
+ attn_outputs = self.attn(
304
+ self.ln_1(hidden_states),
305
+ layer_past=layer_past,
306
+ attention_mask=attention_mask,
307
+ head_mask=head_mask,
308
+ use_cache=use_cache,
309
+ output_attentions=output_attentions,
310
+ )
311
+ attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
312
+ outputs = attn_outputs[1:]
313
+ # residual connection
314
+ hidden_states = attn_output + hidden_states
315
+
316
+ if encoder_hidden_states is not None:
317
+ # add one self-attention block for cross-attention
318
+ assert hasattr(
319
+ self, "crossattention"
320
+ ), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
321
+ cross_attn_outputs = self.crossattention(
322
+ self.ln_cross_attn(hidden_states),
323
+ attention_mask=attention_mask,
324
+ head_mask=head_mask,
325
+ encoder_hidden_states=encoder_hidden_states,
326
+ encoder_attention_mask=encoder_attention_mask,
327
+ output_attentions=output_attentions,
328
+ )
329
+ attn_output = cross_attn_outputs[0]
330
+ # residual connection
331
+ hidden_states = hidden_states + attn_output
332
+ outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
333
+
334
+ feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
335
+ # residual connection
336
+ hidden_states = hidden_states + feed_forward_hidden_states
337
+ # hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states))
338
+
339
+ outputs = [hidden_states] + outputs
340
+ return outputs # hidden_states, present, (attentions, cross_attentions)
341
+
342
+
343
+ class GPT2PreTrainedModel(PreTrainedModel):
344
+ """
345
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
346
+ models.
347
+ """
348
+
349
+ config_class = GPT2Config
350
+ load_tf_weights = load_tf_weights_in_gpt2
351
+ base_model_prefix = "transformer"
352
+
353
+ def __init__(self, *inputs, **kwargs):
354
+ super().__init__(*inputs, **kwargs)
355
+
356
+ def _init_weights(self, module):
357
+ """Initialize the weights."""
358
+ if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
359
+ # Slightly different from the TF version which uses truncated_normal for initialization
360
+ # cf https://github.com/pytorch/pytorch/pull/5617
361
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
362
+ if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
363
+ module.bias.data.zero_()
364
+ elif isinstance(module, nn.LayerNorm):
365
+ module.bias.data.zero_()
366
+ module.weight.data.fill_(1.0)
367
+ # module.weight.data.fill_(.01) # KL: Adapter change
368
+
369
+
370
+ @dataclass
371
+ class GPT2DoubleHeadsModelOutput(ModelOutput):
372
+ """
373
+ Base class for outputs of models predicting if two sentences are consecutive or not.
374
+ Args:
375
+ loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
376
+ Language modeling loss.
377
+ mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
378
+ Multiple choice classification loss.
379
+ logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
380
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
381
+ mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
382
+ Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
383
+ past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
384
+ List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
385
+ batch_size, num_heads, sequence_length, embed_size_per_head)`).
386
+ Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
387
+ :obj:`past_key_values` input) to speed up sequential decoding.
388
+ hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
389
+ Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
390
+ of shape :obj:`(batch_size, sequence_length, hidden_size)`.
391
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
392
+ attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
393
+ Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
394
+ sequence_length, sequence_length)`.
395
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
396
+ heads.
397
+ """
398
+
399
+ loss: Optional[torch.FloatTensor] = None
400
+ mc_loss: Optional[torch.FloatTensor] = None
401
+ logits: torch.FloatTensor = None
402
+ mc_logits: torch.FloatTensor = None
403
+ past_key_values: Optional[List[torch.FloatTensor]] = None
404
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
405
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
406
+
407
+
408
+ GPT2_START_DOCSTRING = r"""
409
+ This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
410
+ methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
411
+ pruning heads etc.)
412
+ This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
413
+ subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
414
+ general usage and behavior.
415
+ Parameters:
416
+ config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
417
+ Initializing with a config file does not load the weights associated with the model, only the
418
+ configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
419
+ weights.
420
+ """
421
+
422
+ GPT2_INPUTS_DOCSTRING = r"""
423
+ Args:
424
+ input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
425
+ :obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
426
+ ``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
427
+ sequence tokens in the vocabulary.
428
+ If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
429
+ passed as ``input_ids``.
430
+ Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
431
+ :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
432
+ details.
433
+ `What are input IDs? <../glossary.html#input-ids>`__
434
+ past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
435
+ Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
436
+ :obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
437
+ have their past given to this model should not be passed as ``input_ids`` as they have already been
438
+ computed.
439
+ attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
440
+ Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
441
+ - 1 for tokens that are **not masked**,
442
+ - 0 for tokens that are **masked**.
443
+ `What are attention masks? <../glossary.html#attention-mask>`__
444
+ token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
445
+ Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
446
+ 1]``:
447
+ - 0 corresponds to a `sentence A` token,
448
+ - 1 corresponds to a `sentence B` token.
449
+ `What are token type IDs? <../glossary.html#token-type-ids>`_
450
+ position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
451
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
452
+ config.max_position_embeddings - 1]``.
453
+ `What are position IDs? <../glossary.html#position-ids>`_
454
+ head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
455
+ Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
456
+ - 1 indicates the head is **not masked**,
457
+ - 0 indicates the head is **masked**.
458
+ inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
459
+ Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
460
+ This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
461
+ vectors than the model's internal embedding lookup matrix.
462
+ If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
463
+ :obj:`past_key_values`).
464
+ use_cache (:obj:`bool`, `optional`):
465
+ If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
466
+ decoding (see :obj:`past_key_values`).
467
+ output_attentions (:obj:`bool`, `optional`):
468
+ Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
469
+ tensors for more detail.
470
+ output_hidden_states (:obj:`bool`, `optional`):
471
+ Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
472
+ more detail.
473
+ return_dict (:obj:`bool`, `optional`):
474
+ Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
475
+ """
476
+ PARALLELIZE_DOCSTRING = r"""
477
+ Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
478
+ it will evenly distribute blocks across all devices.
479
+ Args:
480
+ device_map (:obj:`Dict[int, list]`, optional, defaults to None):
481
+ A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
482
+ automatically mapped to the first device (for esoteric reasons). That means that the first device should
483
+ have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
484
+ following number of attention modules:
485
+ - gpt2: 12
486
+ - gpt2-medium: 24
487
+ - gpt2-large: 36
488
+ - gpt2-xl: 48
489
+ Example::
490
+ # Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
491
+ model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
492
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
493
+ 1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
494
+ 2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
495
+ 3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
496
+ model.parallelize(device_map)
497
+ """
498
+ DEPARALLELIZE_DOCSTRING = r"""
499
+ Moves the model to cpu from a model parallel state.
500
+ Example::
501
+ # On a 4 GPU machine with gpt2-large:
502
+ model = GPT2LMHeadModel.from_pretrained('gpt2-large')
503
+ device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
504
+ 1: [8, 9, 10, 11, 12, 13, 14, 15],
505
+ 2: [16, 17, 18, 19, 20, 21, 22, 23],
506
+ 3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
507
+ model.parallelize(device_map) # Splits the model across several devices
508
+ model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
509
+ """
510
+
511
+
512
+ @add_start_docstrings(
513
+ "The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
514
+ GPT2_START_DOCSTRING,
515
+ )
516
+ class GPT2Model(GPT2PreTrainedModel):
517
+ def __init__(self, config):
518
+ super().__init__(config)
519
+
520
+ self.wte = nn.Embedding(config.vocab_size, config.n_embd)
521
+ # self.wpe = nn.Embedding(config.n_positions, config.n_embd)
522
+ self.drop = nn.Dropout(config.embd_pdrop)
523
+ self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
524
+ self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
525
+
526
+ self.init_weights()
527
+ # Model parallel
528
+ self.model_parallel = False
529
+ self.device_map = None
530
+
531
+ self.use_layers = None
532
+
533
+ def set_layers(self, num_layers):
534
+ assert 1 <= num_layers <= len(self.h)
535
+ if num_layers is not None:
536
+ num_layers -= 1
537
+ self.use_layers = num_layers
538
+
539
+ @add_start_docstrings(PARALLELIZE_DOCSTRING)
540
+ def parallelize(self, device_map=None):
541
+ # Check validity of device_map
542
+ self.device_map = (
543
+ get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
544
+ )
545
+ assert_device_map(self.device_map, len(self.h))
546
+ self.model_parallel = True
547
+ self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
548
+ self.last_device = "cuda:" + str(max(self.device_map.keys()))
549
+ self.wte = self.wte.to(self.first_device)
550
+ self.wpe = self.wpe.to(self.first_device)
551
+ # Load onto devices
552
+ for k, v in self.device_map.items():
553
+ for block in v:
554
+ cuda_device = "cuda:" + str(k)
555
+ self.h[block] = self.h[block].to(cuda_device)
556
+ # ln_f to last
557
+ self.ln_f = self.ln_f.to(self.last_device)
558
+
559
+ @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
560
+ def deparallelize(self):
561
+ self.model_parallel = False
562
+ self.device_map = None
563
+ self.first_device = "cpu"
564
+ self.last_device = "cpu"
565
+ self.wte = self.wte.to("cpu")
566
+ self.wpe = self.wpe.to("cpu")
567
+ for index in range(len(self.h)):
568
+ self.h[index] = self.h[index].to("cpu")
569
+ self.ln_f = self.ln_f.to("cpu")
570
+ torch.cuda.empty_cache()
571
+
572
+ def get_input_embeddings(self):
573
+ return self.wte
574
+
575
+ def set_input_embeddings(self, new_embeddings):
576
+ self.wte = new_embeddings
577
+
578
+ def _prune_heads(self, heads_to_prune):
579
+ """
580
+ Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
581
+ """
582
+ for layer, heads in heads_to_prune.items():
583
+ self.h[layer].attn.prune_heads(heads)
584
+
585
+ @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
586
+ @add_code_sample_docstrings(
587
+ processor_class=_TOKENIZER_FOR_DOC,
588
+ checkpoint="gpt2",
589
+ output_type=BaseModelOutputWithPastAndCrossAttentions,
590
+ config_class=_CONFIG_FOR_DOC,
591
+ )
592
+ def forward(
593
+ self,
594
+ input_ids=None,
595
+ past_key_values=None,
596
+ attention_mask=None,
597
+ token_type_ids=None,
598
+ position_ids=None,
599
+ head_mask=None,
600
+ inputs_embeds=None,
601
+ encoder_hidden_states=None,
602
+ encoder_attention_mask=None,
603
+ use_cache=None,
604
+ output_attentions=None,
605
+ output_hidden_states=None,
606
+ return_dict=None,
607
+ ):
608
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
609
+ output_hidden_states = (
610
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
611
+ )
612
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
613
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
614
+
615
+ if input_ids is not None and inputs_embeds is not None:
616
+ raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
617
+ elif input_ids is not None:
618
+ input_shape = input_ids.size()
619
+ input_ids = input_ids.view(-1, input_shape[-1])
620
+ batch_size = input_ids.shape[0]
621
+ elif inputs_embeds is not None:
622
+ input_shape = inputs_embeds.size()[:-1]
623
+ batch_size = inputs_embeds.shape[0]
624
+ else:
625
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
626
+
627
+ if token_type_ids is not None:
628
+ token_type_ids = token_type_ids.view(-1, input_shape[-1])
629
+ if position_ids is not None:
630
+ position_ids = position_ids.view(-1, input_shape[-1])
631
+
632
+ if past_key_values is None:
633
+ past_length = 0
634
+ past_key_values = [None] * len(self.h)
635
+ else:
636
+ past_length = past_key_values[0][0].size(-2)
637
+ if position_ids is None:
638
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
639
+ position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
640
+ position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
641
+
642
+ # Attention mask.
643
+ if attention_mask is not None:
644
+ assert batch_size > 0, "batch_size has to be defined and > 0"
645
+ attention_mask = attention_mask.view(batch_size, -1)
646
+ # We create a 3D attention mask from a 2D tensor mask.
647
+ # Sizes are [batch_size, 1, 1, to_seq_length]
648
+ # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
649
+ # this attention mask is more simple than the triangular masking of causal attention
650
+ # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
651
+ attention_mask = attention_mask[:, None, None, :]
652
+
653
+ # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
654
+ # masked positions, this operation will create a tensor which is 0.0 for
655
+ # positions we want to attend and -10000.0 for masked positions.
656
+ # Since we are adding it to the raw scores before the softmax, this is
657
+ # effectively the same as removing these entirely.
658
+ attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
659
+ attention_mask = (1.0 - attention_mask) * -10000.0
660
+
661
+ # If a 2D ou 3D attention mask is provided for the cross-attention
662
+ # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
663
+ if self.config.add_cross_attention and encoder_hidden_states is not None:
664
+ encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
665
+ encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
666
+ if encoder_attention_mask is None:
667
+ encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
668
+ encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
669
+ else:
670
+ encoder_attention_mask = None
671
+
672
+ # Prepare head mask if needed
673
+ # 1.0 in head_mask indicate we keep the head
674
+ # attention_probs has shape bsz x n_heads x N x N
675
+ # head_mask has shape n_layer x batch x n_heads x N x N
676
+ head_mask = self.get_head_mask(head_mask, self.config.n_layer)
677
+
678
+ if inputs_embeds is None:
679
+ inputs_embeds = self.wte(input_ids)
680
+ # position_embeds = self.wpe(position_ids)
681
+ hidden_states = inputs_embeds # + position_embeds
682
+
683
+ if token_type_ids is not None:
684
+ token_type_embeds = self.wte(token_type_ids)
685
+ hidden_states = hidden_states + token_type_embeds
686
+
687
+ hidden_states = self.drop(hidden_states)
688
+
689
+ output_shape = input_shape + (hidden_states.size(-1),)
690
+
691
+ presents = () if use_cache else None
692
+ all_self_attentions = () if output_attentions else None
693
+ all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
694
+ all_hidden_states = () if output_hidden_states else None
695
+ for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
696
+
697
+ if self.use_layers is not None and i >= self.use_layers:
698
+ break
699
+
700
+ # Model parallel
701
+ if self.model_parallel:
702
+ torch.cuda.set_device(hidden_states.device)
703
+ # Ensure layer_past is on same device as hidden_states (might not be correct)
704
+ if layer_past is not None:
705
+ layer_past = layer_past.to(hidden_states.device)
706
+ # Ensure that attention_mask is always on the same device as hidden_states
707
+ if attention_mask is not None:
708
+ attention_mask = attention_mask.to(hidden_states.device)
709
+ if isinstance(head_mask, torch.Tensor):
710
+ head_mask = head_mask.to(hidden_states.device)
711
+ if output_hidden_states:
712
+ all_hidden_states = all_hidden_states + (hidden_states,)
713
+
714
+ if getattr(self.config, "gradient_checkpointing", False):
715
+
716
+ def create_custom_forward(module):
717
+ def custom_forward(*inputs):
718
+ # checkpointing only works with tuple returns, not with lists
719
+ return tuple(output for output in module(*inputs, use_cache, output_attentions))
720
+
721
+ return custom_forward
722
+
723
+ outputs = torch.utils.checkpoint.checkpoint(
724
+ create_custom_forward(block),
725
+ hidden_states,
726
+ layer_past,
727
+ attention_mask,
728
+ head_mask[i],
729
+ encoder_hidden_states,
730
+ encoder_attention_mask,
731
+ )
732
+ else:
733
+ outputs = block(
734
+ hidden_states,
735
+ layer_past=layer_past,
736
+ attention_mask=attention_mask,
737
+ head_mask=head_mask[i],
738
+ encoder_hidden_states=encoder_hidden_states,
739
+ encoder_attention_mask=encoder_attention_mask,
740
+ use_cache=use_cache,
741
+ output_attentions=output_attentions,
742
+ )
743
+
744
+ hidden_states, present = outputs[:2]
745
+ if use_cache is True:
746
+ presents = presents + (present,)
747
+
748
+ if output_attentions:
749
+ all_self_attentions = all_self_attentions + (outputs[2],)
750
+ if self.config.add_cross_attention:
751
+ all_cross_attentions = all_cross_attentions + (outputs[3],)
752
+
753
+ # Model Parallel: If it's the last layer for that device, put things on the next device
754
+ if self.model_parallel:
755
+ for k, v in self.device_map.items():
756
+ if i == v[-1] and "cuda:" + str(k) != self.last_device:
757
+ hidden_states = hidden_states.to("cuda:" + str(k + 1))
758
+
759
+ hidden_states = self.ln_f(hidden_states)
760
+
761
+ hidden_states = hidden_states.view(*output_shape)
762
+ # Add last hidden state
763
+ if output_hidden_states:
764
+ all_hidden_states = all_hidden_states + (hidden_states,)
765
+
766
+ if not return_dict:
767
+ return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
768
+
769
+ return BaseModelOutputWithPastAndCrossAttentions(
770
+ last_hidden_state=hidden_states,
771
+ past_key_values=presents,
772
+ hidden_states=all_hidden_states,
773
+ attentions=all_self_attentions,
774
+ cross_attentions=all_cross_attentions,
775
+ )
third_party/AnyBimanual/agents/peract_bimanual/visual_aligner.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class VisualAligner(nn.Module):
6
+ def __init__(self, input_dim=128, hidden_dim=256, mask_dim=128):
7
+ super(VisualAligner, self).__init__()
8
+
9
+ self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
10
+
11
+ self.conv_res1 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
12
+ self.conv_res2 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
13
+
14
+ self.conv2_right = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
15
+ self.conv2_left = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
16
+
17
+ self.activation = nn.ReLU()
18
+
19
+ def forward(self, ins):
20
+ ins = ins.transpose(1, 2)
21
+
22
+ features = self.activation(self.conv1(ins))
23
+
24
+ residual = features
25
+ features = self.activation(self.conv_res1(features))
26
+ features = self.conv_res2(features)
27
+ features = features + residual
28
+
29
+ mask_right = self.activation(self.conv2_right(features))
30
+ mask_left = self.activation(self.conv2_left(features))
31
+
32
+ mask_right = mask_right.transpose(1, 2)
33
+ mask_left = mask_left.transpose(1, 2)
34
+ ins = ins.transpose(1, 2)
35
+
36
+ masked_ins1 = ins * mask_right
37
+ masked_ins2 = ins * mask_left
38
+
39
+ return masked_ins1, masked_ins2
third_party/AnyBimanual/agents/replay_utils.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import logging
3
+ from typing import List
4
+ import os
5
+ import numpy as np
6
+ from rlbench.backend.observation import Observation
7
+ from rlbench.observation_config import ObservationConfig
8
+ import rlbench.utils as rlbench_utils
9
+ from rlbench.demo import Demo
10
+ from yarr.replay_buffer.replay_buffer import ReplayBuffer
11
+
12
+ from helpers import demo_loading_utils, utils
13
+ from helpers import observation_utils
14
+ from helpers.clip.core.clip import tokenize
15
+
16
+
17
+ from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement
18
+ from yarr.replay_buffer.replay_buffer import ReplayElement
19
+ from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
20
+
21
+
22
+ import torch
23
+ from torch.multiprocessing import Process, Value, Manager
24
+ from helpers.clip.core.clip import build_model, load_clip
25
+ from omegaconf import DictConfig
26
+
27
+
28
+ REWARD_SCALE = 100.0
29
+ LOW_DIM_SIZE = 4
30
+
31
+
32
+ def create_replay(cfg, replay_path):
33
+
34
+ if cfg.method.robot_name == "bimanual":
35
+ return create_bimanual_replay(
36
+ cfg.replay.batch_size,
37
+ cfg.replay.timesteps,
38
+ cfg.replay.prioritisation,
39
+ cfg.replay.task_uniform,
40
+ replay_path if cfg.replay.use_disk else None,
41
+ cfg.rlbench.cameras,
42
+ cfg.method.voxel_sizes,
43
+ cfg.rlbench.camera_resolution,
44
+ )
45
+ else:
46
+ return create_unimanual_replay(
47
+ cfg.replay.batch_size,
48
+ cfg.replay.timesteps,
49
+ cfg.replay.prioritisation,
50
+ cfg.replay.task_uniform,
51
+ replay_path if cfg.replay.use_disk else None,
52
+ cfg.rlbench.cameras,
53
+ cfg.method.voxel_sizes,
54
+ cfg.rlbench.camera_resolution,
55
+ )
56
+
57
+
58
+
59
+ def create_bimanual_replay(
60
+ batch_size: int,
61
+ timesteps: int,
62
+ prioritisation: bool,
63
+ task_uniform: bool,
64
+ save_dir: str,
65
+ cameras: list,
66
+ voxel_sizes,
67
+ image_size=[128, 128],
68
+ replay_size=3e5,
69
+ ):
70
+ trans_indicies_size = 3 * len(voxel_sizes)
71
+ rot_and_grip_indicies_size = 3 + 1
72
+ gripper_pose_size = 7
73
+ ignore_collisions_size = 1
74
+ max_token_seq_len = 77
75
+ lang_feat_dim = 1024
76
+ lang_emb_dim = 512
77
+
78
+ # low_dim_state
79
+ observation_elements = []
80
+ observation_elements.append(
81
+ ObservationElement("right_low_dim_state", (LOW_DIM_SIZE,), np.float32)
82
+ )
83
+ observation_elements.append(
84
+ ObservationElement("left_low_dim_state", (LOW_DIM_SIZE,), np.float32)
85
+ )
86
+
87
+ # rgb, depth, point cloud, intrinsics, extrinsics
88
+ for cname in cameras:
89
+ observation_elements.append(
90
+ # color, height, width
91
+ ObservationElement(
92
+ "%s_rgb" % cname,
93
+ (
94
+ 3,
95
+ image_size[1],
96
+ image_size[0],
97
+ ),
98
+ np.float32,
99
+ )
100
+ )
101
+ observation_elements.append(
102
+ ObservationElement("%s_point_cloud" % cname, (3, image_size[1], image_size[0]), np.float16)
103
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
104
+ observation_elements.append(
105
+ ObservationElement(
106
+ "%s_camera_extrinsics" % cname,
107
+ (
108
+ 4,
109
+ 4,
110
+ ),
111
+ np.float32,
112
+ )
113
+ )
114
+ observation_elements.append(
115
+ ObservationElement(
116
+ "%s_camera_intrinsics" % cname,
117
+ (
118
+ 3,
119
+ 3,
120
+ ),
121
+ np.float32,
122
+ )
123
+ )
124
+
125
+ # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
126
+ for robot_name in ["right", "left"]:
127
+ observation_elements.extend(
128
+ [
129
+ ReplayElement(
130
+ f"{robot_name}_trans_action_indicies",
131
+ (trans_indicies_size,),
132
+ np.int32,
133
+ ),
134
+ ReplayElement(
135
+ f"{robot_name}_rot_grip_action_indicies",
136
+ (rot_and_grip_indicies_size,),
137
+ np.int32,
138
+ ),
139
+ ReplayElement(
140
+ f"{robot_name}_ignore_collisions",
141
+ (ignore_collisions_size,),
142
+ np.int32,
143
+ ),
144
+ ReplayElement(
145
+ f"{robot_name}_gripper_pose", (gripper_pose_size,), np.float32
146
+ ),
147
+ ]
148
+ )
149
+
150
+ observation_elements.extend(
151
+ [
152
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
153
+ ReplayElement(
154
+ "lang_token_embs",
155
+ (
156
+ max_token_seq_len,
157
+ lang_emb_dim,
158
+ ),
159
+ np.float32,
160
+ ), # extracted from CLIP's language encoder
161
+ ReplayElement("task", (), str),
162
+ ReplayElement(
163
+ "lang_goal", (1,), object
164
+ ), # language goal string for debugging and visualization
165
+ ]
166
+ )
167
+
168
+ extra_replay_elements = [
169
+ ReplayElement("demo", (), bool),
170
+ ]
171
+
172
+ replay_buffer = TaskUniformReplayBuffer(
173
+ save_dir=save_dir,
174
+ batch_size=batch_size,
175
+ timesteps=timesteps,
176
+ replay_capacity=int(replay_size),
177
+ action_shape=(8 * 2,),
178
+ action_dtype=np.float32,
179
+ reward_shape=(),
180
+ reward_dtype=np.float32,
181
+ update_horizon=1,
182
+ observation_elements=observation_elements,
183
+ extra_replay_elements=extra_replay_elements,
184
+ )
185
+ return replay_buffer
186
+
187
+ def create_unimanual_replay(
188
+ batch_size: int,
189
+ timesteps: int,
190
+ prioritisation: bool,
191
+ task_uniform: bool,
192
+ save_dir: str,
193
+ cameras: list,
194
+ voxel_sizes,
195
+ image_size=[128, 128],
196
+ replay_size=3e5,
197
+ ):
198
+ trans_indicies_size = 3 * len(voxel_sizes)
199
+ rot_and_grip_indicies_size = 3 + 1
200
+ gripper_pose_size = 7
201
+ ignore_collisions_size = 1
202
+ max_token_seq_len = 77
203
+ lang_feat_dim = 1024
204
+ lang_emb_dim = 512
205
+
206
+ # low_dim_state
207
+ observation_elements = []
208
+ observation_elements.append(
209
+ ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
210
+ )
211
+
212
+ # rgb, depth, point cloud, intrinsics, extrinsics
213
+ for cname in cameras:
214
+ observation_elements.append(
215
+ ObservationElement(
216
+ "%s_rgb" % cname,
217
+ (
218
+ 3,
219
+ *image_size,
220
+ ),
221
+ np.float32,
222
+ )
223
+ )
224
+ observation_elements.append(
225
+ ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
226
+ ) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
227
+ observation_elements.append(
228
+ ObservationElement(
229
+ "%s_camera_extrinsics" % cname,
230
+ (
231
+ 4,
232
+ 4,
233
+ ),
234
+ np.float32,
235
+ )
236
+ )
237
+ observation_elements.append(
238
+ ObservationElement(
239
+ "%s_camera_intrinsics" % cname,
240
+ (
241
+ 3,
242
+ 3,
243
+ ),
244
+ np.float32,
245
+ )
246
+ )
247
+
248
+ # discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
249
+ observation_elements.extend(
250
+ [
251
+ ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
252
+ ReplayElement(
253
+ "rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
254
+ ),
255
+ ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
256
+ ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
257
+ ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
258
+ ReplayElement(
259
+ "lang_token_embs",
260
+ (
261
+ max_token_seq_len,
262
+ lang_emb_dim,
263
+ ),
264
+ np.float32,
265
+ ), # extracted from CLIP's language encoder
266
+ ReplayElement("task", (), str),
267
+ ReplayElement(
268
+ "lang_goal", (1,), object
269
+ ), # language goal string for debugging and visualization
270
+ ]
271
+ )
272
+
273
+ extra_replay_elements = [
274
+ ReplayElement("demo", (), bool),
275
+ ]
276
+
277
+ replay_buffer = TaskUniformReplayBuffer(
278
+ save_dir=save_dir,
279
+ batch_size=batch_size,
280
+ timesteps=timesteps,
281
+ replay_capacity=int(replay_size),
282
+ action_shape=(8,),
283
+ action_dtype=np.float32,
284
+ reward_shape=(),
285
+ reward_dtype=np.float32,
286
+ update_horizon=1,
287
+ observation_elements=observation_elements,
288
+ extra_replay_elements=extra_replay_elements,
289
+ )
290
+ return replay_buffer
291
+
292
+
293
+
294
+ def _get_action(
295
+ obs_tp1: Observation,
296
+ obs_tm1: Observation,
297
+ rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
298
+ voxel_sizes: List[int],
299
+ bounds_offset: List[float],
300
+ rotation_resolution: int,
301
+ crop_augmentation: bool,
302
+ ):
303
+ quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
304
+ if quat[-1] < 0:
305
+ quat = -quat
306
+ disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
307
+ disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution)
308
+
309
+ attention_coordinate = obs_tp1.gripper_pose[:3]
310
+ trans_indicies, attention_coordinates = [], []
311
+ bounds = np.array(rlbench_scene_bounds)
312
+ ignore_collisions = int(obs_tm1.ignore_collisions)
313
+ for depth, vox_size in enumerate(
314
+ voxel_sizes
315
+ ): # only single voxelization-level is used in PerAct
316
+ if depth > 0:
317
+ if crop_augmentation:
318
+ shift = bounds_offset[depth - 1] * 0.75
319
+ attention_coordinate += np.random.uniform(-shift, shift, size=(3,))
320
+ bounds = np.concatenate(
321
+ [
322
+ attention_coordinate - bounds_offset[depth - 1],
323
+ attention_coordinate + bounds_offset[depth - 1],
324
+ ]
325
+ )
326
+ index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
327
+ trans_indicies.extend(index.tolist())
328
+ res = (bounds[3:] - bounds[:3]) / vox_size
329
+ attention_coordinate = bounds[:3] + res * index
330
+ attention_coordinates.append(attention_coordinate)
331
+
332
+ rot_and_grip_indicies = disc_rot.tolist()
333
+ grip = float(obs_tp1.gripper_open)
334
+ rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
335
+ return (
336
+ trans_indicies,
337
+ rot_and_grip_indicies,
338
+ ignore_collisions,
339
+ np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
340
+ attention_coordinates,
341
+ )
342
+
343
+
344
+ def _add_keypoints_to_replay(
345
+ cfg: DictConfig,
346
+ task: str,
347
+ replay: ReplayBuffer,
348
+ inital_obs: Observation,
349
+ demo: Demo,
350
+ episode_keypoints: List[int],
351
+
352
+ description: str = "",
353
+ clip_model=None,
354
+ device="cpu",
355
+ ):
356
+
357
+ cameras = cfg.rlbench.cameras
358
+ rlbench_scene_bounds = cfg.rlbench.scene_bounds
359
+ voxel_sizes = cfg.method.voxel_sizes
360
+ bounds_offset = cfg.method.bounds_offset
361
+ rotation_resolution = cfg.method.rotation_resolution
362
+ crop_augmentation = cfg.method.crop_augmentation
363
+ robot_name = cfg.method.robot_name
364
+
365
+ prev_action = None
366
+ obs = inital_obs
367
+
368
+ for k, keypoint in enumerate(episode_keypoints):
369
+ obs_tp1 = demo[keypoint]
370
+ obs_tm1 = demo[max(0, keypoint - 1)]
371
+
372
+ if obs_tp1.is_bimanual and robot_name == "bimanual":
373
+ #assert isinstance(obs_tp1, BimanualObservation)
374
+ (
375
+ right_trans_indicies,
376
+ right_rot_grip_indicies,
377
+ right_ignore_collisions,
378
+ right_action,
379
+ right_attention_coordinates,
380
+ ) = _get_action(
381
+ obs_tp1.right,
382
+ obs_tm1.right,
383
+ rlbench_scene_bounds,
384
+ voxel_sizes,
385
+ bounds_offset,
386
+ rotation_resolution,
387
+ crop_augmentation,
388
+ )
389
+
390
+ (
391
+ left_trans_indicies,
392
+ left_rot_grip_indicies,
393
+ left_ignore_collisions,
394
+ left_action,
395
+ left_attention_coordinates,
396
+ ) = _get_action(
397
+ obs_tp1.left,
398
+ obs_tm1.left,
399
+ rlbench_scene_bounds,
400
+ voxel_sizes,
401
+ bounds_offset,
402
+ rotation_resolution,
403
+ crop_augmentation,
404
+ )
405
+
406
+ action = np.append(right_action, left_action)
407
+
408
+ right_ignore_collisions = np.array([right_ignore_collisions])
409
+ left_ignore_collisions = np.array([left_ignore_collisions])
410
+
411
+ elif robot_name == "unimanual":
412
+ (
413
+ trans_indicies,
414
+ rot_grip_indicies,
415
+ ignore_collisions,
416
+ action,
417
+ attention_coordinates,
418
+ ) = _get_action(
419
+ obs_tp1,
420
+ obs_tm1,
421
+ rlbench_scene_bounds,
422
+ voxel_sizes,
423
+ bounds_offset,
424
+ rotation_resolution,
425
+ crop_augmentation,
426
+ )
427
+ gripper_pose = obs_tp1.gripper_pose
428
+ elif obs_tp1.is_bimanual and robot_name == "right":
429
+ (
430
+ trans_indicies,
431
+ rot_grip_indicies,
432
+ ignore_collisions,
433
+ action,
434
+ attention_coordinates,
435
+ ) = _get_action(
436
+ obs_tp1.right,
437
+ obs_tm1.right,
438
+ rlbench_scene_bounds,
439
+ voxel_sizes,
440
+ bounds_offset,
441
+ rotation_resolution,
442
+ crop_augmentation,
443
+ )
444
+ gripper_pose = obs_tp1.right.gripper_pose
445
+ elif obs_tp1.is_bimanual and robot_name == "left":
446
+ (
447
+ trans_indicies,
448
+ rot_grip_indicies,
449
+ ignore_collisions,
450
+ action,
451
+ attention_coordinates,
452
+ ) = _get_action(
453
+ obs_tp1.left,
454
+ obs_tm1.left,
455
+ rlbench_scene_bounds,
456
+ voxel_sizes,
457
+ bounds_offset,
458
+ rotation_resolution,
459
+ crop_augmentation,
460
+ )
461
+ gripper_pose = obs_tp1.left.gripper_pose
462
+ else:
463
+ logging.error("Invalid robot name %s", cfg.method.robot_name)
464
+ raise Exception("Invalid robot name.")
465
+
466
+ terminal = k == len(episode_keypoints) - 1
467
+ reward = float(terminal) * REWARD_SCALE if terminal else 0
468
+
469
+ obs_dict = observation_utils.extract_obs(
470
+ obs,
471
+ t=k,
472
+ prev_action=prev_action,
473
+ cameras=cameras,
474
+ episode_length=cfg.rlbench.episode_length,
475
+ robot_name=robot_name
476
+ )
477
+ tokens = tokenize([description]).numpy()
478
+ token_tensor = torch.from_numpy(tokens).to(device)
479
+ sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor)
480
+ obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
481
+ obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
482
+
483
+ prev_action = np.copy(action)
484
+
485
+ others = {"demo": True}
486
+ if robot_name == "bimanual":
487
+ final_obs = {
488
+ "right_trans_action_indicies": right_trans_indicies,
489
+ "right_rot_grip_action_indicies": right_rot_grip_indicies,
490
+ "right_gripper_pose": obs_tp1.right.gripper_pose,
491
+ "left_trans_action_indicies": left_trans_indicies,
492
+ "left_rot_grip_action_indicies": left_rot_grip_indicies,
493
+ "left_gripper_pose": obs_tp1.left.gripper_pose,
494
+ "task": task,
495
+ "lang_goal": np.array([description], dtype=object),
496
+ }
497
+ else:
498
+ final_obs = {
499
+ "trans_action_indicies": trans_indicies,
500
+ "rot_grip_action_indicies": rot_grip_indicies,
501
+ "gripper_pose": gripper_pose,
502
+ "task": task,
503
+ "lang_goal": np.array([description], dtype=object),
504
+ }
505
+
506
+ others.update(final_obs)
507
+ others.update(obs_dict)
508
+
509
+ timeout = False
510
+ replay.add(action, reward, terminal, timeout, **others)
511
+ obs = obs_tp1
512
+
513
+ # final step
514
+ obs_dict_tp1 = observation_utils.extract_obs(
515
+ obs_tp1,
516
+ t=k + 1,
517
+ prev_action=prev_action,
518
+ cameras=cameras,
519
+ episode_length=cfg.rlbench.episode_length,
520
+ robot_name=cfg.method.robot_name
521
+ )
522
+ obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
523
+ obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
524
+
525
+ obs_dict_tp1.pop("wrist_world_to_cam", None)
526
+ obs_dict_tp1.update(final_obs)
527
+ replay.add_final(**obs_dict_tp1)
528
+
529
+ def check_if_replay_exists(task: str, d_idx: int, replay_path: str):
530
+ replay_file = os.path.join(replay_path, f"{task}_replay_{d_idx}.pkl")
531
+ return os.path.exists(replay_file)
532
+
533
+ def fill_replay(
534
+ cfg: DictConfig,
535
+ obs_config: ObservationConfig,
536
+ rank: int,
537
+ replay: ReplayBuffer,
538
+ task: str,
539
+ clip_model=None,
540
+ device="cpu",
541
+ ):
542
+
543
+ num_demos=cfg.rlbench.demos
544
+ demo_augmentation=cfg.method.demo_augmentation
545
+ demo_augmentation_every_n=cfg.method.demo_augmentation_every_n
546
+ keypoint_method=cfg.method.keypoint_method
547
+
548
+
549
+ if clip_model is None:
550
+ model, _ = load_clip("RN50", jit=False, device=device)
551
+ clip_model = build_model(model.state_dict())
552
+ clip_model.to(device)
553
+ del model
554
+
555
+ task_folder = cfg.replay.task_folder
556
+ replay_path = os.path.join(
557
+ cfg.replay.path, task_folder
558
+ )
559
+ logging.debug("Filling %s replay ..." % task)
560
+ for d_idx in range(num_demos):
561
+ # load demo from disk
562
+ if check_if_replay_exists(task, d_idx, replay_path):
563
+ logging.info(f"Replay for demo {d_idx} already exists, skipping...")
564
+ continue
565
+ demo = rlbench_utils.get_stored_demos(
566
+ amount=1,
567
+ image_paths=False,
568
+ dataset_root=cfg.rlbench.demo_path,
569
+ variation_number=-1,
570
+ task_name=task,
571
+ obs_config=obs_config,
572
+ random_selection=False,
573
+ from_episode_number=d_idx,
574
+ )[0]
575
+
576
+ descs = demo._observations[0].misc["descriptions"]
577
+
578
+ # extract keypoints (a.k.a keyframes)
579
+ episode_keypoints = demo_loading_utils.keypoint_discovery(
580
+ demo, method=keypoint_method
581
+ )
582
+
583
+ if rank == 0:
584
+ logging.info(
585
+ f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
586
+ )
587
+
588
+ for i in range(len(demo) - 1):
589
+ if not demo_augmentation and i > 0:
590
+ break
591
+ if i % demo_augmentation_every_n != 0:
592
+ continue
593
+
594
+ obs = demo[i]
595
+ desc = descs[0]
596
+ # if our starting point is past one of the keypoints, then remove it
597
+ while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
598
+ episode_keypoints = episode_keypoints[1:]
599
+ if len(episode_keypoints) == 0:
600
+ break
601
+ _add_keypoints_to_replay(
602
+ cfg,
603
+ task,
604
+ replay,
605
+ obs,
606
+ demo,
607
+ episode_keypoints,
608
+ description=desc,
609
+ clip_model=clip_model,
610
+ device=device,
611
+ )
612
+ logging.debug("Replay %s filled with demos." % task)
613
+
614
+
615
+ def fill_multi_task_replay(
616
+ cfg: DictConfig,
617
+ obs_config: ObservationConfig,
618
+ rank: int,
619
+ replay: ReplayBuffer,
620
+ tasks: List[str],
621
+ clip_model=None,
622
+ ):
623
+
624
+ tasks = cfg.rlbench.tasks
625
+
626
+ manager = Manager()
627
+ store = manager.dict()
628
+
629
+ # create a MP dict for storing indicies
630
+ # TODO(mohit): this shouldn't be initialized here
631
+ del replay._task_idxs
632
+ task_idxs = manager.dict()
633
+ replay._task_idxs = task_idxs
634
+ replay._create_storage(store)
635
+ replay.add_count = Value("i", 0)
636
+
637
+ # fill replay buffer in parallel across tasks
638
+ max_parallel_processes = cfg.replay.max_parallel_processes
639
+ processes = []
640
+ n = np.arange(len(tasks))
641
+ split_n = utils.split_list(n, max_parallel_processes)
642
+ for split in split_n:
643
+ for e_idx, task_idx in enumerate(split):
644
+ task = tasks[int(task_idx)]
645
+ model_device = torch.device(
646
+ "cuda:%s" % (e_idx % torch.cuda.device_count())
647
+ if torch.cuda.is_available()
648
+ else "cpu"
649
+ )
650
+ p = Process(
651
+ target=fill_replay,
652
+ args=(
653
+ cfg,
654
+ obs_config,
655
+ rank,
656
+ replay,
657
+ task,
658
+ clip_model,
659
+ model_device
660
+ ),
661
+ )
662
+
663
+ p.start()
664
+ processes.append(p)
665
+
666
+ for p in processes:
667
+ p.join()
third_party/AnyBimanual/agents/rvt/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """RVT package marker.
2
+
3
+ Keep package import side-effect free so downstream code can import the
4
+ visual stack without pulling in the full training launcher and its
5
+ optional dependencies.
6
+ """
third_party/AnyBimanual/agents/rvt/launch_utils.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+ import torch
4
+ import numpy as np
5
+
6
+ from omegaconf import DictConfig
7
+
8
+ from yarr.agents.agent import Agent
9
+ from yarr.agents.agent import ActResult
10
+ from yarr.agents.agent import Summary
11
+ from yarr.agents.agent import ScalarSummary
12
+ import wandb
13
+ from torch.nn.parallel import DistributedDataParallel as DDP
14
+ import pickle
15
+ from helpers.preprocess_agent import PreprocessAgent
16
+ from agents.rvt.rvt.models.skill_manager import SkillManager
17
+ from agents.rvt.rvt.models.visual_aligner import VisualAligner
18
+
19
+ from agents.rvt.rvt.mvt.mvt import MVT
20
+ from agents.rvt.rvt.models import rvt_agent
21
+ from agents.rvt.rvt.utils.peract_utils import (
22
+ CAMERAS,
23
+ SCENE_BOUNDS,
24
+ IMAGE_SIZE,
25
+ DATA_FOLDER,
26
+ )
27
+
28
+
29
+ import agents.rvt.rvt.config as exp_cfg_mod
30
+ import agents.rvt.rvt.models.rvt_agent as rvt_agent
31
+ import agents.rvt.rvt.mvt.config as mvt_cfg_mod
32
+ import os
33
+
34
+ def create_agent(cfg: DictConfig):
35
+
36
+ exp_cfg = exp_cfg_mod.get_cfg_defaults()
37
+ exp_cfg.bs = cfg.replay.batch_size
38
+ exp_cfg.tasks = ','.join(cfg.rlbench.tasks)
39
+
40
+ exp_cfg.freeze()
41
+
42
+ mvt_cfg = mvt_cfg_mod.get_cfg_defaults()
43
+ mvt_cfg.proprio_dim = cfg.method.low_dim_size
44
+ mvt_cfg.freeze()
45
+
46
+ current_dir = os.path.dirname(os.path.abspath(__file__))
47
+ pkl_path = os.path.join(current_dir, "../../lang_token.pkl")
48
+ pkl_path = os.path.abspath(pkl_path)
49
+ with open(pkl_path, "rb") as f:
50
+ embeddings_dict = pickle.load(f)
51
+ flattened_embeddings = []
52
+ for key in embeddings_dict.keys():
53
+ embedding = torch.tensor(embeddings_dict[key])
54
+ flattened_embedding = embedding.view(-1)
55
+ flattened_embeddings.append(flattened_embedding)
56
+ embeddings_matrix = torch.stack(flattened_embeddings)
57
+ skill_manager = SkillManager(num_classes=18,embedding_matrix=embeddings_matrix)
58
+ visual_aligner = VisualAligner()
59
+ agent = RVTAgentWrapper(cfg.framework.checkpoint_name_prefix, cfg.rlbench, mvt_cfg, exp_cfg, skill_manager, visual_aligner)
60
+
61
+
62
+ preprocess_agent = PreprocessAgent(pose_agent=agent)
63
+ return preprocess_agent
64
+
65
+
66
+
67
+ class RVTAgentWrapper(Agent):
68
+
69
+ def __init__(self, checkpoint_name_prefix, rlbench_cfg, mvt_cfg, exp_cfg, skill_manager, visual_aligner):
70
+ self._checkpoint_filename = f"{checkpoint_name_prefix}.pt"
71
+ self.rvt_agent = None
72
+ self.rlbench_cfg = rlbench_cfg
73
+ self.mvt_cfg = mvt_cfg
74
+ self.exp_cfg = exp_cfg
75
+ self._summaries = {}
76
+ self.skill_manager = skill_manager
77
+ self.visual_aligner = visual_aligner
78
+
79
+ def build(self, training: bool, device=None) -> None:
80
+
81
+ import torch
82
+ torch.cuda.set_device(device)
83
+ torch.cuda.empty_cache()
84
+ self._device = device
85
+ if isinstance(device, int):
86
+ device = f"cuda:{device}"
87
+
88
+ rvt = MVT(
89
+ renderer_device=device,
90
+ **self.mvt_cfg,
91
+ )
92
+ rvt = rvt.to(device)
93
+
94
+ if training:
95
+ rvt = DDP(rvt, device_ids=[device])
96
+
97
+ self.rvt_agent = rvt_agent.RVTAgent(
98
+ network=rvt,
99
+ #image_resolution=self.rlbench_cfg.camera_resolution,
100
+ skill_manager=self.skill_manager,
101
+ visual_aligner=self.visual_aligner,
102
+ stage_two=False,
103
+ add_lang=self.mvt_cfg.add_lang,
104
+ scene_bounds=self.rlbench_cfg.scene_bounds,
105
+ cameras=self.rlbench_cfg.cameras,
106
+ log_dir="/tmp/eval_run",
107
+ **self.exp_cfg.peract,
108
+ **self.exp_cfg.rvt,
109
+
110
+ )
111
+
112
+ self.rvt_agent.build(training, device)
113
+
114
+ def update(self, step: int, replay_sample: dict) -> dict:
115
+ for k, v in replay_sample.items():
116
+ replay_sample[k] = v.unsqueeze(1)
117
+ # RVT is based on the PerAct's Colab version.
118
+ replay_sample["lang_goal_embs"] = replay_sample["lang_token_embs"]
119
+ replay_sample["tasks"] = self.exp_cfg.tasks.split(',')
120
+
121
+ update_dict = self.rvt_agent.update(step, replay_sample)
122
+
123
+
124
+ for key, val in self.rvt_agent.loss_log.items():
125
+ self._summaries[key] = np.mean(np.array(val))
126
+ device = self._device
127
+ rank = device
128
+ if step % 10 == 0 and rank == 0:
129
+ wandb.log({
130
+ 'train/grip_loss': update_dict["grip_loss"],
131
+ 'train/trans_loss': update_dict["trans_loss"],
132
+ 'train/rot_loss': (update_dict["rot_loss_x"]+update_dict["rot_loss_y"]+update_dict["rot_loss_z"]),
133
+ 'train/collision_loss': update_dict["collision_loss"],
134
+ 'train/total_loss': update_dict["total_loss"],
135
+ }, step=step)
136
+ self._wandb_summaries = {
137
+ 'losses/grip_loss': update_dict["grip_loss"],
138
+ 'losses/trans_loss': update_dict["trans_loss"],
139
+ 'losses/rot_loss': (update_dict["rot_loss_x"]+update_dict["rot_loss_y"]+update_dict["rot_loss_z"]),
140
+ 'losses/collision_loss': update_dict["collision_loss"],
141
+ 'losses/total_loss': update_dict["total_loss"],
142
+ }
143
+ return {
144
+ "total_losses": update_dict["total_loss"],
145
+ }
146
+
147
+ return result
148
+
149
+ def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
150
+ return self.rvt_agent.act(step, observation, deterministic)
151
+
152
+ def reset(self) -> None:
153
+ self.rvt_agent.reset()
154
+
155
+ def update_summaries(self) -> List[Summary]:
156
+ summaries = []
157
+ for k, v in self._summaries.items():
158
+ summaries.append(ScalarSummary(f"RVT/{k}", v))
159
+ return summaries
160
+
161
+ def update_wandb_summaries(self):
162
+ summaries = dict()
163
+
164
+ for k, v in self._wandb_summaries.items():
165
+ summaries[k] = v
166
+ return summaries
167
+
168
+ def act_summaries(self) -> List[Summary]:
169
+ return []
170
+
171
+ def load_weights(self, savedir: str) -> None:
172
+ """
173
+ copied from RVT
174
+ """
175
+ device = torch.device("cuda:0")
176
+ weight_file = os.path.join(savedir, self._checkpoint_filename)
177
+ state_dict = torch.load(weight_file, map_location=device)
178
+
179
+ skill = self.rvt_agent.skill_manager
180
+ visual_aligner = self.rvt_agent.visual_aligner
181
+ model = self.rvt_agent._network
182
+ optimizer = self.rvt_agent._optimizer
183
+ lr_sched = self.rvt_agent._lr_sched
184
+
185
+ if isinstance(model, DDP):
186
+ model = model.module
187
+ model.load_state_dict(state_dict["model_state"])
188
+ optimizer.load_state_dict(state_dict["optimizer_state"])
189
+ lr_sched.load_state_dict(state_dict["lr_sched_state"])
190
+
191
+ return self.rvt_agent.load_clip()
192
+
193
+
194
+ def save_weights(self, savedir: str) -> None:
195
+
196
+ os.makedirs(savedir, exist_ok=True)
197
+ weight_file = os.path.join(savedir, self._checkpoint_filename)
198
+ skill = self.rvt_agent.skill_manager
199
+ visual_aligner = self.rvt_agent.visual_aligner
200
+ model = self.rvt_agent._network
201
+ optimizer = self.rvt_agent._optimizer
202
+ lr_sched = self.rvt_agent._lr_sched
203
+
204
+ if isinstance(model, DDP):
205
+ model = model.module
206
+
207
+ skill_state = skill.state_dict()
208
+ visual_aligner_state = visual_aligner.state_dict()
209
+ model_state = model.state_dict()
210
+
211
+ torch.save(
212
+ {
213
+ "skill_state": skill_state,
214
+ "visual_aligner_state": visual_aligner_state,
215
+ "model_state": model_state,
216
+ "optimizer_state": optimizer.state_dict(),
217
+ "lr_sched_state": lr_sched.state_dict(),
218
+ },
219
+ weight_file,
220
+ )
221
+
third_party/AnyBimanual/agents/rvt/rvt/config.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4
+
5
+ from yacs.config import CfgNode as CN
6
+
7
+ _C = CN()
8
+
9
+ _C.agent = "our"
10
+ _C.tasks = "insert_onto_square_peg,open_drawer,place_wine_at_rack_location,light_bulb_in"
11
+ _C.exp_id = "def"
12
+ _C.resume = ""
13
+ # bs per device, effective bs is scaled by num device
14
+ _C.bs = 4
15
+ _C.epochs = 20
16
+ # number of dataloader workers, >= 0
17
+ _C.num_workers = 0
18
+ # 'transition_uniform' or 'task_uniform'
19
+ _C.sample_distribution_mode = 'transition_uniform'
20
+ _C.train_iter = 16 * 10000
21
+
22
+ # arguments present in both peract and rvt
23
+ # some of them donot support every possible combination in peract
24
+ _C.peract = CN()
25
+ _C.peract.lambda_weight_l2 = 1e-6
26
+ # lr should be thought on per sample basis
27
+ # effective lr is multiplied by bs * num_devices
28
+ _C.peract.lr = 2.5e-5
29
+ _C.peract.optimizer_type = "lamb"
30
+ _C.peract.warmup_steps = 0
31
+ _C.peract.lr_cos_dec = False
32
+ _C.peract.add_rgc_loss = True
33
+ _C.peract.num_rotation_classes = 72
34
+ _C.peract.amp = False
35
+ _C.peract.bnb = False
36
+ _C.peract.transform_augmentation = True
37
+ _C.peract.transform_augmentation_xyz = [0.1, 0.1, 0.1]
38
+ _C.peract.transform_augmentation_rpy = [0.0, 0.0, 20.0]
39
+
40
+ # arguments present in only rvt and not peract
41
+ _C.rvt = CN()
42
+ _C.rvt.gt_hm_sigma = 1.5
43
+ _C.rvt.img_aug = 0.1
44
+ _C.rvt.place_with_mean = True
45
+ _C.rvt.move_pc_in_bound = True
46
+
47
+ # arguments present in peract official
48
+ _C.peract_official = CN()
49
+ _C.peract_official.cfg_path = "configs/peract_official_config.yaml"
50
+
51
+
52
+ def get_cfg_defaults():
53
+ """Get a yacs CfgNode object with default values for my_project."""
54
+ return _C.clone()
third_party/AnyBimanual/agents/rvt/rvt/configs/peract_official_config.yaml ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from: https://github.com/peract/peract/releases/download/v1.0.0/peract_600k.zip
2
+ method:
3
+ name: PERACT_BC
4
+ lr: 0.0005
5
+ lr_scheduler: false
6
+ num_warmup_steps: 3000
7
+ optimizer: lamb
8
+ activation: lrelu
9
+ norm: None
10
+ lambda_weight_l2: 1.0e-06
11
+ trans_loss_weight: 1.0
12
+ rot_loss_weight: 1.0
13
+ grip_loss_weight: 1.0
14
+ collision_loss_weight: 1.0
15
+ rotation_resolution: 5
16
+ image_crop_size: 64
17
+ bounds_offset:
18
+ - 0.15
19
+ voxel_sizes:
20
+ - 100
21
+ num_latents: 2048
22
+ latent_dim: 512
23
+ transformer_depth: 6
24
+ transformer_iterations: 1
25
+ cross_heads: 1
26
+ cross_dim_head: 64
27
+ latent_heads: 8
28
+ latent_dim_head: 64
29
+ pos_encoding_with_lang: false
30
+ lang_fusion_type: seq
31
+ voxel_patch_size: 5
32
+ voxel_patch_stride: 5
33
+ input_dropout: 0.1
34
+ attn_dropout: 0.1
35
+ decoder_dropout: 0.0
36
+ crop_augmentation: true
37
+ final_dim: 64
38
+ transform_augmentation:
39
+ apply_se3: true
40
+ aug_xyz:
41
+ - 0.125
42
+ - 0.125
43
+ - 0.125
44
+ aug_rpy:
45
+ - 0.0
46
+ - 0.0
47
+ - 0.0
48
+ aug_rot_resolution: 5
49
+ demo_augmentation: true
50
+ demo_augmentation_every_n: 10
51
+ no_skip_connection: false
52
+ no_perceiver: false
53
+ no_language: false
54
+ keypoint_method: heuristic
55
+ ddp:
56
+ master_addr: "localhost"
57
+ master_port: "29500"
58
+ num_devices: 1
59
+ rlbench:
60
+ task_name: multi
61
+ tasks:
62
+ - change_channel
63
+ - close_jar
64
+ - insert_onto_square_peg
65
+ - light_bulb_in
66
+ - meat_off_grill
67
+ - open_drawer
68
+ - place_cups
69
+ - place_shape_in_shape_sorter
70
+ - push_buttons
71
+ - put_groceries_in_cupboard
72
+ - put_item_in_drawer
73
+ - put_money_in_safe
74
+ - reach_and_drag
75
+ - stack_blocks
76
+ - stack_cups
77
+ - turn_tap
78
+ - set_clock_to_time
79
+ - place_wine_at_rack_location
80
+ - put_rubbish_in_color_bin
81
+ - slide_block_to_color_target
82
+ - sweep_to_dustpan_of_size
83
+ demos: 100
84
+ demo_path: /raid/dataset/
85
+ episode_length: 25
86
+ cameras:
87
+ - front
88
+ - left_shoulder
89
+ - right_shoulder
90
+ - wrist
91
+ camera_resolution:
92
+ - 128
93
+ - 128
94
+ scene_bounds:
95
+ - -0.3
96
+ - -0.5
97
+ - 0.6
98
+ - 0.7
99
+ - 0.5
100
+ - 1.6
101
+ include_lang_goal_in_obs: True
102
+ replay:
103
+ batch_size: 16
104
+ timesteps: 1
105
+ prioritisation: false
106
+ task_uniform: true
107
+ use_disk: true
108
+ path: /raid/arm/replay
109
+ max_parallel_processes: 32
110
+ framework:
111
+ log_freq: 100
112
+ save_freq: 10000
113
+ train_envs: 1
114
+ replay_ratio: 16
115
+ transitions_before_train: 200
116
+ tensorboard_logging: true
117
+ csv_logging: true
118
+ training_iterations: 600001
119
+ gpu: 0
120
+ env_gpu: 0
121
+ logdir: /home/user/workspace/logs_may16_n100
122
+ seeds: 1
123
+ start_seed: 0
124
+ load_existing_weights: true
125
+ num_weights_to_keep: 60
126
+ record_every_n: 5
127
+
third_party/AnyBimanual/agents/rvt/rvt/configs/rvt.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_id: rvt
2
+ tasks: all
3
+ bs: 3
4
+ num_workers: 3
5
+ epochs: 15
6
+ sample_distribution_mode: task_uniform
7
+ peract:
8
+ lr: 1e-4
9
+ warmup_steps: 2000
10
+ optimizer_type: lamb
11
+ lr_cos_dec: True
12
+ transform_augmentation_xyz: [0.125, 0.125, 0.125]
13
+ transform_augmentation_rpy: [0.0, 0.0, 45.0]
14
+ rvt:
15
+ place_with_mean: False
third_party/AnyBimanual/agents/rvt/rvt/configs/rvt2.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ exp_id: rvt2
2
+ tasks: all
3
+ bs: 24
4
+ num_workers: 3
5
+ epochs: 15
6
+ sample_distribution_mode: task_uniform
7
+ peract:
8
+ lr: 1.25e-5
9
+ warmup_steps: 2000
10
+ optimizer_type: lamb
11
+ lr_cos_dec: True
12
+ transform_augmentation_xyz: [0.125, 0.125, 0.125]
13
+ transform_augmentation_rpy: [0.0, 0.0, 45.0]
14
+ amp: True
15
+ bnb: True
16
+ lambda_weight_l2: 1e-4
17
+ rvt:
18
+ place_with_mean: False
19
+ img_aug: 0.0
third_party/AnyBimanual/agents/rvt/rvt/eval.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # Licensed under the NVIDIA Source Code License [see LICENSE for details].
4
+
5
+ import os
6
+ import yaml
7
+ import csv
8
+ import torch
9
+ import cv2
10
+ import shutil
11
+
12
+ import numpy as np
13
+
14
+ from omegaconf import OmegaConf
15
+ from multiprocessing import Value
16
+ from tensorflow.python.summary.summary_iterator import summary_iterator
17
+ from copy import deepcopy
18
+
19
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
20
+ os.environ["BITSANDBYTES_NOWELCOME"] = "1"
21
+
22
+ from rlbench.backend import task as rlbench_task
23
+ from rlbench.backend.utils import task_file_to_task_class
24
+ from rlbench.action_modes.gripper_action_modes import Discrete
25
+ from rlbench.action_modes.action_mode import MoveArmThenGripper
26
+ from yarr.utils.rollout_generator import RolloutGenerator
27
+ from yarr.utils.stat_accumulator import SimpleAccumulator
28
+ from yarr.utils.log_writer import LogWriter
29
+ from yarr.agents.agent import VideoSummary
30
+
31
+ import rvt.mvt.config as default_mvt_cfg
32
+ import rvt.models.rvt_agent as rvt_agent
33
+ import rvt.config as default_exp_cfg
34
+
35
+ from rvt.mvt.mvt import MVT
36
+ from rvt.libs.peract.helpers import utils
37
+ from rvt.utils.custom_rlbench_env import (
38
+ CustomMultiTaskRLBenchEnv2 as CustomMultiTaskRLBenchEnv,
39
+ )
40
+ from rvt.utils.peract_utils import (
41
+ CAMERAS,
42
+ SCENE_BOUNDS,
43
+ IMAGE_SIZE,
44
+ get_official_peract,
45
+ )
46
+ from rvt.utils.rlbench_planning import (
47
+ EndEffectorPoseViaPlanning2 as EndEffectorPoseViaPlanning,
48
+ )
49
+ from rvt.utils.rvt_utils import (
50
+ TensorboardManager,
51
+ get_eval_parser,
52
+ RLBENCH_TASKS,
53
+ )
54
+ from rvt.utils.rvt_utils import load_agent as load_agent_state
55
+
56
+
57
+ def load_agent(
58
+ model_path=None,
59
+ peract_official=False,
60
+ peract_model_dir=None,
61
+ exp_cfg_path=None,
62
+ mvt_cfg_path=None,
63
+ eval_log_dir="",
64
+ device=0,
65
+ use_input_place_with_mean=False,
66
+ ):
67
+ device = f"cuda:{device}"
68
+
69
+ if not (peract_official):
70
+ assert model_path is not None
71
+
72
+ # load exp_cfg
73
+ model_folder = os.path.join(os.path.dirname(model_path))
74
+
75
+ exp_cfg = default_exp_cfg.get_cfg_defaults()
76
+ if exp_cfg_path != None:
77
+ exp_cfg.merge_from_file(exp_cfg_path)
78
+ else:
79
+ exp_cfg.merge_from_file(os.path.join(model_folder, "exp_cfg.yaml"))
80
+
81
+ # NOTE: to not use place_with_mean in evaluation
82
+ # needed for rvt-1 but not rvt-2
83
+ if not use_input_place_with_mean:
84
+ # for backward compatibility
85
+ old_place_with_mean = exp_cfg.rvt.place_with_mean
86
+ exp_cfg.rvt.place_with_mean = True
87
+
88
+ exp_cfg.freeze()
89
+
90
+ # create agent
91
+ if exp_cfg.agent == "original":
92
+ # initialize PerceiverIO Transformer
93
+ VOXEL_SIZES = [100] # 100x100x100 voxels
94
+ NUM_LATENTS = 512 # PerceiverIO latents
95
+ BATCH_SIZE_TRAIN = 1
96
+ perceiver_encoder = PerceiverIO(
97
+ depth=6,
98
+ iterations=1,
99
+ voxel_size=VOXEL_SIZES[0],
100
+ initial_dim=3 + 3 + 1 + 3,
101
+ low_dim_size=4,
102
+ layer=0,
103
+ num_rotation_classes=72,
104
+ num_grip_classes=2,
105
+ num_collision_classes=2,
106
+ num_latents=NUM_LATENTS,
107
+ latent_dim=512,
108
+ cross_heads=1,
109
+ latent_heads=8,
110
+ cross_dim_head=64,
111
+ latent_dim_head=64,
112
+ weight_tie_layers=False,
113
+ activation="lrelu",
114
+ input_dropout=0.1,
115
+ attn_dropout=0.1,
116
+ decoder_dropout=0.0,
117
+ voxel_patch_size=5,
118
+ voxel_patch_stride=5,
119
+ final_dim=64,
120
+ )
121
+
122
+ # initialize PerceiverActor
123
+ agent = PerceiverActorAgent(
124
+ coordinate_bounds=SCENE_BOUNDS,
125
+ perceiver_encoder=perceiver_encoder,
126
+ camera_names=CAMERAS,
127
+ batch_size=BATCH_SIZE_TRAIN,
128
+ voxel_size=VOXEL_SIZES[0],
129
+ voxel_feature_size=3,
130
+ num_rotation_classes=72,
131
+ rotation_resolution=5,
132
+ image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
133
+ transform_augmentation=False,
134
+ **exp_cfg.peract,
135
+ )
136
+ elif exp_cfg.agent == "our":
137
+ mvt_cfg = default_mvt_cfg.get_cfg_defaults()
138
+ if mvt_cfg_path != None:
139
+ mvt_cfg.merge_from_file(mvt_cfg_path)
140
+ else:
141
+ mvt_cfg.merge_from_file(os.path.join(model_folder, "mvt_cfg.yaml"))
142
+
143
+ mvt_cfg.freeze()
144
+
145
+ # for rvt-2 we do not change place_with_mean regardless of the arg
146
+ # done this way to ensure backward compatibility and allow the
147
+ # flexibility for rvt-1
148
+ if mvt_cfg.stage_two:
149
+ exp_cfg.defrost()
150
+ exp_cfg.rvt.place_with_mean = old_place_with_mean
151
+ exp_cfg.freeze()
152
+
153
+ rvt = MVT(
154
+ renderer_device=device,
155
+ **mvt_cfg,
156
+ )
157
+
158
+ agent = rvt_agent.RVTAgent(
159
+ network=rvt.to(device),
160
+ image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
161
+ add_lang=mvt_cfg.add_lang,
162
+ stage_two=mvt_cfg.stage_two,
163
+ rot_ver=mvt_cfg.rot_ver,
164
+ scene_bounds=SCENE_BOUNDS,
165
+ cameras=CAMERAS,
166
+ log_dir=f"{eval_log_dir}/eval_run",
167
+ **exp_cfg.peract,
168
+ **exp_cfg.rvt,
169
+ )
170
+ else:
171
+ raise NotImplementedError
172
+
173
+ agent.build(training=False, device=device)
174
+ load_agent_state(model_path, agent)
175
+ agent.eval()
176
+
177
+ elif peract_official: # load official peract model, using the provided code
178
+ try:
179
+ model_folder = os.path.join(os.path.abspath(peract_model_dir), "..", "..")
180
+ train_cfg_path = os.path.join(model_folder, "config.yaml")
181
+ agent = get_official_peract(train_cfg_path, False, device, bs=1)
182
+ except FileNotFoundError:
183
+ print("Config file not found, trying to load again in our format")
184
+ train_cfg_path = "configs/peract_official_config.yaml"
185
+ agent = get_official_peract(train_cfg_path, False, device, bs=1)
186
+ agent.load_weights(peract_model_dir)
187
+ agent.eval()
188
+
189
+ print("Agent Information")
190
+ print(agent)
191
+ return agent
192
+
193
+
194
+ @torch.no_grad()
195
+ def eval(
196
+ agent,
197
+ tasks,
198
+ eval_datafolder,
199
+ start_episode=0,
200
+ eval_episodes=25,
201
+ episode_length=25,
202
+ replay_ground_truth=False,
203
+ device=0,
204
+ headless=True,
205
+ logging=False,
206
+ log_dir=None,
207
+ verbose=True,
208
+ save_video=False,
209
+ ):
210
+ agent.eval()
211
+ if isinstance(agent, rvt_agent.RVTAgent):
212
+ agent.load_clip()
213
+
214
+ camera_resolution = [IMAGE_SIZE, IMAGE_SIZE]
215
+ obs_config = utils.create_obs_config(CAMERAS, camera_resolution, method_name="")
216
+
217
+ gripper_mode = Discrete()
218
+ arm_action_mode = EndEffectorPoseViaPlanning()
219
+ action_mode = MoveArmThenGripper(arm_action_mode, gripper_mode)
220
+
221
+ task_files = [
222
+ t.replace(".py", "")
223
+ for t in os.listdir(rlbench_task.TASKS_PATH)
224
+ if t != "__init__.py" and t.endswith(".py")
225
+ ]
226
+
227
+ task_classes = []
228
+ if tasks[0] == "all":
229
+ tasks = RLBENCH_TASKS
230
+ if verbose:
231
+ print(f"evaluate on {len(tasks)} tasks: ", tasks)
232
+
233
+ for task in tasks:
234
+ if task not in task_files:
235
+ raise ValueError("Task %s not recognised!." % task)
236
+ task_classes.append(task_file_to_task_class(task))
237
+
238
+ eval_env = CustomMultiTaskRLBenchEnv(
239
+ task_classes=task_classes,
240
+ observation_config=obs_config,
241
+ action_mode=action_mode,
242
+ dataset_root=eval_datafolder,
243
+ episode_length=episode_length,
244
+ headless=headless,
245
+ swap_task_every=eval_episodes,
246
+ include_lang_goal_in_obs=True,
247
+ time_in_state=True,
248
+ record_every_n=1 if save_video else -1,
249
+ )
250
+
251
+ eval_env.eval = True
252
+
253
+ device = f"cuda:{device}"
254
+
255
+ if logging:
256
+ assert log_dir is not None
257
+
258
+ # create metric saving writer
259
+ csv_file = "eval_results.csv"
260
+ if not os.path.exists(os.path.join(log_dir, csv_file)):
261
+ with open(os.path.join(log_dir, csv_file), "w") as csv_fp:
262
+ fieldnames = ["task", "success rate", "length", "total_transitions"]
263
+ csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
264
+ csv_writer.writeheader()
265
+
266
+ # evaluate agent
267
+ rollout_generator = RolloutGenerator(device)
268
+ stats_accumulator = SimpleAccumulator(eval_video_fps=30)
269
+
270
+ eval_env.launch()
271
+
272
+ current_task_id = -1
273
+
274
+ num_tasks = len(tasks)
275
+ step_signal = Value("i", -1)
276
+
277
+ scores = []
278
+ for task_id in range(num_tasks):
279
+ task_rewards = []
280
+ for ep in range(start_episode, start_episode + eval_episodes):
281
+ episode_rollout = []
282
+ generator = rollout_generator.generator(
283
+ step_signal=step_signal,
284
+ env=eval_env,
285
+ agent=agent,
286
+ episode_length=episode_length,
287
+ timesteps=1,
288
+ eval=True,
289
+ eval_demo_seed=ep,
290
+ record_enabled=False,
291
+ replay_ground_truth=replay_ground_truth,
292
+ )
293
+ try:
294
+ for replay_transition in generator:
295
+ episode_rollout.append(replay_transition)
296
+ except StopIteration as e:
297
+ continue
298
+ except Exception as e:
299
+ eval_env.shutdown()
300
+ raise e
301
+
302
+ for transition in episode_rollout:
303
+ stats_accumulator.step(transition, True)
304
+ current_task_id = transition.info["active_task_id"]
305
+ assert current_task_id == task_id
306
+
307
+ task_name = tasks[task_id]
308
+ reward = episode_rollout[-1].reward
309
+ task_rewards.append(reward)
310
+ lang_goal = eval_env._lang_goal
311
+ if verbose:
312
+ print(
313
+ f"Evaluating {task_name} | Episode {ep} | Score: {reward} | Episode Length: {len(episode_rollout)} | Lang Goal: {lang_goal}"
314
+ )
315
+
316
+ # report summaries
317
+ summaries = []
318
+ summaries.extend(stats_accumulator.pop())
319
+ task_name = tasks[task_id]
320
+ if logging:
321
+ # writer csv first
322
+ with open(os.path.join(log_dir, csv_file), "a") as csv_fp:
323
+ fieldnames = ["task", "success rate", "length", "total_transitions"]
324
+ csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
325
+ csv_results = {"task": task_name}
326
+ for s in summaries:
327
+ if s.name == "eval_envs/return":
328
+ csv_results["success rate"] = s.value
329
+ elif s.name == "eval_envs/length":
330
+ csv_results["length"] = s.value
331
+ elif s.name == "eval_envs/total_transitions":
332
+ csv_results["total_transitions"] = s.value
333
+ if "eval" in s.name:
334
+ s.name = "%s/%s" % (s.name, task_name)
335
+ csv_writer.writerow(csv_results)
336
+ else:
337
+ for s in summaries:
338
+ if "eval" in s.name:
339
+ s.name = "%s/%s" % (s.name, task_name)
340
+
341
+ if len(summaries) > 0:
342
+ task_score = [
343
+ s.value for s in summaries if f"eval_envs/return/{task_name}" in s.name
344
+ ][0]
345
+ else:
346
+ task_score = "unknown"
347
+
348
+ print(f"[Evaluation] Finished {task_name} | Final Score: {task_score}\n")
349
+
350
+ scores.append(task_score)
351
+
352
+ if save_video:
353
+ video_image_folder = "./tmp"
354
+ record_fps = 25
355
+ record_folder = os.path.join(log_dir, "videos")
356
+ os.makedirs(record_folder, exist_ok=True)
357
+ video_success_cnt = 0
358
+ video_fail_cnt = 0
359
+ video_cnt = 0
360
+ for summary in summaries:
361
+ if isinstance(summary, VideoSummary):
362
+ video = deepcopy(summary.value)
363
+ video = np.transpose(video, (0, 2, 3, 1))
364
+ video = video[:, :, :, ::-1]
365
+ if task_rewards[video_cnt] > 99:
366
+ video_path = os.path.join(
367
+ record_folder,
368
+ f"{task_name}_success_{video_success_cnt}.mp4",
369
+ )
370
+ video_success_cnt += 1
371
+ else:
372
+ video_path = os.path.join(
373
+ record_folder, f"{task_name}_fail_{video_fail_cnt}.mp4"
374
+ )
375
+ video_fail_cnt += 1
376
+ video_cnt += 1
377
+ os.makedirs(video_image_folder, exist_ok=True)
378
+ for idx in range(len(video) - 10):
379
+ cv2.imwrite(
380
+ os.path.join(video_image_folder, f"{idx}.png"), video[idx]
381
+ )
382
+ images_path = os.path.join(video_image_folder, r"%d.png")
383
+ os.system(
384
+ "ffmpeg -i {} -vf palettegen palette.png -hide_banner -loglevel error".format(
385
+ images_path
386
+ )
387
+ )
388
+ os.system(
389
+ "ffmpeg -framerate {} -i {} -i palette.png -lavfi paletteuse {} -hide_banner -loglevel error".format(
390
+ record_fps, images_path, video_path
391
+ )
392
+ )
393
+ os.remove("palette.png")
394
+ shutil.rmtree(video_image_folder)
395
+
396
+ eval_env.shutdown()
397
+
398
+ if logging:
399
+ csv_fp.close()
400
+
401
+ # set agent to back train mode
402
+ agent.train()
403
+
404
+ # unloading clip to save memory
405
+ if isinstance(agent, rvt_agent.RVTAgent):
406
+ agent.unload_clip()
407
+ agent._network.free_mem()
408
+
409
+ return scores
410
+
411
+
412
+ def get_model_index(filename):
413
+ """
414
+ :param filenam: path of file of format /.../model_idx.pth
415
+ :return: idx or None
416
+ """
417
+ if len(filename) >= 9 and filename[-4:] == ".pth":
418
+ try:
419
+ index = int(filename[:-4].split("_")[-1])
420
+ except:
421
+ index = None
422
+ else:
423
+ index = None
424
+ return index
425
+
426
+
427
+ def _eval(args):
428
+
429
+ model_paths = []
430
+ if not (args.peract_official):
431
+ assert args.model_name is not None
432
+ model_paths.append(os.path.join(args.model_folder, args.model_name))
433
+ else:
434
+ model_paths.append(None)
435
+
436
+ # skipping evaluated models
437
+ if args.skip:
438
+ """
439
+ to_skip: {
440
+ 0: {'light_bulb_in': False, .....}
441
+ 1: {'light_bulb_in': False, .....}
442
+ .
443
+ .
444
+ }
445
+ """
446
+ to_skip = {
447
+ get_model_index(x): {y: False for y in args.tasks} for x in model_paths
448
+ }
449
+
450
+ filenames = os.listdir(args.eval_log_dir)
451
+ for filename in filenames:
452
+ if not filename.startswith("events.out.tfevents."):
453
+ continue
454
+ summ = summary_iterator(f"{args.eval_log_dir}/{filename}")
455
+ # skipping the time log of the summary
456
+ try:
457
+ next(summ)
458
+ except:
459
+ # moving to the next file
460
+ continue
461
+ for cur_summ in summ:
462
+ cur_task = cur_summ.summary.value[0].tag[5:]
463
+ cur_step = cur_summ.step
464
+ if cur_step in to_skip:
465
+ to_skip[cur_step][cur_task] = True
466
+
467
+ tb = TensorboardManager(args.eval_log_dir)
468
+ for model_path in model_paths:
469
+ tasks_to_eval = deepcopy(args.tasks)
470
+
471
+ if args.peract_official:
472
+ model_idx = 0
473
+ else:
474
+ model_idx = get_model_index(model_path)
475
+ if model_idx is None:
476
+ model_idx = 0
477
+
478
+ if args.skip:
479
+ for _task in args.tasks:
480
+ if to_skip[model_idx][_task]:
481
+ tasks_to_eval.remove(_task)
482
+
483
+ if len(tasks_to_eval) == 0:
484
+ print(f"Skipping model_idx={model_idx} for args.tasks={args.tasks}")
485
+ continue
486
+
487
+ if not (args.peract_official):
488
+ agent = load_agent(
489
+ model_path=model_path,
490
+ exp_cfg_path=args.exp_cfg_path,
491
+ mvt_cfg_path=args.mvt_cfg_path,
492
+ eval_log_dir=args.eval_log_dir,
493
+ device=args.device,
494
+ use_input_place_with_mean=args.use_input_place_with_mean,
495
+ )
496
+
497
+ agent_eval_log_dir = os.path.join(
498
+ args.eval_log_dir, os.path.basename(model_path).split(".")[0]
499
+ )
500
+ else:
501
+ agent = load_agent(
502
+ peract_official=args.peract_official,
503
+ peract_model_dir=args.peract_model_dir,
504
+ device=args.device,
505
+ use_input_place_with_mean=args.use_input_place_with_mean,
506
+ )
507
+ agent_eval_log_dir = os.path.join(args.eval_log_dir, "final")
508
+
509
+ os.makedirs(agent_eval_log_dir, exist_ok=True)
510
+ scores = eval(
511
+ agent=agent,
512
+ tasks=tasks_to_eval,
513
+ eval_datafolder=args.eval_datafolder,
514
+ start_episode=args.start_episode,
515
+ eval_episodes=args.eval_episodes,
516
+ episode_length=args.episode_length,
517
+ replay_ground_truth=args.ground_truth,
518
+ device=args.device,
519
+ headless=args.headless,
520
+ logging=True,
521
+ log_dir=agent_eval_log_dir,
522
+ verbose=True,
523
+ save_video=args.save_video,
524
+ )
525
+ print(f"model {model_path}, scores {scores}")
526
+ task_scores = {}
527
+ for i in range(len(tasks_to_eval)):
528
+ task_scores[tasks_to_eval[i]] = scores[i]
529
+
530
+ print("save ", task_scores)
531
+ tb.update("eval", model_idx, task_scores)
532
+ tb.writer.flush()
533
+
534
+ tb.close()
535
+
536
+
537
+ if __name__ == "__main__":
538
+ parser = get_eval_parser()
539
+
540
+ args = parser.parse_args()
541
+
542
+ if args.log_name is None:
543
+ args.log_name = "none"
544
+
545
+ if not (args.peract_official):
546
+ args.eval_log_dir = os.path.join(args.model_folder, "eval", args.log_name)
547
+ else:
548
+ args.eval_log_dir = os.path.join(args.peract_model_dir, "eval", args.log_name)
549
+
550
+ os.makedirs(args.eval_log_dir, exist_ok=True)
551
+
552
+ # save the arguments for future reference
553
+ with open(os.path.join(args.eval_log_dir, "eval_config.yaml"), "w") as fp:
554
+ yaml.dump(args.__dict__, fp)
555
+
556
+ _eval(args)
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitattributes ADDED
@@ -0,0 +1 @@
 
 
1
+ pcd_data.tar.gz filter=lfs diff=lfs merge=lfs -text
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ build/*
2
+ *.egg-info/*
3
+ *.so
4
+ */__pycache__/*
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/LICENSE ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved.
2
+
3
+
4
+ NVIDIA Source Code License for instant neural graphics primitives
5
+
6
+
7
+ =======================================================================
8
+
9
+ 1. Definitions
10
+
11
+ "Licensor" means any person or entity that distributes its Work.
12
+
13
+ "Software" means the original work of authorship made available under
14
+ this License.
15
+
16
+ "Work" means the Software and any additions to or derivative works of
17
+ the Software that are made available under this License.
18
+
19
+ The terms "reproduce," "reproduction," "derivative works," and
20
+ "distribution" have the meaning as provided under U.S. copyright law;
21
+ provided, however, that for the purposes of this License, derivative
22
+ works shall not include works that remain separable from, or merely
23
+ link (or bind by name) to the interfaces of, the Work.
24
+
25
+ Works, including the Software, are "made available" under this License
26
+ by including in or with the Work either (a) a copyright notice
27
+ referencing the applicability of this License to the Work, or (b) a
28
+ copy of this License.
29
+
30
+ 2. License Grants
31
+
32
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
33
+ License, each Licensor grants to you a perpetual, worldwide,
34
+ non-exclusive, royalty-free, copyright license to reproduce,
35
+ prepare derivative works of, publicly display, publicly perform,
36
+ sublicense and distribute its Work and any resulting derivative
37
+ works in any form.
38
+
39
+ 3. Limitations
40
+
41
+ 3.1 Redistribution. You may reproduce or distribute the Work only
42
+ if (a) you do so under this License, (b) you include a complete
43
+ copy of this License with your distribution, and (c) you retain
44
+ without modification any copyright, patent, trademark, or
45
+ attribution notices that are present in the Work.
46
+
47
+ 3.2 Derivative Works. You may specify that additional or different
48
+ terms apply to the use, reproduction, and distribution of your
49
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
50
+ provide that the use limitation in Section 3.3 applies to your
51
+ derivative works, and (b) you identify the specific derivative
52
+ works that are subject to Your Terms. Notwithstanding Your Terms,
53
+ this License (including the redistribution requirements in Section
54
+ 3.1) will continue to apply to the Work itself.
55
+
56
+ 3.3 Use Limitation. The Work and any derivative works thereof only
57
+ may be used or intended for use non-commercially. Notwithstanding
58
+ the foregoing, NVIDIA and its affiliates may use the Work and any
59
+ derivative works commercially. As used herein, "non-commercially"
60
+ means for research or evaluation purposes only.
61
+
62
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
63
+ against any Licensor (including any claim, cross-claim or
64
+ counterclaim in a lawsuit) to enforce any patents that you allege
65
+ are infringed by any Work, then your rights under this License from
66
+ such Licensor (including the grant in Section 2.1) will terminate
67
+ immediately.
68
+
69
+ 3.5 Trademarks. This License does not grant any rights to use any
70
+ Licensor’s or its affiliates’ names, logos, or trademarks, except
71
+ as necessary to reproduce the notices described in this License.
72
+
73
+ 3.6 Termination. If you violate any term of this License, then your
74
+ rights under this License (including the grant in Section 2.1) will
75
+ terminate immediately.
76
+
77
+ 4. Disclaimer of Warranty.
78
+
79
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
80
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
81
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
82
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
83
+ THIS LICENSE.
84
+
85
+ 5. Limitation of Liability.
86
+
87
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
88
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
89
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
90
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
91
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
92
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
93
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
94
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
95
+ THE POSSIBILITY OF SUCH DAMAGES.
96
+
97
+ =======================================================================
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Point Renderer
2
+ A minimal, lightweight CUDA-accelerated renderer of pointclouds.
3
+
4
+ <div align="center"><img src="demo.png"/></div>
5
+
6
+ ### Install
7
+
8
+ ```
9
+ pip install -r requirements.txt
10
+ pip install -e .
11
+ ```
12
+
13
+ ### Run
14
+
15
+ **Load Data**
16
+ Extract included pcd_data.tar.gz
17
+
18
+ ```
19
+ import numpy as np
20
+
21
+ data = np.load("pcd_data/w1280_h720/3.npy", allow_pickle=True)
22
+ data = data[None][0]
23
+ pc = data["pc"]
24
+ rgb = data["img_feat"]
25
+ ```
26
+
27
+ **Render the image**
28
+
29
+ ```
30
+ # Make the renderer
31
+ from point_renderer.renderer import PointRenderer
32
+ renderer = PointRenderer(device="cuda", perf_timer=False)
33
+
34
+ # Define a batch of cameras
35
+ img_size = (512, 512)
36
+ K = renderer.get_camera_intrinsics(hfov=70, img_size=img_size)
37
+ camera_poses = renderer.get_batch_of_camera_poses(
38
+ cam_positions=[[1.5, 1.5, 1.5],[-1.5, -1.5, -1.5]],
39
+ cam_lookats=[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]])
40
+
41
+ # Render the pointcloud from the given cameras
42
+ images, depths = renderer.render_batch(pc, rgb, camera_poses, K, img_size,
43
+ default_color=1.0,
44
+ splat_radius=0.005,
45
+ aa_factor=2
46
+ )
47
+
48
+ # Show the results
49
+ plt.imshow(images[0].detach().cpu().numpy()); plt.show()
50
+ plt.imshow(depths[0].detach().cpu().numpy()); plt.show()
51
+ plt.imshow(images[1].detach().cpu().numpy()); plt.show()
52
+ plt.imshow(depths[1].detach().cpu().numpy()); plt.show()
53
+ ```
54
+
55
+ .. Or run the jupyter notebook that has this same code above, and also all the benchmarks.
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/demo.png ADDED

Git LFS Details

  • SHA256: 67abe7c267d53a8a62e8a383412032dcedbc6373162309ba5652031cfc780a7d
  • Pointer size: 131 Bytes
  • Size of remote file: 253 kB
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/image_0_splat_2xaa.png ADDED
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/point_renderer/cameras.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from point_renderer import ops
4
+ from functools import lru_cache
5
+
6
+ @lru_cache(maxsize=32)
7
+ def linalg_inv(poses):
8
+ return torch.linalg.inv(poses)
9
+
10
+ class Cameras:
11
+ def __init__(self, poses, intrinsics, img_size, inv_poses=None):
12
+ self.poses = poses
13
+ self.img_size = img_size
14
+ if inv_poses is None:
15
+ self.inv_poses = linalg_inv(poses)
16
+ else:
17
+ self.inv_poses = inv_poses
18
+ self.intrinsics = intrinsics
19
+
20
+ def __len__(self):
21
+ return len(self.poses)
22
+
23
+ def scale(self, constant):
24
+ self.intrinsics = self.intrinsics.clone()
25
+ self.intrinsics[:, :2, :3] *= constant
26
+
27
+ def is_orthographic(self):
28
+ raise ValueError("is_orthographic should be called on child classes only")
29
+
30
+ def is_perspective(self):
31
+ raise ValueError("is_perspective should be called on child classes only")
32
+
33
+
34
+ class PerspectiveCameras(Cameras):
35
+ def __init__(self, poses, intrinsics, img_size, inv_poses=None):
36
+ super().__init__(poses, intrinsics, img_size, inv_poses)
37
+
38
+ @classmethod
39
+ def from_lookat(cls, eyes, ats, ups, hfov, img_size, device="cpu"):
40
+ cam_poses = []
41
+ for eye, at, up in zip(eyes, ats, ups):
42
+ T = ops.lookat_to_cam_pose(eye, at, up, device=device)
43
+ cam_poses.append(T)
44
+ cam_poses = torch.stack(cam_poses, dim=0)
45
+ intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device)
46
+ intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
47
+ return PerspectiveCameras(cam_poses, intrinsics, img_size)
48
+
49
+ @classmethod
50
+ def from_rotation_and_translation(cls, R, T, S, hfov, img_size):
51
+ device = R.device
52
+ assert T.device == device
53
+ cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float)
54
+ cam_poses[:, :3, :3] = R * S[None, :]
55
+ cam_poses[:, :3, 3] = T
56
+ cam_poses[:, 3, 3] = 1.0
57
+ intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device)
58
+ intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
59
+ return PerspectiveCameras(cam_poses, intrinsics, img_size)
60
+
61
+ def to(self, device):
62
+ return PerspectiveCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device))
63
+
64
+ def is_orthographic(self):
65
+ return False
66
+
67
+ def is_perspective(self):
68
+ return True
69
+
70
+ class OrthographicCameras(Cameras):
71
+ def __init__(self, poses, intrinsics, img_size, inv_poses=None):
72
+ super().__init__(poses, intrinsics, img_size, inv_poses)
73
+
74
+ @classmethod
75
+ def from_lookat(cls, eyes, ats, ups, img_sizes_w, img_size_px, device="cpu"):
76
+ """
77
+ Args:
78
+ eyes: Nx3 tensor of camera coordinates
79
+ ats: Nx3 tensor of look-at directions
80
+ ups: Nx3 tensor of up-vectors
81
+ scale: Nx2 tensor defining image sizes in world coordinates
82
+ img_size: 2-dim tuple defining image size in pixels
83
+ Returns:
84
+ OrthographicCamera
85
+ """
86
+ if isinstance(img_sizes_w, list):
87
+ img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(eyes), 1))
88
+
89
+ cam_poses = []
90
+ for eye, at, up in zip(eyes, ats, ups):
91
+ T = ops.lookat_to_cam_pose(eye, at, up, device=device)
92
+ cam_poses.append(T)
93
+ cam_poses = torch.stack(cam_poses, dim=0)
94
+ intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device)
95
+ return OrthographicCameras(cam_poses, intrinsics, img_size_px)
96
+
97
+ @classmethod
98
+ def from_rotation_and_translation(cls, R, T, img_sizes_w, img_size_px, device="cpu"):
99
+ if isinstance(img_sizes_w, list):
100
+ img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(R), 1))
101
+
102
+ device = R.device
103
+ assert T.device == device
104
+ cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float)
105
+ cam_poses[:, :3, :3] = R
106
+ cam_poses[:, :3, 3] = T
107
+ cam_poses[:, 3, 3] = 1.0
108
+ intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device)
109
+ intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
110
+ return OrthographicCameras(cam_poses, intrinsics, img_size_px)
111
+
112
+ def to(self, device):
113
+ return OrthographicCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device))
114
+
115
+ def is_orthographic(self):
116
+ return True
117
+
118
+ def is_perspective(self):
119
+ return False