Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +9 -0
- environment/base_pip_freeze.txt +176 -0
- environment/base_python.txt +1 -0
- environment/env_list.txt +4 -0
- environment/hardware_snapshot.txt +1 -0
- environment/nvidia_smi.txt +22 -0
- environment/reconstruct_anybimanual_overlap_replay.sh +22 -0
- environment/rlbench_pip_freeze.txt +200 -0
- environment/rlbench_python.txt +1 -0
- environment/runtime_env_vars.sh +4 -0
- environment/setup_same_hardware.sh +25 -0
- environment/uname.txt +1 -0
- handoff/instructions4.md +591 -0
- history/VLAarchtests_previous_README.md +172 -0
- metadata/source_sizes.txt +4 -0
- metadata/staged_size.txt +1 -0
- metadata/staged_tree_top2.txt +64 -0
- third_party/AnyBimanual/agents/__init__.py +0 -0
- third_party/AnyBimanual/agents/agent_factory.py +101 -0
- third_party/AnyBimanual/agents/peract_bc/__init__.py +1 -0
- third_party/AnyBimanual/agents/peract_bc/launch_utils.py +128 -0
- third_party/AnyBimanual/agents/peract_bc/perceiver_lang_io.py +481 -0
- third_party/AnyBimanual/agents/peract_bc/qattention_peract_bc_agent.py +939 -0
- third_party/AnyBimanual/agents/peract_bc/qattention_stack_agent.py +132 -0
- third_party/AnyBimanual/agents/peract_bc/skill_manager.py +70 -0
- third_party/AnyBimanual/agents/peract_bc/trajectory_gpt2.py +775 -0
- third_party/AnyBimanual/agents/peract_bc/visual_aligner.py +39 -0
- third_party/AnyBimanual/agents/peract_bimanual/__init__.py +1 -0
- third_party/AnyBimanual/agents/peract_bimanual/launch_utils.py +117 -0
- third_party/AnyBimanual/agents/peract_bimanual/perceiver_lang_io.py +628 -0
- third_party/AnyBimanual/agents/peract_bimanual/qattention_peract_bc_agent.py +1317 -0
- third_party/AnyBimanual/agents/peract_bimanual/qattention_stack_agent.py +209 -0
- third_party/AnyBimanual/agents/peract_bimanual/skill_manager.py +70 -0
- third_party/AnyBimanual/agents/peract_bimanual/trajectory_gpt2.py +775 -0
- third_party/AnyBimanual/agents/peract_bimanual/visual_aligner.py +39 -0
- third_party/AnyBimanual/agents/replay_utils.py +667 -0
- third_party/AnyBimanual/agents/rvt/__init__.py +6 -0
- third_party/AnyBimanual/agents/rvt/launch_utils.py +221 -0
- third_party/AnyBimanual/agents/rvt/rvt/config.py +54 -0
- third_party/AnyBimanual/agents/rvt/rvt/configs/peract_official_config.yaml +127 -0
- third_party/AnyBimanual/agents/rvt/rvt/configs/rvt.yaml +15 -0
- third_party/AnyBimanual/agents/rvt/rvt/configs/rvt2.yaml +19 -0
- third_party/AnyBimanual/agents/rvt/rvt/eval.py +556 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitattributes +1 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitignore +4 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/LICENSE +97 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/README.md +55 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/demo.png +3 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/image_0_splat_2xaa.png +0 -0
- third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/point_renderer/cameras.py +119 -0
.gitattributes
CHANGED
|
@@ -14068,3 +14068,12 @@ baselines/AnyBimanual_overlap_replay/multi/10000-14999/14031.replay filter=lfs d
|
|
| 14068 |
baselines/AnyBimanual_overlap_replay/multi/10000-14999/14030.replay filter=lfs diff=lfs merge=lfs -text
|
| 14069 |
baselines/AnyBimanual_overlap_replay/multi/10000-14999/14032.replay filter=lfs diff=lfs merge=lfs -text
|
| 14070 |
baselines/AnyBimanual_overlap_replay/multi/10000-14999/14029.replay filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14068 |
baselines/AnyBimanual_overlap_replay/multi/10000-14999/14030.replay filter=lfs diff=lfs merge=lfs -text
|
| 14069 |
baselines/AnyBimanual_overlap_replay/multi/10000-14999/14032.replay filter=lfs diff=lfs merge=lfs -text
|
| 14070 |
baselines/AnyBimanual_overlap_replay/multi/10000-14999/14029.replay filter=lfs diff=lfs merge=lfs -text
|
| 14071 |
+
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/demo.png filter=lfs diff=lfs merge=lfs -text
|
| 14072 |
+
third_party/AnyBimanual/third_party/RLBench/readme_files/task_grid.png filter=lfs diff=lfs merge=lfs -text
|
| 14073 |
+
third_party/AnyBimanual/third_party/PyRep/tutorials/images/kinematics_group.png filter=lfs diff=lfs merge=lfs -text
|
| 14074 |
+
third_party/AnyBimanual/third_party/PyRep/tutorials/images/collision_collections.png filter=lfs diff=lfs merge=lfs -text
|
| 14075 |
+
third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene_robots.ttt filter=lfs diff=lfs merge=lfs -text
|
| 14076 |
+
third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene_mobiles_with_arms.ttt filter=lfs diff=lfs merge=lfs -text
|
| 14077 |
+
third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene_mobiles.ttt filter=lfs diff=lfs merge=lfs -text
|
| 14078 |
+
third_party/AnyBimanual/third_party/PyRep/tests/assets/test_scene.ttt filter=lfs diff=lfs merge=lfs -text
|
| 14079 |
+
third_party/AnyBimanual/third_party/PyRep/tests/assets/cracker_box/texture_map.png filter=lfs diff=lfs merge=lfs -text
|
environment/base_pip_freeze.txt
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
accelerate==1.13.0
|
| 2 |
+
annotated-doc==0.0.4
|
| 3 |
+
antlr4-python3-runtime==4.9.3
|
| 4 |
+
anyio==4.6.0
|
| 5 |
+
argon2-cffi==23.1.0
|
| 6 |
+
argon2-cffi-bindings==21.2.0
|
| 7 |
+
arrow==1.3.0
|
| 8 |
+
asttokens==2.4.1
|
| 9 |
+
async-lru==2.0.4
|
| 10 |
+
attrs==24.2.0
|
| 11 |
+
babel==2.16.0
|
| 12 |
+
beautifulsoup4==4.12.3
|
| 13 |
+
bleach==6.1.0
|
| 14 |
+
blinker==1.4
|
| 15 |
+
certifi==2024.8.30
|
| 16 |
+
cffi==1.17.1
|
| 17 |
+
charset-normalizer==3.3.2
|
| 18 |
+
click==8.3.1
|
| 19 |
+
comm==0.2.2
|
| 20 |
+
cryptography==3.4.8
|
| 21 |
+
cuda-bindings==12.9.4
|
| 22 |
+
cuda-pathfinder==1.2.2
|
| 23 |
+
cuda-toolkit==12.8.1
|
| 24 |
+
dbus-python==1.2.18
|
| 25 |
+
debugpy==1.8.5
|
| 26 |
+
decorator==5.1.1
|
| 27 |
+
defusedxml==0.7.1
|
| 28 |
+
distro==1.7.0
|
| 29 |
+
entrypoints==0.4
|
| 30 |
+
execnet==2.1.2
|
| 31 |
+
executing==2.1.0
|
| 32 |
+
fastjsonschema==2.20.0
|
| 33 |
+
filelock==3.13.1
|
| 34 |
+
fqdn==1.5.1
|
| 35 |
+
fsspec==2024.2.0
|
| 36 |
+
h11==0.14.0
|
| 37 |
+
hf-xet==1.4.2
|
| 38 |
+
httpcore==1.0.5
|
| 39 |
+
httplib2==0.20.2
|
| 40 |
+
httpx==0.27.2
|
| 41 |
+
huggingface_hub==1.8.0
|
| 42 |
+
idna==3.10
|
| 43 |
+
importlib-metadata==4.6.4
|
| 44 |
+
iniconfig==2.3.0
|
| 45 |
+
ipykernel==6.29.5
|
| 46 |
+
ipython==8.27.0
|
| 47 |
+
ipython-genutils==0.2.0
|
| 48 |
+
ipywidgets==8.1.5
|
| 49 |
+
isoduration==20.11.0
|
| 50 |
+
jedi==0.19.1
|
| 51 |
+
jeepney==0.7.1
|
| 52 |
+
Jinja2==3.1.3
|
| 53 |
+
json5==0.9.25
|
| 54 |
+
jsonpointer==3.0.0
|
| 55 |
+
jsonschema==4.23.0
|
| 56 |
+
jsonschema-specifications==2023.12.1
|
| 57 |
+
jupyter-archive==3.4.0
|
| 58 |
+
jupyter-events==0.10.0
|
| 59 |
+
jupyter-highlight-selected-word==0.2.0
|
| 60 |
+
jupyter-lsp==2.2.5
|
| 61 |
+
jupyter_client==7.4.9
|
| 62 |
+
jupyter_contrib_core==0.4.2
|
| 63 |
+
jupyter_contrib_nbextensions==0.7.0
|
| 64 |
+
jupyter_core==5.7.2
|
| 65 |
+
jupyter_nbextensions_configurator==0.6.4
|
| 66 |
+
jupyter_server==2.14.2
|
| 67 |
+
jupyter_server_terminals==0.5.3
|
| 68 |
+
jupyterlab==4.2.5
|
| 69 |
+
jupyterlab_pygments==0.3.0
|
| 70 |
+
jupyterlab_server==2.27.3
|
| 71 |
+
jupyterlab_widgets==3.0.13
|
| 72 |
+
keyring==23.5.0
|
| 73 |
+
launchpadlib==1.10.16
|
| 74 |
+
lazr.restfulclient==0.14.4
|
| 75 |
+
lazr.uri==1.0.6
|
| 76 |
+
lxml==5.3.0
|
| 77 |
+
markdown-it-py==4.0.0
|
| 78 |
+
MarkupSafe==2.1.5
|
| 79 |
+
matplotlib-inline==0.1.7
|
| 80 |
+
mdurl==0.1.2
|
| 81 |
+
mistune==3.0.2
|
| 82 |
+
more-itertools==8.10.0
|
| 83 |
+
mpmath==1.3.0
|
| 84 |
+
nbclassic==1.1.0
|
| 85 |
+
nbclient==0.10.0
|
| 86 |
+
nbconvert==7.16.4
|
| 87 |
+
nbformat==5.10.4
|
| 88 |
+
nest-asyncio==1.6.0
|
| 89 |
+
networkx==3.2.1
|
| 90 |
+
notebook==6.5.5
|
| 91 |
+
notebook_shim==0.2.4
|
| 92 |
+
numpy==1.26.3
|
| 93 |
+
nvidia-cublas-cu12==12.8.4.1
|
| 94 |
+
nvidia-cuda-cupti-cu12==12.8.90
|
| 95 |
+
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 96 |
+
nvidia-cuda-runtime-cu12==12.8.90
|
| 97 |
+
nvidia-cudnn-cu12==9.19.0.56
|
| 98 |
+
nvidia-cufft-cu12==11.3.3.83
|
| 99 |
+
nvidia-cufile-cu12==1.13.1.3
|
| 100 |
+
nvidia-curand-cu12==10.3.9.90
|
| 101 |
+
nvidia-cusolver-cu12==11.7.3.90
|
| 102 |
+
nvidia-cusparse-cu12==12.5.8.93
|
| 103 |
+
nvidia-cusparselt-cu12==0.7.1
|
| 104 |
+
nvidia-nccl-cu12==2.28.9
|
| 105 |
+
nvidia-nvjitlink-cu12==12.8.93
|
| 106 |
+
nvidia-nvshmem-cu12==3.4.5
|
| 107 |
+
nvidia-nvtx-cu12==12.8.90
|
| 108 |
+
oauthlib==3.2.0
|
| 109 |
+
omegaconf==2.3.0
|
| 110 |
+
overrides==7.7.0
|
| 111 |
+
packaging==24.1
|
| 112 |
+
pandocfilters==1.5.1
|
| 113 |
+
parso==0.8.4
|
| 114 |
+
pexpect==4.9.0
|
| 115 |
+
pillow==10.2.0
|
| 116 |
+
platformdirs==4.3.6
|
| 117 |
+
pluggy==1.6.0
|
| 118 |
+
prometheus_client==0.21.0
|
| 119 |
+
prompt_toolkit==3.0.47
|
| 120 |
+
psutil==6.0.0
|
| 121 |
+
ptyprocess==0.7.0
|
| 122 |
+
pure_eval==0.2.3
|
| 123 |
+
py-spy==0.4.1
|
| 124 |
+
pycparser==2.22
|
| 125 |
+
Pygments==2.18.0
|
| 126 |
+
PyGObject==3.42.1
|
| 127 |
+
PyJWT==2.3.0
|
| 128 |
+
pyparsing==2.4.7
|
| 129 |
+
pytest==9.0.2
|
| 130 |
+
pytest-xdist==3.8.0
|
| 131 |
+
python-apt==2.4.0+ubuntu4
|
| 132 |
+
python-dateutil==2.9.0.post0
|
| 133 |
+
python-json-logger==2.0.7
|
| 134 |
+
PyYAML==6.0.2
|
| 135 |
+
pyzmq==24.0.1
|
| 136 |
+
referencing==0.35.1
|
| 137 |
+
regex==2026.3.32
|
| 138 |
+
requests==2.32.3
|
| 139 |
+
rfc3339-validator==0.1.4
|
| 140 |
+
rfc3986-validator==0.1.1
|
| 141 |
+
rich==14.3.3
|
| 142 |
+
rpds-py==0.20.0
|
| 143 |
+
safetensors==0.7.0
|
| 144 |
+
SecretStorage==3.3.1
|
| 145 |
+
Send2Trash==1.8.3
|
| 146 |
+
sentencepiece==0.2.1
|
| 147 |
+
shellingham==1.5.4
|
| 148 |
+
six==1.16.0
|
| 149 |
+
sniffio==1.3.1
|
| 150 |
+
soupsieve==2.6
|
| 151 |
+
stack-data==0.6.3
|
| 152 |
+
sympy==1.14.0
|
| 153 |
+
systemd-python==234
|
| 154 |
+
terminado==0.18.1
|
| 155 |
+
tinycss2==1.3.0
|
| 156 |
+
tokenizers==0.22.2
|
| 157 |
+
torch==2.11.0+cu128
|
| 158 |
+
torchaudio==2.11.0+cu128
|
| 159 |
+
torchvision==0.26.0+cu128
|
| 160 |
+
tornado==6.4.1
|
| 161 |
+
tqdm==4.67.3
|
| 162 |
+
traitlets==5.14.3
|
| 163 |
+
transformers==5.4.0
|
| 164 |
+
triton==3.6.0
|
| 165 |
+
typer==0.24.1
|
| 166 |
+
types-python-dateutil==2.9.0.20240906
|
| 167 |
+
typing_extensions==4.15.0
|
| 168 |
+
uri-template==1.3.0
|
| 169 |
+
urllib3==2.2.3
|
| 170 |
+
wadllib==1.3.6
|
| 171 |
+
wcwidth==0.2.13
|
| 172 |
+
webcolors==24.8.0
|
| 173 |
+
webencodings==0.5.1
|
| 174 |
+
websocket-client==1.8.0
|
| 175 |
+
widgetsnbextension==4.0.13
|
| 176 |
+
zipp==1.0.0
|
environment/base_python.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Python 3.11.10
|
environment/env_list.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Name Active Path
|
| 2 |
+
────────────────────────────────────────────
|
| 3 |
+
base * /workspace
|
| 4 |
+
rlbench /workspace/envs/rlbench
|
environment/hardware_snapshot.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
NVIDIA RTX PRO 6000 Blackwell Server Edition, 580.126.09, 97887 MiB
|
environment/nvidia_smi.txt
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Mon Mar 30 14:48:02 2026
|
| 2 |
+
+-----------------------------------------------------------------------------------------+
|
| 3 |
+
| NVIDIA-SMI 580.126.09 Driver Version: 580.126.09 CUDA Version: 13.0 |
|
| 4 |
+
+-----------------------------------------+------------------------+----------------------+
|
| 5 |
+
| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |
|
| 6 |
+
| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |
|
| 7 |
+
| | | MIG M. |
|
| 8 |
+
|=========================================+========================+======================|
|
| 9 |
+
| 0 NVIDIA RTX PRO 6000 Blac... On | 00000000:F4:00.0 Off | 0 |
|
| 10 |
+
| N/A 34C P0 83W / 600W | 6110MiB / 97887MiB | 0% Default |
|
| 11 |
+
| | | Disabled |
|
| 12 |
+
+-----------------------------------------+------------------------+----------------------+
|
| 13 |
+
|
| 14 |
+
+-----------------------------------------------------------------------------------------+
|
| 15 |
+
| Processes: |
|
| 16 |
+
| GPU GI CI PID Type Process name GPU Memory |
|
| 17 |
+
| ID ID Usage |
|
| 18 |
+
|=========================================================================================|
|
| 19 |
+
| 0 N/A N/A 181865 G /usr/lib/xorg/Xorg 97MiB |
|
| 20 |
+
| 0 N/A N/A 278028 C python 570MiB |
|
| 21 |
+
| 0 N/A N/A 278251 C+G ...space/envs/rlbench/bin/python 5401MiB |
|
| 22 |
+
+-----------------------------------------------------------------------------------------+
|
environment/reconstruct_anybimanual_overlap_replay.sh
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
if [ "$#" -ne 2 ]; then
|
| 5 |
+
echo "usage: $0 <sharded_multi_dir> <flat_output_dir>" >&2
|
| 6 |
+
echo "example: $0 baselines/AnyBimanual_overlap_replay/multi /tmp/multi_flat" >&2
|
| 7 |
+
exit 1
|
| 8 |
+
fi
|
| 9 |
+
|
| 10 |
+
src="$1"
|
| 11 |
+
dst="$2"
|
| 12 |
+
|
| 13 |
+
mkdir -p "$dst"
|
| 14 |
+
|
| 15 |
+
find "$src" -mindepth 2 -maxdepth 2 -type f -name '*.replay' -print0 \
|
| 16 |
+
| xargs -0 -n 1 -P 32 bash -c '
|
| 17 |
+
f="$1"
|
| 18 |
+
out="$2/$(basename "$f")"
|
| 19 |
+
ln "$f" "$out"
|
| 20 |
+
' _ '{}' "$dst"
|
| 21 |
+
|
| 22 |
+
echo "reconstructed flat replay directory at: $dst"
|
environment/rlbench_pip_freeze.txt
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
absl-py==2.1.0
|
| 2 |
+
accelerate==0.31.0
|
| 3 |
+
addict==2.4.0
|
| 4 |
+
aiohappyeyeballs==2.6.1
|
| 5 |
+
aiohttp==3.13.4
|
| 6 |
+
aiosignal==1.4.0
|
| 7 |
+
antlr4-python3-runtime==4.9.3
|
| 8 |
+
asttokens==3.0.1
|
| 9 |
+
async-timeout==5.0.1
|
| 10 |
+
attrs==26.1.0
|
| 11 |
+
backports.zstd @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_backports.zstd_1767044984/work
|
| 12 |
+
blinker==1.9.0
|
| 13 |
+
blosc==1.11.4
|
| 14 |
+
Brotli @ file:///home/conda/feedstock_root/build_artifacts/brotli-split_1764016952863/work
|
| 15 |
+
cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
|
| 16 |
+
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1772001073725/work/certifi
|
| 17 |
+
cffi @ file:///home/conda/feedstock_root/build_artifacts/cffi_1761202865726/work
|
| 18 |
+
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1773659966602/work
|
| 19 |
+
click==8.3.1
|
| 20 |
+
click-prompt==0.5.1
|
| 21 |
+
clip @ git+https://github.com/openai/CLIP.git@d05afc436d78f1c48dc0dbf8e5980a9d471f35f6
|
| 22 |
+
cloudpickle==3.1.2
|
| 23 |
+
comm==0.2.3
|
| 24 |
+
ConfigArgParse==1.7.5
|
| 25 |
+
contourpy @ file:///home/conda/feedstock_root/build_artifacts/contourpy_1744743067588/work
|
| 26 |
+
cuda-bindings==12.9.4
|
| 27 |
+
cuda-pathfinder==1.2.2
|
| 28 |
+
cuda-toolkit==12.8.1
|
| 29 |
+
cycler @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_cycler_1764466758/work
|
| 30 |
+
dash==4.1.0
|
| 31 |
+
decorator==5.2.1
|
| 32 |
+
docker-pycreds==0.4.0
|
| 33 |
+
einops==0.8.0
|
| 34 |
+
exceptiongroup==1.3.1
|
| 35 |
+
executing==2.2.1
|
| 36 |
+
Farama-Notifications==0.0.4
|
| 37 |
+
fastjsonschema==2.21.2
|
| 38 |
+
filelock @ file:///home/conda/feedstock_root/build_artifacts/filelock_1773313889543/work
|
| 39 |
+
Flask==3.1.3
|
| 40 |
+
fonttools @ file:///home/conda/feedstock_root/build_artifacts/fonttools_1773137064424/work
|
| 41 |
+
freetype-py==2.5.1
|
| 42 |
+
frozenlist==1.8.0
|
| 43 |
+
fsspec==2026.3.0
|
| 44 |
+
ftfy==6.2.0
|
| 45 |
+
gitdb==4.0.12
|
| 46 |
+
GitPython==3.1.46
|
| 47 |
+
gmpy2 @ file:///home/conda/feedstock_root/build_artifacts/gmpy2_1773244929835/work
|
| 48 |
+
grpcio==1.78.0
|
| 49 |
+
gym==0.26.2
|
| 50 |
+
gym-notices==0.1.0
|
| 51 |
+
gymnasium==1.0.0a2
|
| 52 |
+
h2 @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_h2_1756364871/work
|
| 53 |
+
h5py @ file:///home/conda/feedstock_root/build_artifacts/h5py_1774712049671/work
|
| 54 |
+
hf-xet==1.4.2
|
| 55 |
+
hpack @ file:///home/conda/feedstock_root/build_artifacts/hpack_1737618293087/work
|
| 56 |
+
huggingface_hub==0.36.2
|
| 57 |
+
hydra-core==1.3.2
|
| 58 |
+
hyperframe @ file:///home/conda/feedstock_root/build_artifacts/hyperframe_1737618333194/work
|
| 59 |
+
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1760286409563/work
|
| 60 |
+
imageio @ file:///home/conda/feedstock_root/build_artifacts/imageio_1738273805233/work
|
| 61 |
+
imgaug==0.4.0
|
| 62 |
+
importlib_metadata==9.0.0
|
| 63 |
+
ipython==8.39.0
|
| 64 |
+
ipywidgets==8.1.8
|
| 65 |
+
itsdangerous==2.2.0
|
| 66 |
+
jedi==0.19.2
|
| 67 |
+
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_jinja2_1764517220/work
|
| 68 |
+
joblib==1.5.3
|
| 69 |
+
jsonschema==4.26.0
|
| 70 |
+
jsonschema-specifications==2025.9.1
|
| 71 |
+
jupyter_core==5.9.1
|
| 72 |
+
jupyterlab_widgets==3.0.16
|
| 73 |
+
kiwisolver @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_kiwisolver_1773067043/work
|
| 74 |
+
lazy-loader==0.5
|
| 75 |
+
Markdown==3.10.2
|
| 76 |
+
markdown-it-py==4.0.0
|
| 77 |
+
MarkupSafe @ file:///home/conda/feedstock_root/build_artifacts/markupsafe_1772444934960/work
|
| 78 |
+
matplotlib @ file:///home/conda/feedstock_root/build_artifacts/matplotlib-suite_1715976200404/work
|
| 79 |
+
matplotlib-inline==0.2.1
|
| 80 |
+
mdurl==0.1.2
|
| 81 |
+
meshcat==0.3.2
|
| 82 |
+
moviepy==2.2.1
|
| 83 |
+
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1773661943568/work
|
| 84 |
+
multidict==6.7.1
|
| 85 |
+
munkres==1.1.4
|
| 86 |
+
narwhals==2.18.1
|
| 87 |
+
natsort==8.4.0
|
| 88 |
+
nbformat==5.10.4
|
| 89 |
+
nest-asyncio==1.6.0
|
| 90 |
+
networkx @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_networkx_1731521053/work
|
| 91 |
+
numpy==1.26.4
|
| 92 |
+
nvidia-cublas-cu12==12.8.4.1
|
| 93 |
+
nvidia-cuda-cupti-cu12==12.8.90
|
| 94 |
+
nvidia-cuda-nvrtc-cu12==12.8.93
|
| 95 |
+
nvidia-cuda-runtime-cu12==12.8.90
|
| 96 |
+
nvidia-cudnn-cu12==9.19.0.56
|
| 97 |
+
nvidia-cufft-cu12==11.3.3.83
|
| 98 |
+
nvidia-cufile-cu12==1.13.1.3
|
| 99 |
+
nvidia-curand-cu12==10.3.9.90
|
| 100 |
+
nvidia-cusolver-cu12==11.7.3.90
|
| 101 |
+
nvidia-cusparse-cu12==12.5.8.93
|
| 102 |
+
nvidia-cusparselt-cu12==0.7.1
|
| 103 |
+
nvidia-nccl-cu12==2.28.9
|
| 104 |
+
nvidia-nvjitlink-cu12==12.8.93
|
| 105 |
+
nvidia-nvshmem-cu12==3.4.5
|
| 106 |
+
nvidia-nvtx-cu12==12.8.90
|
| 107 |
+
omegaconf==2.3.0
|
| 108 |
+
open3d==0.19.0
|
| 109 |
+
openai==0.28.1
|
| 110 |
+
opencv-python==4.10.0.84
|
| 111 |
+
packaging @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_packaging_1769093650/work
|
| 112 |
+
pandas @ file:///home/conda/feedstock_root/build_artifacts/pandas_1744430447393/work
|
| 113 |
+
parso==0.8.6
|
| 114 |
+
-e git+https://github.com/markusgrotz/peract_bimanual.git@bb0232a6ba3fe116566e9568f0c7af980ed6703d#egg=peract_bimanual
|
| 115 |
+
perceiver-pytorch==0.8.8
|
| 116 |
+
pexpect==4.9.0
|
| 117 |
+
pillow==12.1.1
|
| 118 |
+
platformdirs==4.9.4
|
| 119 |
+
plotly==6.6.0
|
| 120 |
+
ply @ file:///home/conda/feedstock_root/build_artifacts/ply_1733239724146/work
|
| 121 |
+
poetry-core==2.3.2
|
| 122 |
+
prompt_toolkit==3.0.52
|
| 123 |
+
propcache==0.4.1
|
| 124 |
+
protobuf==5.29.6
|
| 125 |
+
psutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_psutil_1769678154/work
|
| 126 |
+
ptyprocess==0.7.0
|
| 127 |
+
pure_eval==0.2.3
|
| 128 |
+
pycparser @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pycparser_1733195786/work
|
| 129 |
+
pyglet==2.1.13
|
| 130 |
+
Pygments==2.20.0
|
| 131 |
+
pyngrok==7.5.1
|
| 132 |
+
PyOpenGL==3.1.0
|
| 133 |
+
pyparsing @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pyparsing_1769003998/work
|
| 134 |
+
PyQt5==5.15.11
|
| 135 |
+
PyQt5_sip==12.17.0
|
| 136 |
+
pyquaternion==0.9.9
|
| 137 |
+
pyrender==0.1.45
|
| 138 |
+
-e git+https://github.com/markusgrotz/PyRep.git@b8bd1d7a3182adcd570d001649c0849047ebf197#egg=PyRep
|
| 139 |
+
PySocks @ file:///home/conda/feedstock_root/build_artifacts/pysocks_1733217236728/work
|
| 140 |
+
python-dateutil @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_python-dateutil_1751104122/work
|
| 141 |
+
pytorch-lamb==1.0.0
|
| 142 |
+
pytz @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_pytz_1773679724/work
|
| 143 |
+
PyYAML @ file:///home/conda/feedstock_root/build_artifacts/pyyaml_1770223234623/work
|
| 144 |
+
pyzmq==27.1.0
|
| 145 |
+
referencing==0.37.0
|
| 146 |
+
regex==2024.5.15
|
| 147 |
+
requests @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_requests_1774462091/work
|
| 148 |
+
retrying==1.4.2
|
| 149 |
+
# Editable install with no version control (reveal-vla-bimanual==0.1.0)
|
| 150 |
+
-e /workspace/reveal_vla_bimanual
|
| 151 |
+
rich==13.9.4
|
| 152 |
+
rich-click==1.8.9
|
| 153 |
+
-e git+https://github.com/markusgrotz/RLBench.git@8af748c51287989294e00c9c670e3330a0e35ed5#egg=rlbench
|
| 154 |
+
rpds-py==0.30.0
|
| 155 |
+
safetensors==0.4.3
|
| 156 |
+
scikit-image==0.25.2
|
| 157 |
+
scikit-learn==1.7.2
|
| 158 |
+
scipy @ file:///home/conda/feedstock_root/build_artifacts/scipy-split_1716470219380/work/dist/scipy-1.13.1-cp310-cp310-linux_x86_64.whl#sha256=a4ff22b6dc27b61196be51695f53f9b0676e7c1bc564872b51fc3c41b79ae80b
|
| 159 |
+
segment-anything==1.0
|
| 160 |
+
sentry-sdk==2.56.0
|
| 161 |
+
setproctitle==1.3.7
|
| 162 |
+
shapely==2.1.2
|
| 163 |
+
sip @ file:///home/conda/feedstock_root/build_artifacts/sip_1759437834046/work
|
| 164 |
+
six @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_six_1753199211/work
|
| 165 |
+
smmap==5.0.3
|
| 166 |
+
stack-data==0.6.3
|
| 167 |
+
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1771952240620/work
|
| 168 |
+
tensorboard==2.16.2
|
| 169 |
+
tensorboard-data-server==0.7.2
|
| 170 |
+
tensorboardX==2.6.4
|
| 171 |
+
termcolor==3.3.0
|
| 172 |
+
threadpoolctl==3.6.0
|
| 173 |
+
tifffile==2025.5.10
|
| 174 |
+
timeout-decorator==0.5.0
|
| 175 |
+
timm==1.0.26
|
| 176 |
+
tokenizers==0.19.1
|
| 177 |
+
toml @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_toml_1764486833/work
|
| 178 |
+
tomli @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tomli_1774492402/work
|
| 179 |
+
torch==2.11.0+cu128
|
| 180 |
+
torchaudio==2.11.0+cu128
|
| 181 |
+
torchvision==0.26.0+cu128
|
| 182 |
+
tornado @ file:///home/conda/feedstock_root/build_artifacts/tornado_1774357896577/work
|
| 183 |
+
tqdm @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_tqdm_1770153424/work
|
| 184 |
+
traitlets==5.14.3
|
| 185 |
+
transformers==4.41.2
|
| 186 |
+
transforms3d==0.4.1
|
| 187 |
+
trimesh @ file:///home/conda/feedstock_root/build_artifacts/trimesh_1774412449209/work
|
| 188 |
+
triton==3.6.0
|
| 189 |
+
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/bld/rattler-build_typing_extensions_1756220668/work
|
| 190 |
+
tzdata @ file:///home/conda/feedstock_root/build_artifacts/python-tzdata_1765719872007/work
|
| 191 |
+
u-msgpack-python==2.8.0
|
| 192 |
+
unicodedata2 @ file:///home/conda/feedstock_root/build_artifacts/unicodedata2_1770908960326/work
|
| 193 |
+
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1767817748113/work
|
| 194 |
+
wandb==0.18.0
|
| 195 |
+
wcwidth==0.2.14
|
| 196 |
+
Werkzeug==3.1.7
|
| 197 |
+
widgetsnbextension==4.0.15
|
| 198 |
+
yarl==1.23.0
|
| 199 |
+
-e git+https://github.com/markusgrotz/YARR.git@6822ff78602c77878b27d4cfe759ce029c67bffb#egg=yarr
|
| 200 |
+
zipp==3.23.0
|
environment/rlbench_python.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Python 3.10.20
|
environment/runtime_env_vars.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export HF_HOME=/workspace/.cache/huggingface
|
| 2 |
+
export MAMBA_ROOT_PREFIX=/workspace/.micromamba
|
| 3 |
+
export DISPLAY=:99
|
| 4 |
+
export PYTHONPATH=/workspace/VLAarchtests/code/reveal_vla_bimanual
|
environment/setup_same_hardware.sh
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
set -euo pipefail
|
| 3 |
+
|
| 4 |
+
ROOT_DIR="${ROOT_DIR:-/workspace}"
|
| 5 |
+
|
| 6 |
+
echo "[setup] expected hardware: NVIDIA RTX PRO 6000 Blackwell Server Edition"
|
| 7 |
+
echo "[setup] expected OS family: Ubuntu 24.04 / Linux 6.8"
|
| 8 |
+
|
| 9 |
+
if [ ! -d "$ROOT_DIR/VLAarchtests" ]; then
|
| 10 |
+
echo "[setup] expected staged tree at $ROOT_DIR/VLAarchtests" >&2
|
| 11 |
+
fi
|
| 12 |
+
|
| 13 |
+
echo "[setup] base python:"
|
| 14 |
+
python --version || true
|
| 15 |
+
|
| 16 |
+
if [ -x "$ROOT_DIR/.tools/micromamba/bin/micromamba" ]; then
|
| 17 |
+
echo "[setup] micromamba envs:"
|
| 18 |
+
"$ROOT_DIR/.tools/micromamba/bin/micromamba" env list || true
|
| 19 |
+
fi
|
| 20 |
+
|
| 21 |
+
echo "[setup] current GPU snapshot:"
|
| 22 |
+
nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader || true
|
| 23 |
+
|
| 24 |
+
echo "[setup] this repo includes package snapshots in environment/"
|
| 25 |
+
echo "[setup] recreate the rlbench env at $ROOT_DIR/envs/rlbench before running the overlap baseline path"
|
environment/uname.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Linux 129a2ec0f4a9 6.8.0-106-generic #106-Ubuntu SMP PREEMPT_DYNAMIC Fri Mar 6 07:58:08 UTC 2026 x86_64 x86_64 x86_64 GNU/Linux
|
handoff/instructions4.md
ADDED
|
@@ -0,0 +1,591 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Developer handoff: 10-hour sim sprint on 1×RTX PRO 6000 for elastic-occlusion bimanual reveal/retrieve
|
| 2 |
+
|
| 3 |
+
## Scope
|
| 4 |
+
|
| 5 |
+
This handoff is for the current simulation phase only. The purpose is not to produce publication-grade evidence. The purpose is to use one short sprint to extract the most decision-relevant signal possible before the custom three-task teleoperation benchmark exists.
|
| 6 |
+
|
| 7 |
+
This document does **not** include explicit instructions for future teleoperation data collection.
|
| 8 |
+
|
| 9 |
+
The target problem remains the same: bimanual reveal and retrieve under elastic occlusion, with task families that map to (1) foliage reveal with safe actor insertion and retrieval, (2) bag opening plus retrieval, and (3) folded-cloth / suitcase reveal with minimal fold disruption.
|
| 10 |
+
|
| 11 |
+
## Hard constraints for this sprint
|
| 12 |
+
|
| 13 |
+
The sprint must assume the following:
|
| 14 |
+
|
| 15 |
+
- Hardware: **1× RTX PRO 6000** workstation GPU.
|
| 16 |
+
- Wall-clock budget: **~10 hours total**.
|
| 17 |
+
- Deliverable standard: **decision-quality results**, not paper-quality results. They must be rigorous nonetheless - absolutely no data leaks.
|
| 18 |
+
- The output of the sprint must be enough to shape the next development cycle.
|
| 19 |
+
|
| 20 |
+
This means the sprint is not allowed to expand into a broad refactor, a large hyperparameter search, an external benchmark integration project, or a foundation-model migration. The objective is to get the strongest signal per hour.
|
| 21 |
+
|
| 22 |
+
## What this sprint needs to answer
|
| 23 |
+
|
| 24 |
+
At the end of the 10-hour window, the repo should let us answer these questions with reasonable confidence:
|
| 25 |
+
|
| 26 |
+
1. Does the current full architecture look directionally better than trivial and structured baselines on rough proxy versions of the three real task families?
|
| 27 |
+
2. Which parts of the architecture appear to matter most right now: explicit task conditioning, geometry, memory, planner, or candidate family structure?
|
| 28 |
+
3. For each of the three target tasks, is the current architecture best described as **promising**, **uncertain**, or **weak** under the proxy tests?
|
| 29 |
+
4. If the system is weak, is the weakness coming mostly from perception/state estimation, memory, retrieve gating/planning, or proxy mismatch?
|
| 30 |
+
5. Which next engineering changes are most likely to widen the eventual real-task performance gap, and which ones are unlikely to matter?
|
| 31 |
+
|
| 32 |
+
That is the bar for success in this sprint. It is acceptable if the results are approximate. It is not acceptable if the results are too noisy or too poorly instrumented to support those decisions.
|
| 33 |
+
|
| 34 |
+
## What this sprint can and cannot tell us
|
| 35 |
+
|
| 36 |
+
This sprint **can** tell us whether the current structured architecture is showing the right dependencies under stress. For example, it can tell us whether memory actually helps under reocclusion, whether geometry actually helps under camera perturbation, whether retrieve gating blocks premature retrieve, and whether the planner adds value beyond trivial candidate selection.
|
| 37 |
+
|
| 38 |
+
This sprint **cannot** tell us true real-world performance on live foliage, real bag interiors, or folded clothes in a suitcase. It also cannot tell us whether the final production backbone should be CLIP, OpenVLA-style, LingBot-style, or something else. Those are later decisions.
|
| 39 |
+
|
| 40 |
+
So the correct target is not “exact future performance.” The correct target is “a useful approximation of whether the current structure is pointed in the right direction, and where the biggest current bottlenecks are.”
|
| 41 |
+
|
| 42 |
+
## High-level decisions to lock now
|
| 43 |
+
|
| 44 |
+
1. Treat the current **compact-phase CLIP/RGB-D handoff branch** as the only reference branch for this sprint.
|
| 45 |
+
2. Keep the explicit reveal-state stack. Do **not** rewrite the repo into a monolithic end-to-end VLA policy in this sprint.
|
| 46 |
+
3. Keep RLBench only as a smoke test for three-camera and bimanual integration. Do **not** use RLBench mean success as the main selector for reveal/retrieve architecture changes.
|
| 47 |
+
4. Do **not** switch to a new backbone in this sprint. The goal here is to evaluate the structure, not to spend the 10-hour budget on trunk migration.
|
| 48 |
+
5. Prefer **eval-time knockouts and toggles** over fully retrained matched ablations in this sprint. Retrained matched ablations remain important later, but they are too expensive for the current time budget.
|
| 49 |
+
|
| 50 |
+
## Immediate read of the current repo
|
| 51 |
+
|
| 52 |
+
The current repo is still a good scaffold, but the signal quality is weaker than the structure quality. The strongest part of the system remains the decomposition itself: explicit reveal-state fields, dual scene/belief memory, task-conditioned proposal families, and a planner that reasons about persistence, reocclusion, support, and actor feasibility. The weakest part is that the current evidence is still too easy to misread. The proxy results already suggest that the large spatial rollout path is not the right place to spend another iteration right now, while the compact-phase line remains the most credible base. That should be treated as settled for this sprint.
|
| 53 |
+
|
| 54 |
+
## Sprint strategy
|
| 55 |
+
|
| 56 |
+
The fastest path to useful conclusions is:
|
| 57 |
+
|
| 58 |
+
1. Remove the most obvious confounds.
|
| 59 |
+
2. Add just enough stress slicing and logging to make the proxy benchmark informative.
|
| 60 |
+
3. Run a **small but fixed stratified benchmark** that is reused for every comparison.
|
| 61 |
+
4. Compare the base model against trivial baselines and a few **eval-time architecture knockouts**.
|
| 62 |
+
5. Use the results to build a task-by-task bottleneck map.
|
| 63 |
+
|
| 64 |
+
This is deliberately narrower than the earlier broad handoff. The point is to obtain meaningful conclusions within one GPU and one workday.
|
| 65 |
+
|
| 66 |
+
## Do now vs defer
|
| 67 |
+
|
| 68 |
+
### Must do in this sprint
|
| 69 |
+
|
| 70 |
+
These are the changes that are worth the time even under a 10-hour cap.
|
| 71 |
+
|
| 72 |
+
1. Explicit task metadata must override text routing.
|
| 73 |
+
2. History camera geometry must propagate correctly.
|
| 74 |
+
3. The compact-phase branch must become the default base config.
|
| 75 |
+
4. The proxy benchmark must become stratified by task and stress slice.
|
| 76 |
+
5. The benchmark runner must support simple baselines and eval-time architecture knockouts.
|
| 77 |
+
6. Reporting must include the small set of task-specific metrics that actually matter for the three target tasks.
|
| 78 |
+
|
| 79 |
+
### Explicitly defer until after this sprint
|
| 80 |
+
|
| 81 |
+
These are good ideas, but they should not consume the current 10-hour budget.
|
| 82 |
+
|
| 83 |
+
1. Foundation trunk migration (OpenVLA, LingBot, π0.5, etc.).
|
| 84 |
+
2. External deformable benchmark integration.
|
| 85 |
+
3. Full matched retraining ablation suite.
|
| 86 |
+
4. Major loss redesign or long retraining campaigns.
|
| 87 |
+
5. Large nuisance sweeps beyond the narrow stress slices listed below.
|
| 88 |
+
6. Spatial rollout branch rescue work.
|
| 89 |
+
|
| 90 |
+
## Mandatory repo changes for this sprint
|
| 91 |
+
|
| 92 |
+
### 1. Replace heuristic task routing as the primary path
|
| 93 |
+
|
| 94 |
+
**Why now:** this is a real confound and cheap to fix.
|
| 95 |
+
|
| 96 |
+
Files to edit:
|
| 97 |
+
|
| 98 |
+
- `code/reveal_vla_bimanual/models/policy.py`
|
| 99 |
+
- `code/reveal_vla_bimanual/models/action_decoder.py`
|
| 100 |
+
- `code/reveal_vla_bimanual/sim_reveal/dataset.py`
|
| 101 |
+
- `code/reveal_vla_bimanual/sim_reveal/generate_dataset.py`
|
| 102 |
+
- `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
|
| 103 |
+
|
| 104 |
+
Required changes:
|
| 105 |
+
|
| 106 |
+
- Add `task_name` and `task_id` to every proxy training and evaluation example.
|
| 107 |
+
- Make explicit task metadata override any text-based inference everywhere.
|
| 108 |
+
- Keep keyword routing only as a fallback for legacy examples that do not carry task metadata.
|
| 109 |
+
- Surface the resolved task family in benchmark logs so mistakes are easy to see.
|
| 110 |
+
|
| 111 |
+
Acceptance criterion:
|
| 112 |
+
|
| 113 |
+
- A misleading prompt string must not change the task family when `task_name` is present.
|
| 114 |
+
|
| 115 |
+
### 2. Fix history geometry propagation
|
| 116 |
+
|
| 117 |
+
**Why now:** if geometry is broken in history, the current geometry ablations are not trustworthy.
|
| 118 |
+
|
| 119 |
+
Files to edit:
|
| 120 |
+
|
| 121 |
+
- `code/reveal_vla_bimanual/models/policy.py`
|
| 122 |
+
- any history batching utility that currently drops camera matrices
|
| 123 |
+
- proxy dataset serialization if history camera metadata is missing
|
| 124 |
+
|
| 125 |
+
Required changes:
|
| 126 |
+
|
| 127 |
+
- Save and batch history camera intrinsics and extrinsics.
|
| 128 |
+
- Pass them through the history encoder when geometry and camera-pose tokens are enabled.
|
| 129 |
+
- Add a validity mask if some history frames do not have full camera metadata.
|
| 130 |
+
- Add a debug log or assertion path that makes it obvious whether history geometry is really being used.
|
| 131 |
+
|
| 132 |
+
Acceptance criterion:
|
| 133 |
+
|
| 134 |
+
- A geometry-enabled run must receive non-null history camera tensors in the forward path.
|
| 135 |
+
|
| 136 |
+
### 3. Freeze the compact-phase branch as the main base
|
| 137 |
+
|
| 138 |
+
**Why now:** the time budget does not allow another architecture round on the weaker spatial branch.
|
| 139 |
+
|
| 140 |
+
Files to edit:
|
| 141 |
+
|
| 142 |
+
- `code/reveal_vla_bimanual/train/configs/*.yaml`
|
| 143 |
+
- `code/reveal_vla_bimanual/eval/run_ablations.py`
|
| 144 |
+
- any training or eval launcher that still points to older dummy ablation configs
|
| 145 |
+
|
| 146 |
+
Required changes:
|
| 147 |
+
|
| 148 |
+
- Create one new base config derived from the compact-phase recipe. Suggested filename:
|
| 149 |
+
- `proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact_phase_v7_base.yaml`
|
| 150 |
+
- Mark the current spatial configs as experimental.
|
| 151 |
+
- Make the new v7 base the default for this sprint.
|
| 152 |
+
- Do not create a broad new family of retrained ablation configs now. That is for later.
|
| 153 |
+
|
| 154 |
+
Acceptance criterion:
|
| 155 |
+
|
| 156 |
+
- The benchmark runner should use the v7 compact-phase config by default unless explicitly told otherwise.
|
| 157 |
+
|
| 158 |
+
### 4. Add a narrow but informative proxy stress suite
|
| 159 |
+
|
| 160 |
+
**Why now:** the current proxies are too undifferentiated. The sprint only needs enough stress structure to make results interpretable.
|
| 161 |
+
|
| 162 |
+
Files to edit:
|
| 163 |
+
|
| 164 |
+
- `code/reveal_vla_bimanual/sim_reveal/procedural_envs.py`
|
| 165 |
+
- `code/reveal_vla_bimanual/sim_reveal/proxy_specs.py`
|
| 166 |
+
- `code/reveal_vla_bimanual/sim_reveal/dataset.py`
|
| 167 |
+
- `code/reveal_vla_bimanual/sim_reveal/generate_dataset.py`
|
| 168 |
+
|
| 169 |
+
Required changes:
|
| 170 |
+
|
| 171 |
+
Add only these stress slices for this sprint:
|
| 172 |
+
|
| 173 |
+
- `nominal`
|
| 174 |
+
- `high_reocclusion`
|
| 175 |
+
- `camera_perturbation`
|
| 176 |
+
- one task-specific critical slice per task:
|
| 177 |
+
- foliage: `tight_corridor_high_collateral`
|
| 178 |
+
- bag: `one_sided_slip`
|
| 179 |
+
- cloth: `fold_sensitive_long_persistence`
|
| 180 |
+
|
| 181 |
+
Also add:
|
| 182 |
+
|
| 183 |
+
- `difficulty_bin` with only `medium` and `hard` for this sprint (skip easy and extreme to save time and focus on decision-relevant cases)
|
| 184 |
+
- episode metadata for the sampled nuisance parameters used in each stress slice
|
| 185 |
+
- per-step traces for visibility, support, access/corridor, reocclusion risk, disturbance, and chosen candidate family
|
| 186 |
+
|
| 187 |
+
Acceptance criterion:
|
| 188 |
+
|
| 189 |
+
- Every benchmark report must be sliceable by task family, stress slice, and difficulty bin.
|
| 190 |
+
|
| 191 |
+
### 5. Add simple baselines and oracle-style planner evaluation
|
| 192 |
+
|
| 193 |
+
**Why now:** without these, base-model numbers are hard to interpret.
|
| 194 |
+
|
| 195 |
+
Files to edit/add:
|
| 196 |
+
|
| 197 |
+
- `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
|
| 198 |
+
- `code/reveal_vla_bimanual/eval/metrics.py`
|
| 199 |
+
- add `code/reveal_vla_bimanual/eval/run_proxy_random_eval.py`
|
| 200 |
+
- add `code/reveal_vla_bimanual/eval/run_proxy_candidate0_eval.py`
|
| 201 |
+
- add `code/reveal_vla_bimanual/eval/run_planner_oracle_eval.py`
|
| 202 |
+
- add `code/reveal_vla_bimanual/eval/run_proxy_scripted_eval.py` if the existing scripted path is not already callable directly
|
| 203 |
+
|
| 204 |
+
Required changes:
|
| 205 |
+
|
| 206 |
+
- Add random candidate selection.
|
| 207 |
+
- Add candidate-0 selection.
|
| 208 |
+
- Add scripted teacher execution.
|
| 209 |
+
- Add oracle-planner evaluation that uses proxy candidate summaries directly.
|
| 210 |
+
- Add support for eval-time architecture toggles:
|
| 211 |
+
- `--disable_planner`
|
| 212 |
+
- `--disable_memory`
|
| 213 |
+
- `--disable_task_conditioning`
|
| 214 |
+
- `--disable_geometry`
|
| 215 |
+
- `--disable_camera_pose` (optional if it is cheap)
|
| 216 |
+
|
| 217 |
+
Acceptance criterion:
|
| 218 |
+
|
| 219 |
+
- The benchmark must be able to compare the same checkpoint against trivial baselines and against architecture knockouts without retraining all variants.
|
| 220 |
+
|
| 221 |
+
### 6. Strengthen reporting, but only where it matters
|
| 222 |
+
|
| 223 |
+
**Why now:** the current sprint only succeeds if it produces conclusions, not just numbers.
|
| 224 |
+
|
| 225 |
+
Files to edit:
|
| 226 |
+
|
| 227 |
+
- `code/reveal_vla_bimanual/eval/metrics.py`
|
| 228 |
+
- `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
|
| 229 |
+
|
| 230 |
+
Required metric outputs:
|
| 231 |
+
|
| 232 |
+
Global:
|
| 233 |
+
|
| 234 |
+
- overall proxy success
|
| 235 |
+
- per-task success
|
| 236 |
+
- success by stress slice
|
| 237 |
+
- success by difficulty bin
|
| 238 |
+
- premature retrieve rate
|
| 239 |
+
- reocclusion-after-reveal rate
|
| 240 |
+
- planner regret (where oracle summaries are available)
|
| 241 |
+
|
| 242 |
+
Task-specific headline metrics:
|
| 243 |
+
|
| 244 |
+
Foliage:
|
| 245 |
+
|
| 246 |
+
- visibility integral
|
| 247 |
+
- corridor availability
|
| 248 |
+
- collateral motion / damage proxy
|
| 249 |
+
- actor-feasibility floor before retrieve
|
| 250 |
+
|
| 251 |
+
Bag:
|
| 252 |
+
|
| 253 |
+
- mouth aperture
|
| 254 |
+
- hold persistence
|
| 255 |
+
- rim slip rate
|
| 256 |
+
- insertable corridor
|
| 257 |
+
|
| 258 |
+
Cloth:
|
| 259 |
+
|
| 260 |
+
- fold preservation
|
| 261 |
+
- layer separation quality
|
| 262 |
+
- top-layer stability
|
| 263 |
+
- lift-too-high rate
|
| 264 |
+
|
| 265 |
+
Required report shape:
|
| 266 |
+
|
| 267 |
+
- one overall table
|
| 268 |
+
- one task × stress slice table
|
| 269 |
+
- per-episode JSON traces for failure clustering later
|
| 270 |
+
|
| 271 |
+
Acceptance criterion:
|
| 272 |
+
|
| 273 |
+
- A single report should make it obvious whether the model is failing because of reocclusion, geometry sensitivity, premature retrieve, or task-specific degradation.
|
| 274 |
+
|
| 275 |
+
## Changes that are useful only if they are cheap
|
| 276 |
+
|
| 277 |
+
These are allowed only if the mandatory work finishes early.
|
| 278 |
+
|
| 279 |
+
### A. Light training rebalance
|
| 280 |
+
|
| 281 |
+
Only do this if it can be implemented in under about one hour.
|
| 282 |
+
|
| 283 |
+
Allowed small changes:
|
| 284 |
+
|
| 285 |
+
- oversample obvious hard negatives already present in the proxy dataset
|
| 286 |
+
- slightly increase loss weight on unsafe-retrieve ranking errors
|
| 287 |
+
- log candidate ranking diagnostics during training
|
| 288 |
+
|
| 289 |
+
Do **not** do a broad loss redesign in this sprint.
|
| 290 |
+
|
| 291 |
+
### B. Trunk modularity prep
|
| 292 |
+
|
| 293 |
+
Only do this if the mandatory work is already complete.
|
| 294 |
+
|
| 295 |
+
Allowed small changes:
|
| 296 |
+
|
| 297 |
+
- define a simple trunk adapter interface around the current CLIP path
|
| 298 |
+
- avoid touching planner, memory, or reveal-head code
|
| 299 |
+
|
| 300 |
+
Do **not** attempt a real new-backbone integration in this sprint.
|
| 301 |
+
|
| 302 |
+
## 10-hour execution plan
|
| 303 |
+
|
| 304 |
+
This schedule is the intended wall-clock plan. It is aggressive but realistic if the scope stays narrow.
|
| 305 |
+
|
| 306 |
+
### Hour 0 to 1.5: remove confounds
|
| 307 |
+
|
| 308 |
+
Complete:
|
| 309 |
+
|
| 310 |
+
- explicit task metadata path
|
| 311 |
+
- history geometry propagation
|
| 312 |
+
- v7 compact-phase base config
|
| 313 |
+
- eval-time toggle plumbing in the benchmark runner
|
| 314 |
+
|
| 315 |
+
Output expected by the end of this block:
|
| 316 |
+
|
| 317 |
+
- code compiles
|
| 318 |
+
- a tiny smoke run confirms task routing and history geometry are active
|
| 319 |
+
|
| 320 |
+
### Hour 1.5 to 3: build the fixed eval set and reporting
|
| 321 |
+
|
| 322 |
+
Complete:
|
| 323 |
+
|
| 324 |
+
- stratified proxy eval set generation
|
| 325 |
+
- task/stress/difficulty metadata roundtrip
|
| 326 |
+
- benchmark tables and per-episode JSON traces
|
| 327 |
+
- random, candidate-0, scripted, and oracle evaluation entry points
|
| 328 |
+
|
| 329 |
+
Output expected by the end of this block:
|
| 330 |
+
|
| 331 |
+
- one fixed benchmark episode set reused by every later run
|
| 332 |
+
|
| 333 |
+
### Hour 3 to 5.5: produce one base-model result
|
| 334 |
+
|
| 335 |
+
Preferred path:
|
| 336 |
+
|
| 337 |
+
- reuse the best existing compact-phase checkpoint if it still loads cleanly after the metadata and geometry fixes
|
| 338 |
+
- if needed, run a short warm-start fine-tune from that checkpoint rather than starting from scratch
|
| 339 |
+
|
| 340 |
+
Do not spend this block on multi-seed training. One strong base run is more valuable than several weak incomplete runs.
|
| 341 |
+
|
| 342 |
+
Output expected by the end of this block:
|
| 343 |
+
|
| 344 |
+
- one evaluated base model on the full fixed proxy suite
|
| 345 |
+
|
| 346 |
+
### Hour 5.5 to 8.5: run baselines and eval-time knockouts
|
| 347 |
+
|
| 348 |
+
Required comparisons:
|
| 349 |
+
|
| 350 |
+
- random
|
| 351 |
+
- candidate-0
|
| 352 |
+
- scripted teacher
|
| 353 |
+
- oracle planner
|
| 354 |
+
- base model
|
| 355 |
+
- base model with planner disabled
|
| 356 |
+
- base model with memory disabled
|
| 357 |
+
- base model with task conditioning disabled
|
| 358 |
+
- base model with geometry disabled
|
| 359 |
+
- base model with camera pose disabled if cheap enough
|
| 360 |
+
|
| 361 |
+
Output expected by the end of this block:
|
| 362 |
+
|
| 363 |
+
- a complete comparison table on the same episodes
|
| 364 |
+
|
| 365 |
+
### Hour 8.5 to 10: summarize and extract conclusions
|
| 366 |
+
|
| 367 |
+
Complete:
|
| 368 |
+
|
| 369 |
+
- task-by-task bottleneck summary
|
| 370 |
+
- approximate transfer-readiness labels for foliage, bag, and cloth
|
| 371 |
+
- ranked next-step engineering priorities
|
| 372 |
+
|
| 373 |
+
Output expected by the end of this block:
|
| 374 |
+
|
| 375 |
+
- one benchmark summary table
|
| 376 |
+
- one short conclusion memo in the repo or artifact directory
|
| 377 |
+
|
| 378 |
+
## Fixed benchmark design for the 10-hour sprint
|
| 379 |
+
|
| 380 |
+
Use one small, fixed, stratified benchmark. Reuse the same episode seeds for all variants.
|
| 381 |
+
|
| 382 |
+
Recommended size:
|
| 383 |
+
|
| 384 |
+
- **300 total episodes**
|
| 385 |
+
- **100 per task family**
|
| 386 |
+
- per task family:
|
| 387 |
+
- 20 `nominal` / `medium`
|
| 388 |
+
- 20 `nominal` / `hard`
|
| 389 |
+
- 20 `high_reocclusion`
|
| 390 |
+
- 20 `camera_perturbation`
|
| 391 |
+
- 20 task-specific critical slice
|
| 392 |
+
|
| 393 |
+
This is small enough to run repeatedly and large enough to show directional differences if the logging is good.
|
| 394 |
+
|
| 395 |
+
## Required tests to add now
|
| 396 |
+
|
| 397 |
+
The test suite in this sprint is not meant to be exhaustive. It is meant to prevent false conclusions.
|
| 398 |
+
|
| 399 |
+
### Unit tests
|
| 400 |
+
|
| 401 |
+
`tests/test_explicit_task_metadata_overrides_text.py`
|
| 402 |
+
|
| 403 |
+
- Batch has `task_name="bag"` and misleading foliage text.
|
| 404 |
+
- Assert that bag proposal families and bag task heads are used.
|
| 405 |
+
|
| 406 |
+
`tests/test_text_routing_only_used_as_fallback.py`
|
| 407 |
+
|
| 408 |
+
- Assert that keyword routing is skipped when task metadata exists.
|
| 409 |
+
|
| 410 |
+
`tests/test_history_camera_geometry_propagates.py`
|
| 411 |
+
|
| 412 |
+
- Assert that history frames receive non-null camera tensors when geometry is enabled.
|
| 413 |
+
|
| 414 |
+
`tests/test_history_geometry_validity_mask.py`
|
| 415 |
+
|
| 416 |
+
- Assert that mixed valid/invalid history geometry uses a validity mask rather than silent nulling.
|
| 417 |
+
|
| 418 |
+
`tests/test_eval_toggle_paths_work.py`
|
| 419 |
+
|
| 420 |
+
- Assert that planner, memory, task-conditioning, and geometry toggles actually change the execution path.
|
| 421 |
+
|
| 422 |
+
`tests/test_benchmark_report_contains_task_and_stress_slices.py`
|
| 423 |
+
|
| 424 |
+
- Assert that the output report includes task family, stress slice, and difficulty bin tables.
|
| 425 |
+
|
| 426 |
+
### Integration tests
|
| 427 |
+
|
| 428 |
+
`tests/test_proxy_stress_profile_metadata_roundtrip.py`
|
| 429 |
+
|
| 430 |
+
- Generate all sprint stress slices and assert the metadata survives dataset serialization and evaluation.
|
| 431 |
+
|
| 432 |
+
`tests/test_planner_beats_random_on_oracle_candidates.py`
|
| 433 |
+
|
| 434 |
+
- Use oracle candidate summaries and assert the planner beats random and candidate-0 on regret/top-1.
|
| 435 |
+
|
| 436 |
+
`tests/test_memory_matters_under_high_reocclusion.py`
|
| 437 |
+
|
| 438 |
+
- Compare full memory vs disabled-memory on a small high-reocclusion slice and assert a directional drop.
|
| 439 |
+
|
| 440 |
+
`tests/test_geometry_matters_under_camera_perturbation.py`
|
| 441 |
+
|
| 442 |
+
- Compare geometry-on vs geometry-off on a small camera-perturbation slice and assert a directional drop.
|
| 443 |
+
|
| 444 |
+
`tests/test_retrieve_gating_blocks_premature_retrieve.py`
|
| 445 |
+
|
| 446 |
+
- Feed candidates where raw retrieve looks tempting but support/persistence/corridor are unsafe and assert that retrieve is rejected.
|
| 447 |
+
|
| 448 |
+
These tests are enough for this sprint. Do not expand the suite unless the mandatory implementation is already done.
|
| 449 |
+
|
| 450 |
+
## Benchmark runs to perform in this sprint
|
| 451 |
+
|
| 452 |
+
### 1. Baselines
|
| 453 |
+
|
| 454 |
+
Run on the fixed 300-episode benchmark:
|
| 455 |
+
|
| 456 |
+
- random candidate selection
|
| 457 |
+
- candidate-0 selection
|
| 458 |
+
- scripted teacher
|
| 459 |
+
- oracle planner
|
| 460 |
+
|
| 461 |
+
These establish the floor, the structured upper bound for the current proposal set, and whether the learned planner adds anything.
|
| 462 |
+
|
| 463 |
+
### 2. Base model run
|
| 464 |
+
|
| 465 |
+
Run the v7 compact-phase base on the same 300 episodes.
|
| 466 |
+
|
| 467 |
+
### 3. Eval-time architecture knockouts
|
| 468 |
+
|
| 469 |
+
Run the same checkpoint with these toggles:
|
| 470 |
+
|
| 471 |
+
- no planner
|
| 472 |
+
- no memory
|
| 473 |
+
- no task conditioning
|
| 474 |
+
- no geometry
|
| 475 |
+
- no camera pose (only if cheap)
|
| 476 |
+
|
| 477 |
+
These are not final scientific ablations. They are fast directional probes that tell us what appears to matter right now.
|
| 478 |
+
|
| 479 |
+
### 4. Optional cheap extra run
|
| 480 |
+
|
| 481 |
+
Only if time remains:
|
| 482 |
+
|
| 483 |
+
- a short warm-start fine-tune with a small hard-negative rebalance
|
| 484 |
+
|
| 485 |
+
Treat this only as bonus signal. It is not part of the minimum sprint success condition.
|
| 486 |
+
|
| 487 |
+
## Decision rubric for interpreting the results
|
| 488 |
+
|
| 489 |
+
At the end of the sprint, do **not** only report the raw numbers. Convert the results into a bottleneck map.
|
| 490 |
+
|
| 491 |
+
### Signals that the architecture is directionally healthy
|
| 492 |
+
|
| 493 |
+
- Base clearly beats random and candidate-0 on all three task families.
|
| 494 |
+
- Oracle planner beats random and candidate-0 by a wide margin, showing the proposal/planner structure is at least usable.
|
| 495 |
+
- Disabling planner hurts most on premature-retrieve and task-specific stress slices.
|
| 496 |
+
- Disabling memory hurts most on `high_reocclusion` and long-persistence cloth cases.
|
| 497 |
+
- Disabling geometry hurts most on `camera_perturbation`.
|
| 498 |
+
- Task conditioning matters on the mixed-task benchmark and under misleading text.
|
| 499 |
+
|
| 500 |
+
### Signals that a component is currently unproven
|
| 501 |
+
|
| 502 |
+
- Base is only slightly above candidate-0 or random.
|
| 503 |
+
- Oracle planner is weak, which means the proposal set or planner utility is not yet reliable.
|
| 504 |
+
- Memory removal is almost flat on `high_reocclusion`.
|
| 505 |
+
- Geometry removal is almost flat on `camera_perturbation`.
|
| 506 |
+
- Planner removal is flat, which suggests that the learned scores or the candidate shortlist are not carrying useful structure.
|
| 507 |
+
|
| 508 |
+
### Task-specific interpretation
|
| 509 |
+
|
| 510 |
+
For foliage, judge transfer-readiness mainly from:
|
| 511 |
+
|
| 512 |
+
- corridor availability
|
| 513 |
+
- actor-feasibility floor before retrieve
|
| 514 |
+
- collateral motion / damage proxy
|
| 515 |
+
- robustness under `high_reocclusion` and `tight_corridor_high_collateral`
|
| 516 |
+
|
| 517 |
+
For bag, judge transfer-readiness mainly from:
|
| 518 |
+
|
| 519 |
+
- mouth aperture
|
| 520 |
+
- hold persistence
|
| 521 |
+
- rim slip rate
|
| 522 |
+
- robustness under `one_sided_slip`
|
| 523 |
+
|
| 524 |
+
For cloth, judge transfer-readiness mainly from:
|
| 525 |
+
|
| 526 |
+
- fold preservation
|
| 527 |
+
- layer separation quality
|
| 528 |
+
- top-layer stability
|
| 529 |
+
- lift-too-high rate
|
| 530 |
+
- robustness under `fold_sensitive_long_persistence`
|
| 531 |
+
|
| 532 |
+
At the end of the sprint, label each task family as:
|
| 533 |
+
|
| 534 |
+
- **Promising**: base beats weak baselines clearly and the expected architectural components matter on the right stress slices.
|
| 535 |
+
- **Uncertain**: base is somewhat above baselines but at least one critical stress slice or component dependency is weak.
|
| 536 |
+
- **Weak**: base is near trivial baselines or the critical stress slices fail badly enough that the current structure is not yet convincing.
|
| 537 |
+
|
| 538 |
+
## Practical GPU/runtime guidance
|
| 539 |
+
|
| 540 |
+
This sprint assumes a single 96 GB workstation GPU. That is enough for the current CLIP-based compact-phase line and repeated proxy evaluations, but it is not enough time to justify broad parallel experiments.
|
| 541 |
+
|
| 542 |
+
Use the following operating rules:
|
| 543 |
+
|
| 544 |
+
- run experiments sequentially
|
| 545 |
+
- prefer one strong base run over many partial runs
|
| 546 |
+
- use mixed precision if already supported
|
| 547 |
+
- keep evaluation batch sizes modest and stable
|
| 548 |
+
- avoid large retraining loops or many seeds
|
| 549 |
+
- reuse the same fixed benchmark episodes for every comparison
|
| 550 |
+
|
| 551 |
+
The main bottleneck in this sprint should be engineering and interpretation, not raw VRAM.
|
| 552 |
+
|
| 553 |
+
## Things not to do in this sprint
|
| 554 |
+
|
| 555 |
+
Do not switch backbones. Do not integrate an external simulator. Do not rescue the spatial rollout branch. Do not run broad hyperparameter sweeps. Do not attempt a five-seed retraining ablation matrix. Do not use RLBench averages as the main argument for or against architecture changes meant for foliage, bag, or folded cloth.
|
| 556 |
+
|
| 557 |
+
## Minimal deliverables at the 10-hour mark
|
| 558 |
+
|
| 559 |
+
At the end of the sprint, the repo or artifact directory should contain:
|
| 560 |
+
|
| 561 |
+
1. the new `v7` compact-phase base config
|
| 562 |
+
2. the fixed 300-episode stratified benchmark definition or metadata file
|
| 563 |
+
3. updated benchmark runner with stress-slice reporting
|
| 564 |
+
4. random, candidate-0, scripted, and oracle evaluation runners
|
| 565 |
+
5. the required unit and integration tests listed above
|
| 566 |
+
6. one complete result table comparing:
|
| 567 |
+
- random
|
| 568 |
+
- candidate-0
|
| 569 |
+
- scripted teacher
|
| 570 |
+
- oracle planner
|
| 571 |
+
- base model
|
| 572 |
+
- no planner
|
| 573 |
+
- no memory
|
| 574 |
+
- no task conditioning
|
| 575 |
+
- no geometry
|
| 576 |
+
- optional no camera pose
|
| 577 |
+
7. one short decision memo that states:
|
| 578 |
+
- approximate transfer-readiness for foliage, bag, and cloth
|
| 579 |
+
- which architectural components look most important
|
| 580 |
+
- which current weakness is most likely to block real-task success
|
| 581 |
+
- what should be strengthened next
|
| 582 |
+
|
| 583 |
+
## Success condition for this sprint
|
| 584 |
+
|
| 585 |
+
This sprint is successful if, by the end of 10 hours, we can say something like the following and defend it with benchmark evidence:
|
| 586 |
+
|
| 587 |
+
- “The structured architecture appears meaningfully better than trivial baselines on foliage and bag proxies, but cloth remains fragile because fold preservation degrades under long persistence.”
|
| 588 |
+
- or “The planner structure looks sound under oracle candidates, but the learned state estimate is still too weak, so the next work should target perception/memory rather than planner redesign.”
|
| 589 |
+
- or “Geometry and task conditioning matter, but memory does not yet move the reocclusion slice, so the current memory story is still unproven.”
|
| 590 |
+
|
| 591 |
+
If we can make claims of that form with actual run outputs, the sprint has done its job.
|
history/VLAarchtests_previous_README.md
ADDED
|
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
tags:
|
| 3 |
+
- robotics
|
| 4 |
+
- vision-language-action
|
| 5 |
+
- bimanual-manipulation
|
| 6 |
+
- rlbench
|
| 7 |
+
- rgbd
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
# VLAarchtests
|
| 11 |
+
|
| 12 |
+
Bundle uploaded from `/workspace` runpod sessions dated `2026-03-25 UTC` and `2026-03-26 UTC`.
|
| 13 |
+
|
| 14 |
+
## Top-Level Contents
|
| 15 |
+
|
| 16 |
+
- `code/reveal_vla_bimanual/`
|
| 17 |
+
- project code used for the proxy and RLBench runs in this bundle
|
| 18 |
+
- `artifacts/data/reveal_proxy/`
|
| 19 |
+
- proxy dataset bundles used by the handoff runs
|
| 20 |
+
- `artifacts/outputs/r3d/`
|
| 21 |
+
- previously uploaded R3D proxy outputs already present in the bundle
|
| 22 |
+
- `artifacts/outputs/r3d_handoff/`
|
| 23 |
+
- handoff proxy checkpoints
|
| 24 |
+
- `artifacts/outputs/r3d_handoff_phase/`
|
| 25 |
+
- phase-supervised handoff proxy checkpoints
|
| 26 |
+
- `artifacts/outputs/rlbench_current/`
|
| 27 |
+
- RLBench checkpoints from the current session
|
| 28 |
+
- `artifacts/reports/`
|
| 29 |
+
- proxy and RLBench result files copied from `/workspace/reports`
|
| 30 |
+
- `environment/`
|
| 31 |
+
- same-machine setup files and validation helpers
|
| 32 |
+
- `tests/`
|
| 33 |
+
- local test suite
|
| 34 |
+
- `handoff/instructions.md`
|
| 35 |
+
- instruction file used for the handoff work
|
| 36 |
+
- `MODEL_INDEX.md`
|
| 37 |
+
- checkpoint and result index
|
| 38 |
+
- `results/session_results_20260326.md`
|
| 39 |
+
- raw result tables for the `2026-03-25/26` work
|
| 40 |
+
|
| 41 |
+
## Code Added Or Updated
|
| 42 |
+
|
| 43 |
+
### Core model, memory, planner, and dataset paths
|
| 44 |
+
|
| 45 |
+
- `code/reveal_vla_bimanual/models/backbones.py`
|
| 46 |
+
- `code/reveal_vla_bimanual/models/multiview_fusion.py`
|
| 47 |
+
- `code/reveal_vla_bimanual/models/observation_memory.py`
|
| 48 |
+
- `code/reveal_vla_bimanual/models/reveal_head.py`
|
| 49 |
+
- `code/reveal_vla_bimanual/models/world_model.py`
|
| 50 |
+
- `code/reveal_vla_bimanual/models/action_decoder.py`
|
| 51 |
+
- `code/reveal_vla_bimanual/models/planner.py`
|
| 52 |
+
- `code/reveal_vla_bimanual/models/policy.py`
|
| 53 |
+
- `code/reveal_vla_bimanual/train/losses.py`
|
| 54 |
+
- `code/reveal_vla_bimanual/sim_reveal/dataset.py`
|
| 55 |
+
- `code/reveal_vla_bimanual/sim_reveal/procedural_envs.py`
|
| 56 |
+
- `code/reveal_vla_bimanual/sim_rlbench/dataset.py`
|
| 57 |
+
|
| 58 |
+
### Training and evaluation paths
|
| 59 |
+
|
| 60 |
+
- `code/reveal_vla_bimanual/train/run_rlbench_experiment.py`
|
| 61 |
+
- `code/reveal_vla_bimanual/eval/run_reveal_benchmark.py`
|
| 62 |
+
- `code/reveal_vla_bimanual/eval/run_ablations.py`
|
| 63 |
+
- `code/reveal_vla_bimanual/eval/run_teacher_audit.py`
|
| 64 |
+
- `code/reveal_vla_bimanual/eval/run_rlbench_rollout_eval.py`
|
| 65 |
+
- `code/reveal_vla_bimanual/eval/run_rlbench_knn_eval.py`
|
| 66 |
+
|
| 67 |
+
### Added or updated training configs
|
| 68 |
+
|
| 69 |
+
- `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact.yaml`
|
| 70 |
+
- `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_spatial.yaml`
|
| 71 |
+
- `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_compact_phase.yaml`
|
| 72 |
+
- `code/reveal_vla_bimanual/train/configs/proxy_interaction_r3d_stage3_clip_rgbd_handoff_spatial_phase.yaml`
|
| 73 |
+
- `code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip_current_valid9.yaml`
|
| 74 |
+
- `code/reveal_vla_bimanual/train/configs/rlbench_subset3_backbone_only_clip_current_common23.yaml`
|
| 75 |
+
- `code/reveal_vla_bimanual/train/configs/rlbench_lift_ball_backbone_only_clip_current_wide.yaml`
|
| 76 |
+
- `code/reveal_vla_bimanual/train/configs/rlbench_lift_ball_backbone_only_clip_step1.yaml`
|
| 77 |
+
- `code/reveal_vla_bimanual/train/configs/rlbench_push_box_backbone_only_clip_step1.yaml`
|
| 78 |
+
|
| 79 |
+
### Test files
|
| 80 |
+
|
| 81 |
+
The staged `tests/` directory contains `32` test modules plus `conftest.py`, including:
|
| 82 |
+
|
| 83 |
+
- geometry and camera rotation coverage
|
| 84 |
+
- phase-label and candidate-ranking coverage
|
| 85 |
+
- planner gradient-flow and reocclusion gating coverage
|
| 86 |
+
- world-model null-rollout, field-consistency, and task-adapter coverage
|
| 87 |
+
- proxy scripted benchmark and teacher-audit coverage
|
| 88 |
+
|
| 89 |
+
## Verification
|
| 90 |
+
|
| 91 |
+
- local test command:
|
| 92 |
+
- `PYTHONPATH=/workspace/VLAarchtests_work/code/reveal_vla_bimanual python -m pytest -q /workspace/VLAarchtests_work/tests`
|
| 93 |
+
- result:
|
| 94 |
+
- `33 passed`
|
| 95 |
+
|
| 96 |
+
## Raw Result Files
|
| 97 |
+
|
| 98 |
+
### Proxy and handoff results
|
| 99 |
+
|
| 100 |
+
- `artifacts/reports/reveal_smoke_mod/reveal_benchmark.json`
|
| 101 |
+
- `artifacts/reports/reveal_smoke_nogeom/reveal_benchmark.json`
|
| 102 |
+
- `artifacts/reports/reveal_smoke_noplanner/reveal_benchmark.json`
|
| 103 |
+
- `artifacts/reports/reveal_handoff_compare_serious/reveal_benchmark.json`
|
| 104 |
+
- `artifacts/reports/reveal_handoff_compare_serious_compact/reveal_benchmark.json`
|
| 105 |
+
- `artifacts/reports/reveal_phase_compare_serious_compact/reveal_benchmark.json`
|
| 106 |
+
- `artifacts/reports/reveal_phase_compare_serious_spatial_compactwm/reveal_benchmark.json`
|
| 107 |
+
- `artifacts/reports/reveal_phase_ablations_compact/ablations.json`
|
| 108 |
+
- `artifacts/reports/reveal_teacher_audit_serious/teacher_audit.json`
|
| 109 |
+
|
| 110 |
+
### RLBench result files
|
| 111 |
+
|
| 112 |
+
- `artifacts/reports/rlbench_dual_buttons_baseline_len100_ep1_ik_rescale/rollout_eval.json`
|
| 113 |
+
- `artifacts/reports/rlbench_dual_buttons_common23_len100_ep1_ik_rescale/rollout_eval.json`
|
| 114 |
+
- `artifacts/reports/rlbench_push_box_common23_len100_ep1_ik_rescale/rollout_eval.json`
|
| 115 |
+
- `artifacts/reports/rlbench_lift_ball_wide_len160_ep1_ik_c1/rollout_eval.json`
|
| 116 |
+
- `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1/rollout_eval.json`
|
| 117 |
+
- `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1_s005/rollout_eval.json`
|
| 118 |
+
- `artifacts/reports/rlbench_push_box_knn_step1_ep1/rollout_eval.json`
|
| 119 |
+
- `artifacts/reports/rlbench_push_box_knn_step1_ep5/rollout_eval.json`
|
| 120 |
+
- `artifacts/reports/rlbench_push_box_knn_step1_ep5_top1_dense/rollout_eval.json`
|
| 121 |
+
|
| 122 |
+
## Raw Result Tables
|
| 123 |
+
|
| 124 |
+
### Proxy serious runs
|
| 125 |
+
|
| 126 |
+
| Artifact | File | Raw values |
|
| 127 |
+
| --- | --- | --- |
|
| 128 |
+
| spatial handoff vs released baseline | `artifacts/reports/reveal_handoff_compare_serious/reveal_benchmark.json` | baseline mean success `0.5833`, handoff mean success `0.2167` |
|
| 129 |
+
| spatial-trained checkpoint with compact world model vs released baseline | `artifacts/reports/reveal_handoff_compare_serious_compact/reveal_benchmark.json` | baseline mean success `0.5833`, handoff mean success `0.5200` |
|
| 130 |
+
| compact-phase vs released baseline | `artifacts/reports/reveal_phase_compare_serious_compact/reveal_benchmark.json` | baseline mean success `0.5833`, compact-phase mean success `0.5133` |
|
| 131 |
+
| spatial-phase with compact world model vs released baseline | `artifacts/reports/reveal_phase_compare_serious_spatial_compactwm/reveal_benchmark.json` | baseline mean success `0.5833`, spatial-phase compact-world-model mean success `0.4933` |
|
| 132 |
+
|
| 133 |
+
### Proxy ablations
|
| 134 |
+
|
| 135 |
+
| Artifact | File | Raw values |
|
| 136 |
+
| --- | --- | --- |
|
| 137 |
+
| compact-phase ablations | `artifacts/reports/reveal_phase_ablations_compact/ablations.json` | full `0.5133`, `no_geometry` `0.5133`, `no_spatial_memory` `0.4967`, `compact_world_model` `0.5133`, `no_planner` `0.4333`, `gaussian_candidates_only` `0.4667`, `no_task_head` `0.5133`, `no_support_mode_conditioning` `0.5133` |
|
| 138 |
+
|
| 139 |
+
### RLBench direct-policy runs
|
| 140 |
+
|
| 141 |
+
| Artifact | File | Raw values |
|
| 142 |
+
| --- | --- | --- |
|
| 143 |
+
| lift-ball wide checkpoint, one-step replanning | `artifacts/reports/rlbench_lift_ball_wide_len160_ep1_ik_c1/rollout_eval.json` | mean success `0.0`, mean return `0.0`, path recoveries `[148]`, noop fallbacks `[11]` |
|
| 144 |
+
| push-box step-1 checkpoint, one-step replanning | `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1/rollout_eval.json` | mean success `0.0`, mean return `0.0`, path recoveries `[177]`, noop fallbacks `[0]` |
|
| 145 |
+
| push-box step-1 checkpoint, one-step replanning, `delta_scale=0.05` | `artifacts/reports/rlbench_push_box_step1_ep1_ik_c1_s005/rollout_eval.json` | mean success `0.0`, mean return `0.0`, path recoveries `[180]`, noop fallbacks `[0]` |
|
| 146 |
+
|
| 147 |
+
### RLBench retrieval runs
|
| 148 |
+
|
| 149 |
+
| Artifact | File | Raw values |
|
| 150 |
+
| --- | --- | --- |
|
| 151 |
+
| push-box kNN, `bank_stride=4`, `top_k=5`, `time_window=8`, `episodes=1` | `artifacts/reports/rlbench_push_box_knn_step1_ep1/rollout_eval.json` | mean success `1.0`, mean return `1.0`, bank size `2815` |
|
| 152 |
+
| push-box kNN, `bank_stride=4`, `top_k=5`, `time_window=8`, `episodes=5` | `artifacts/reports/rlbench_push_box_knn_step1_ep5/rollout_eval.json` | successes `[0.0, 1.0, 0.0, 0.0, 0.0]`, mean success `0.2`, bank size `2815` |
|
| 153 |
+
| push-box kNN, `bank_stride=1`, `top_k=1`, `time_window=4`, `episodes=5` | `artifacts/reports/rlbench_push_box_knn_step1_ep5_top1_dense/rollout_eval.json` | successes `[0.0, 0.0, 1.0, 1.0, 0.0]`, mean success `0.4`, bank size `11259` |
|
| 154 |
+
|
| 155 |
+
## Environment Recreation Files
|
| 156 |
+
|
| 157 |
+
- `environment/setup_same_machine.sh`
|
| 158 |
+
- `environment/validate_same_machine.sh`
|
| 159 |
+
- `environment/run_peract2_13_rollouts.sh`
|
| 160 |
+
- `environment/runtime_env_vars.sh`
|
| 161 |
+
- `environment/hardware_snapshot.txt`
|
| 162 |
+
- `environment/glxinfo_B.txt`
|
| 163 |
+
- `environment/upstream_revisions.txt`
|
| 164 |
+
- `environment/system_packages_same_machine.txt`
|
| 165 |
+
- `environment/rlbench_env_export.yaml`
|
| 166 |
+
- `environment/rlbench_env_explicit.txt`
|
| 167 |
+
- `environment/rlbench_pip_freeze.txt`
|
| 168 |
+
- `environment/reveal_env_export.yaml`
|
| 169 |
+
- `environment/reveal_env_explicit.txt`
|
| 170 |
+
- `environment/reveal_pip_freeze.txt`
|
| 171 |
+
|
| 172 |
+
Detailed raw tables for the `2026-03-25/26` work are in `results/session_results_20260326.md`.
|
metadata/source_sizes.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
13G /workspace/VLAarchtests
|
| 2 |
+
2.2G /workspace/third_party/AnyBimanual
|
| 3 |
+
54G /workspace/baselines
|
| 4 |
+
219M /workspace/reports
|
metadata/staged_size.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
69G /workspace/hf_publish/VLAarchtests2
|
metadata/staged_tree_top2.txt
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/workspace/hf_publish/VLAarchtests2/CHANGE_AND_TEST_LOG.md
|
| 2 |
+
/workspace/hf_publish/VLAarchtests2/MODEL_AND_ARTIFACT_INDEX.md
|
| 3 |
+
/workspace/hf_publish/VLAarchtests2/README.md
|
| 4 |
+
/workspace/hf_publish/VLAarchtests2/RESULTS_RAW.md
|
| 5 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests
|
| 6 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/.cache
|
| 7 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/.gitattributes
|
| 8 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/.gitignore
|
| 9 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/.pytest_cache
|
| 10 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/FILE_MANIFEST.txt
|
| 11 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/MODEL_INDEX.md
|
| 12 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/README.md
|
| 13 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/artifacts
|
| 14 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/code
|
| 15 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/results
|
| 16 |
+
/workspace/hf_publish/VLAarchtests2/VLAarchtests/tests
|
| 17 |
+
/workspace/hf_publish/VLAarchtests2/baselines
|
| 18 |
+
/workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual
|
| 19 |
+
/workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_dummy_demo_root
|
| 20 |
+
/workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_evalroot
|
| 21 |
+
/workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_overlap_replay
|
| 22 |
+
/workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_overlap_runs
|
| 23 |
+
/workspace/hf_publish/VLAarchtests2/baselines/AnyBimanual_subset3_demo_root
|
| 24 |
+
/workspace/hf_publish/VLAarchtests2/environment
|
| 25 |
+
/workspace/hf_publish/VLAarchtests2/environment/base_pip_freeze.txt
|
| 26 |
+
/workspace/hf_publish/VLAarchtests2/environment/base_python.txt
|
| 27 |
+
/workspace/hf_publish/VLAarchtests2/environment/env_list.txt
|
| 28 |
+
/workspace/hf_publish/VLAarchtests2/environment/hardware_snapshot.txt
|
| 29 |
+
/workspace/hf_publish/VLAarchtests2/environment/nvidia_smi.txt
|
| 30 |
+
/workspace/hf_publish/VLAarchtests2/environment/rlbench_pip_freeze.txt
|
| 31 |
+
/workspace/hf_publish/VLAarchtests2/environment/rlbench_python.txt
|
| 32 |
+
/workspace/hf_publish/VLAarchtests2/environment/runtime_env_vars.sh
|
| 33 |
+
/workspace/hf_publish/VLAarchtests2/environment/setup_same_hardware.sh
|
| 34 |
+
/workspace/hf_publish/VLAarchtests2/environment/uname.txt
|
| 35 |
+
/workspace/hf_publish/VLAarchtests2/handoff
|
| 36 |
+
/workspace/hf_publish/VLAarchtests2/handoff/instructions4.md
|
| 37 |
+
/workspace/hf_publish/VLAarchtests2/history
|
| 38 |
+
/workspace/hf_publish/VLAarchtests2/history/VLAarchtests_previous_README.md
|
| 39 |
+
/workspace/hf_publish/VLAarchtests2/metadata
|
| 40 |
+
/workspace/hf_publish/VLAarchtests2/metadata/source_sizes.txt
|
| 41 |
+
/workspace/hf_publish/VLAarchtests2/metadata/staged_size.txt
|
| 42 |
+
/workspace/hf_publish/VLAarchtests2/metadata/staged_tree_top2.txt
|
| 43 |
+
/workspace/hf_publish/VLAarchtests2/reports
|
| 44 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_chain.log
|
| 45 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_eval.log
|
| 46 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_eval_watcher.log
|
| 47 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_resume1000_train.log
|
| 48 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb2_train.log
|
| 49 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb3_eval.log
|
| 50 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb3_train.log
|
| 51 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb3_train_presavefix.log
|
| 52 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_nowandb_train.log
|
| 53 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_fixpretrain_train.log
|
| 54 |
+
/workspace/hf_publish/VLAarchtests2/reports/anybimanual_subset3_overlap_smoke200_train.log
|
| 55 |
+
/workspace/hf_publish/VLAarchtests2/reports/peract2_13_launch_smoke_live
|
| 56 |
+
/workspace/hf_publish/VLAarchtests2/reports/rlbench_common23_exec_calib
|
| 57 |
+
/workspace/hf_publish/VLAarchtests2/reports/rlbench_common23_exec_calib_iso
|
| 58 |
+
/workspace/hf_publish/VLAarchtests2/reports/rlbench_debug_common23_pushbox_ep1
|
| 59 |
+
/workspace/hf_publish/VLAarchtests2/reports/rlbench_general_debug
|
| 60 |
+
/workspace/hf_publish/VLAarchtests2/reports/rlbench_subset3_common23_live_ep1
|
| 61 |
+
/workspace/hf_publish/VLAarchtests2/reports/run_bag_selector_iter9_prebuild.log
|
| 62 |
+
/workspace/hf_publish/VLAarchtests2/reports/true_baseline_compare_subset3_v1
|
| 63 |
+
/workspace/hf_publish/VLAarchtests2/third_party
|
| 64 |
+
/workspace/hf_publish/VLAarchtests2/third_party/AnyBimanual
|
third_party/AnyBimanual/agents/__init__.py
ADDED
|
File without changes
|
third_party/AnyBimanual/agents/agent_factory.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from yarr.agents.agent import BimanualAgent
|
| 8 |
+
from yarr.agents.agent import LeaderFollowerAgent
|
| 9 |
+
from yarr.agents.agent import Agent
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
supported_agents = {"leader_follower": ("PERACT_BC", "RVT"),
|
| 13 |
+
"independent" : ("PERACT_BC", "RVT"),
|
| 14 |
+
"bimanual": ("BIMANUAL_PERACT", "ACT_BC_LANG"),
|
| 15 |
+
"unimanual": ()}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_agent(cfg: DictConfig) -> Agent:
|
| 19 |
+
|
| 20 |
+
method_name = cfg.method.name
|
| 21 |
+
agent_type = cfg.method.agent_type
|
| 22 |
+
|
| 23 |
+
logging.info("Using method %s with type %s", method_name, agent_type)
|
| 24 |
+
|
| 25 |
+
assert(method_name in supported_agents[agent_type])
|
| 26 |
+
|
| 27 |
+
agent_fn = agent_fn_by_name(method_name)
|
| 28 |
+
|
| 29 |
+
if agent_type == "leader_follower":
|
| 30 |
+
checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
|
| 31 |
+
cfg.method.robot_name = "right"
|
| 32 |
+
cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_leader"
|
| 33 |
+
leader_agent = agent_fn(cfg)
|
| 34 |
+
|
| 35 |
+
cfg.method.robot_name = "left"
|
| 36 |
+
cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_follower"
|
| 37 |
+
cfg.method.low_dim_size = cfg.method.low_dim_size + 8 # also add the action size
|
| 38 |
+
follower_agent = agent_fn(cfg)
|
| 39 |
+
|
| 40 |
+
cfg.method.robot_name = "bimanual"
|
| 41 |
+
|
| 42 |
+
return LeaderFollowerAgent(leader_agent, follower_agent)
|
| 43 |
+
|
| 44 |
+
elif agent_type == "independent":
|
| 45 |
+
checkpoint_name_prefix = cfg.framework.checkpoint_name_prefix
|
| 46 |
+
cfg.method.robot_name = "right"
|
| 47 |
+
cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_right"
|
| 48 |
+
right_agent = agent_fn(cfg)
|
| 49 |
+
|
| 50 |
+
cfg.method.robot_name = "left"
|
| 51 |
+
cfg.framework.checkpoint_name_prefix = f"{checkpoint_name_prefix}_{method_name.lower()}_left"
|
| 52 |
+
left_agent = agent_fn(cfg)
|
| 53 |
+
|
| 54 |
+
cfg.method.robot_name = "bimanual"
|
| 55 |
+
|
| 56 |
+
return BimanualAgent(right_agent, left_agent)
|
| 57 |
+
elif agent_type == "bimanual" or agent_type == "unimanual":
|
| 58 |
+
|
| 59 |
+
return agent_fn(cfg)
|
| 60 |
+
else:
|
| 61 |
+
raise Exception("invalid agent type")
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def agent_fn_by_name(method_name: str) -> Agent:
|
| 65 |
+
if method_name == "ARM":
|
| 66 |
+
from agents import arm
|
| 67 |
+
|
| 68 |
+
raise NotImplementedError("ARM not yet supported for eval.py")
|
| 69 |
+
elif method_name == "BC_LANG":
|
| 70 |
+
from agents.baselines import bc_lang
|
| 71 |
+
|
| 72 |
+
return bc_lang.launch_utils.create_agent
|
| 73 |
+
elif method_name == "VIT_BC_LANG":
|
| 74 |
+
from agents.baselines import vit_bc_lang
|
| 75 |
+
|
| 76 |
+
return vit_bc_lang.launch_utils.create_agent
|
| 77 |
+
elif method_name == "C2FARM_LINGUNET_BC":
|
| 78 |
+
from agents import c2farm_lingunet_bc
|
| 79 |
+
|
| 80 |
+
return c2farm_lingunet_bc.launch_utils.create_agent
|
| 81 |
+
elif method_name.startswith("PERACT_BC"):
|
| 82 |
+
from agents import peract_bc
|
| 83 |
+
|
| 84 |
+
return peract_bc.launch_utils.create_agent
|
| 85 |
+
elif method_name.startswith("BIMANUAL_PERACT"):
|
| 86 |
+
from agents import peract_bimanual
|
| 87 |
+
|
| 88 |
+
return peract_bimanual.launch_utils.create_agent
|
| 89 |
+
elif method_name.startswith("RVT"):
|
| 90 |
+
from agents import rvt
|
| 91 |
+
|
| 92 |
+
return rvt.launch_utils.create_agent
|
| 93 |
+
elif method_name.startswith("ACT_BC_LANG"):
|
| 94 |
+
from agents import act_bc_lang
|
| 95 |
+
|
| 96 |
+
return act_bc_lang.launch_utils.create_agent
|
| 97 |
+
elif method_name == "PERACT_RL":
|
| 98 |
+
raise NotImplementedError("PERACT_RL not yet supported for eval.py")
|
| 99 |
+
|
| 100 |
+
else:
|
| 101 |
+
raise ValueError("Method %s does not exists." % method_name)
|
third_party/AnyBimanual/agents/peract_bc/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.peract_bc.launch_utils
|
third_party/AnyBimanual/agents/peract_bc/launch_utils.py
ADDED
|
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 7 |
+
from agents.peract_bc.perceiver_lang_io import PerceiverVoxelLangEncoder
|
| 8 |
+
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 9 |
+
from agents.peract_bc.qattention_stack_agent import QAttentionStackAgent
|
| 10 |
+
import pickle
|
| 11 |
+
import torch
|
| 12 |
+
from agents.peract_bc.skill_manager import SkillManager
|
| 13 |
+
from agents.peract_bc.visual_aligner import VisualAligner
|
| 14 |
+
from omegaconf import DictConfig
|
| 15 |
+
import os
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def create_agent(cfg: DictConfig):
|
| 19 |
+
LATENT_SIZE = 64
|
| 20 |
+
depth_0bounds = cfg.rlbench.scene_bounds
|
| 21 |
+
cam_resolution = cfg.rlbench.camera_resolution
|
| 22 |
+
|
| 23 |
+
num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
|
| 24 |
+
qattention_agents = []
|
| 25 |
+
|
| 26 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 27 |
+
pkl_path = os.path.join(current_dir, "../../lang_token.pkl")
|
| 28 |
+
pkl_path = os.path.abspath(pkl_path)
|
| 29 |
+
with open(pkl_path, "rb") as f:
|
| 30 |
+
embeddings_dict = pickle.load(f)
|
| 31 |
+
flattened_embeddings = []
|
| 32 |
+
for key in embeddings_dict.keys():
|
| 33 |
+
embedding = torch.tensor(embeddings_dict[key])
|
| 34 |
+
flattened_embedding = embedding.view(-1)
|
| 35 |
+
flattened_embeddings.append(flattened_embedding)
|
| 36 |
+
embeddings_matrix = torch.stack(flattened_embeddings)
|
| 37 |
+
# The released AnyBimanual checkpoints were trained with a wider skill-manager
|
| 38 |
+
# hidden size than the public repo currently hardcodes. Keep the original
|
| 39 |
+
# default for non-AnyBimanual runs, but use the released width when loading
|
| 40 |
+
# the AnyBimanual checkpoint family.
|
| 41 |
+
skill_manager_hidden_size = int(
|
| 42 |
+
getattr(cfg.framework, "skill_manager_hidden_size", 256 if cfg.framework.anybimanual else 128)
|
| 43 |
+
)
|
| 44 |
+
skill_manager = SkillManager(
|
| 45 |
+
num_classes=18,
|
| 46 |
+
embedding_matrix=embeddings_matrix,
|
| 47 |
+
hidden_size=skill_manager_hidden_size,
|
| 48 |
+
)
|
| 49 |
+
visual_aligner = VisualAligner()
|
| 50 |
+
|
| 51 |
+
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
|
| 52 |
+
last = depth == len(cfg.method.voxel_sizes) - 1
|
| 53 |
+
perceiver_encoder = PerceiverVoxelLangEncoder(
|
| 54 |
+
depth=cfg.method.transformer_depth,
|
| 55 |
+
iterations=cfg.method.transformer_iterations,
|
| 56 |
+
voxel_size=vox_size,
|
| 57 |
+
initial_dim=3 + 3 + 1 + 3,
|
| 58 |
+
low_dim_size=cfg.method.low_dim_size,
|
| 59 |
+
layer=depth,
|
| 60 |
+
num_rotation_classes=num_rotation_classes if last else 0,
|
| 61 |
+
num_grip_classes=2 if last else 0,
|
| 62 |
+
num_collision_classes=2 if last else 0,
|
| 63 |
+
input_axis=3,
|
| 64 |
+
num_latents=cfg.method.num_latents,
|
| 65 |
+
latent_dim=cfg.method.latent_dim,
|
| 66 |
+
cross_heads=cfg.method.cross_heads,
|
| 67 |
+
latent_heads=cfg.method.latent_heads,
|
| 68 |
+
cross_dim_head=cfg.method.cross_dim_head,
|
| 69 |
+
latent_dim_head=cfg.method.latent_dim_head,
|
| 70 |
+
weight_tie_layers=False,
|
| 71 |
+
activation=cfg.method.activation,
|
| 72 |
+
pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
|
| 73 |
+
input_dropout=cfg.method.input_dropout,
|
| 74 |
+
attn_dropout=cfg.method.attn_dropout,
|
| 75 |
+
decoder_dropout=cfg.method.decoder_dropout,
|
| 76 |
+
lang_fusion_type=cfg.method.lang_fusion_type,
|
| 77 |
+
voxel_patch_size=cfg.method.voxel_patch_size,
|
| 78 |
+
voxel_patch_stride=cfg.method.voxel_patch_stride,
|
| 79 |
+
no_skip_connection=cfg.method.no_skip_connection,
|
| 80 |
+
no_perceiver=cfg.method.no_perceiver,
|
| 81 |
+
no_language=cfg.method.no_language,
|
| 82 |
+
final_dim=cfg.method.final_dim,
|
| 83 |
+
anybimanual=cfg.framework.anybimanual,
|
| 84 |
+
skill_manager = skill_manager,
|
| 85 |
+
visual_aligner = visual_aligner
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
qattention_agent = QAttentionPerActBCAgent(
|
| 89 |
+
layer=depth,
|
| 90 |
+
coordinate_bounds=depth_0bounds,
|
| 91 |
+
perceiver_encoder=perceiver_encoder,
|
| 92 |
+
camera_names=cfg.rlbench.cameras,
|
| 93 |
+
voxel_size=vox_size,
|
| 94 |
+
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
|
| 95 |
+
image_crop_size=cfg.method.image_crop_size,
|
| 96 |
+
lr=cfg.method.lr,
|
| 97 |
+
training_iterations=cfg.framework.training_iterations,
|
| 98 |
+
lr_scheduler=cfg.method.lr_scheduler,
|
| 99 |
+
num_warmup_steps=cfg.method.num_warmup_steps,
|
| 100 |
+
trans_loss_weight=cfg.method.trans_loss_weight,
|
| 101 |
+
rot_loss_weight=cfg.method.rot_loss_weight,
|
| 102 |
+
grip_loss_weight=cfg.method.grip_loss_weight,
|
| 103 |
+
collision_loss_weight=cfg.method.collision_loss_weight,
|
| 104 |
+
include_low_dim_state=True,
|
| 105 |
+
image_resolution=cam_resolution,
|
| 106 |
+
batch_size=cfg.replay.batch_size,
|
| 107 |
+
voxel_feature_size=3,
|
| 108 |
+
lambda_weight_l2=cfg.method.lambda_weight_l2,
|
| 109 |
+
num_rotation_classes=num_rotation_classes,
|
| 110 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 111 |
+
transform_augmentation=cfg.method.transform_augmentation.apply_se3,
|
| 112 |
+
transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
|
| 113 |
+
transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
|
| 114 |
+
transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
|
| 115 |
+
optimizer_type=cfg.method.optimizer,
|
| 116 |
+
num_devices=cfg.ddp.num_devices,
|
| 117 |
+
checkpoint_name_prefix=cfg.framework.checkpoint_name_prefix,
|
| 118 |
+
anybimanual=cfg.framework.anybimanual,
|
| 119 |
+
)
|
| 120 |
+
qattention_agents.append(qattention_agent)
|
| 121 |
+
|
| 122 |
+
rotation_agent = QAttentionStackAgent(
|
| 123 |
+
qattention_agents=qattention_agents,
|
| 124 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 125 |
+
camera_names=cfg.rlbench.cameras,
|
| 126 |
+
)
|
| 127 |
+
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
|
| 128 |
+
return preprocess_agent
|
third_party/AnyBimanual/agents/peract_bc/perceiver_lang_io.py
ADDED
|
@@ -0,0 +1,481 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Perceiver IO implementation adpated for manipulation
|
| 2 |
+
# Source: https://github.com/lucidrains/perceiver-pytorch
|
| 3 |
+
# License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from einops import repeat
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from perceiver_pytorch.perceiver_pytorch import cache_fn
|
| 12 |
+
from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
|
| 13 |
+
|
| 14 |
+
from helpers.network_utils import (
|
| 15 |
+
DenseBlock,
|
| 16 |
+
SpatialSoftmax3D,
|
| 17 |
+
Conv3DBlock,
|
| 18 |
+
Conv3DUpsampleBlock,
|
| 19 |
+
)
|
| 20 |
+
def symmetric_kl_divergence(left, right):
|
| 21 |
+
eps = 1e-2
|
| 22 |
+
left_prob = torch.clamp(F.log_softmax(left, dim=-1), min=-10, max=10)
|
| 23 |
+
right_prob = torch.clamp(F.log_softmax(right, dim=-1), min=-10, max=10)
|
| 24 |
+
|
| 25 |
+
kl_left_to_right = F.kl_div(left_prob, right_prob.exp(), reduction="batchmean")*eps
|
| 26 |
+
kl_right_to_left = F.kl_div(right_prob, left_prob.exp(), reduction="batchmean")*eps
|
| 27 |
+
|
| 28 |
+
symmetric_kl = -(kl_left_to_right + kl_right_to_left) / 2.0
|
| 29 |
+
return symmetric_kl
|
| 30 |
+
|
| 31 |
+
def l1_norm(tensor):
|
| 32 |
+
return torch.sum(torch.abs(tensor)) + 1e-4 * torch.norm(tensor)
|
| 33 |
+
|
| 34 |
+
def l2_1_norm(tensor):
|
| 35 |
+
l2_norm_per_skill = torch.norm(tensor, dim=-1)
|
| 36 |
+
return torch.sum(l2_norm_per_skill)
|
| 37 |
+
|
| 38 |
+
# PerceiverIO adapted for 6-DoF manipulation
|
| 39 |
+
class PerceiverVoxelLangEncoder(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
depth, # number of self-attention layers
|
| 43 |
+
iterations, # number cross-attention iterations (PerceiverIO uses just 1)
|
| 44 |
+
voxel_size, # N voxels per side (size: N*N*N)
|
| 45 |
+
initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
|
| 46 |
+
low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
|
| 47 |
+
layer=0,
|
| 48 |
+
num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
|
| 49 |
+
num_grip_classes=2, # open or not open
|
| 50 |
+
num_collision_classes=2, # collisions allowed or not allowed
|
| 51 |
+
input_axis=3, # 3D tensors have 3 axes
|
| 52 |
+
num_latents=512, # number of latent vectors
|
| 53 |
+
im_channels=64, # intermediate channel size
|
| 54 |
+
latent_dim=512, # dimensions of latent vectors
|
| 55 |
+
cross_heads=1, # number of cross-attention heads
|
| 56 |
+
latent_heads=8, # number of latent heads
|
| 57 |
+
cross_dim_head=64,
|
| 58 |
+
latent_dim_head=64,
|
| 59 |
+
activation="relu",
|
| 60 |
+
weight_tie_layers=False,
|
| 61 |
+
pos_encoding_with_lang=True,
|
| 62 |
+
input_dropout=0.1,
|
| 63 |
+
attn_dropout=0.1,
|
| 64 |
+
decoder_dropout=0.0,
|
| 65 |
+
lang_fusion_type="seq",
|
| 66 |
+
voxel_patch_size=9,
|
| 67 |
+
voxel_patch_stride=8,
|
| 68 |
+
no_skip_connection=False,
|
| 69 |
+
no_perceiver=False,
|
| 70 |
+
no_language=False,
|
| 71 |
+
final_dim=64,
|
| 72 |
+
anybimanual=False,
|
| 73 |
+
skill_manager=None,
|
| 74 |
+
visual_aligner=None,
|
| 75 |
+
):
|
| 76 |
+
super().__init__()
|
| 77 |
+
self.depth = depth
|
| 78 |
+
self.layer = layer
|
| 79 |
+
self.init_dim = int(initial_dim)
|
| 80 |
+
self.iterations = iterations
|
| 81 |
+
self.input_axis = input_axis
|
| 82 |
+
self.voxel_size = voxel_size
|
| 83 |
+
self.low_dim_size = low_dim_size
|
| 84 |
+
self.im_channels = im_channels
|
| 85 |
+
self.pos_encoding_with_lang = pos_encoding_with_lang
|
| 86 |
+
self.lang_fusion_type = lang_fusion_type
|
| 87 |
+
self.voxel_patch_size = voxel_patch_size
|
| 88 |
+
self.voxel_patch_stride = voxel_patch_stride
|
| 89 |
+
self.num_rotation_classes = num_rotation_classes
|
| 90 |
+
self.num_grip_classes = num_grip_classes
|
| 91 |
+
self.num_collision_classes = num_collision_classes
|
| 92 |
+
self.final_dim = final_dim
|
| 93 |
+
self.input_dropout = input_dropout
|
| 94 |
+
self.attn_dropout = attn_dropout
|
| 95 |
+
self.decoder_dropout = decoder_dropout
|
| 96 |
+
self.no_skip_connection = no_skip_connection
|
| 97 |
+
self.no_perceiver = no_perceiver
|
| 98 |
+
self.no_language = no_language
|
| 99 |
+
self.anybimanual = anybimanual
|
| 100 |
+
self.skill_manager = skill_manager
|
| 101 |
+
self.visual_aligner = visual_aligner
|
| 102 |
+
# patchified input dimensions
|
| 103 |
+
spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
|
| 104 |
+
|
| 105 |
+
# 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
|
| 106 |
+
self.input_dim_before_seq = (
|
| 107 |
+
self.im_channels * 3
|
| 108 |
+
if self.lang_fusion_type == "concat"
|
| 109 |
+
else self.im_channels * 2
|
| 110 |
+
)
|
| 111 |
+
if self.anybimanual:
|
| 112 |
+
self.input_dim_before_seq_ = self.input_dim_before_seq*2
|
| 113 |
+
else:
|
| 114 |
+
self.input_dim_before_seq_ = self.input_dim_before_seq
|
| 115 |
+
# CLIP language feature dimensions
|
| 116 |
+
if self.anybimanual:
|
| 117 |
+
lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 154
|
| 118 |
+
else:
|
| 119 |
+
lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
|
| 120 |
+
|
| 121 |
+
self.lang_max_seq_len = lang_max_seq_len
|
| 122 |
+
# learnable positional encoding
|
| 123 |
+
if self.pos_encoding_with_lang:
|
| 124 |
+
self.pos_encoding = nn.Parameter(
|
| 125 |
+
torch.randn(
|
| 126 |
+
1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
|
| 127 |
+
)
|
| 128 |
+
)
|
| 129 |
+
else:
|
| 130 |
+
# assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
|
| 131 |
+
self.pos_encoding = nn.Parameter(
|
| 132 |
+
torch.randn(
|
| 133 |
+
1,
|
| 134 |
+
spatial_size,
|
| 135 |
+
spatial_size,
|
| 136 |
+
spatial_size,
|
| 137 |
+
self.input_dim_before_seq,
|
| 138 |
+
)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# voxel input preprocessing 1x1 conv encoder
|
| 142 |
+
self.input_preprocess = Conv3DBlock(
|
| 143 |
+
self.init_dim,
|
| 144 |
+
self.im_channels,
|
| 145 |
+
kernel_sizes=1,
|
| 146 |
+
strides=1,
|
| 147 |
+
norm=None,
|
| 148 |
+
activation=activation,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# patchify conv
|
| 152 |
+
self.patchify = Conv3DBlock(
|
| 153 |
+
self.input_preprocess.out_channels,
|
| 154 |
+
self.im_channels,
|
| 155 |
+
kernel_sizes=self.voxel_patch_size,
|
| 156 |
+
strides=self.voxel_patch_stride,
|
| 157 |
+
norm=None,
|
| 158 |
+
activation=activation,
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
# language preprocess
|
| 162 |
+
if self.lang_fusion_type == "concat":
|
| 163 |
+
self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
|
| 164 |
+
elif self.lang_fusion_type == "seq":
|
| 165 |
+
self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
|
| 166 |
+
|
| 167 |
+
# proprioception
|
| 168 |
+
if self.low_dim_size > 0:
|
| 169 |
+
self.proprio_preprocess = DenseBlock(
|
| 170 |
+
self.low_dim_size,
|
| 171 |
+
self.im_channels,
|
| 172 |
+
norm=None,
|
| 173 |
+
activation=activation,
|
| 174 |
+
)
|
| 175 |
+
|
| 176 |
+
# pooling functions
|
| 177 |
+
self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
|
| 178 |
+
self.global_maxp = nn.AdaptiveMaxPool3d(1)
|
| 179 |
+
|
| 180 |
+
# 1st 3D softmax
|
| 181 |
+
self.ss0 = SpatialSoftmax3D(
|
| 182 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 183 |
+
)
|
| 184 |
+
flat_size = self.im_channels * 4
|
| 185 |
+
|
| 186 |
+
# latent vectors (that are randomly initialized)
|
| 187 |
+
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
|
| 188 |
+
|
| 189 |
+
# encoder cross attention
|
| 190 |
+
self.cross_attend_blocks = nn.ModuleList(
|
| 191 |
+
[
|
| 192 |
+
PreNorm(
|
| 193 |
+
latent_dim,
|
| 194 |
+
Attention(
|
| 195 |
+
latent_dim,
|
| 196 |
+
self.input_dim_before_seq_,
|
| 197 |
+
heads=cross_heads,
|
| 198 |
+
dim_head=cross_dim_head,
|
| 199 |
+
dropout=input_dropout,
|
| 200 |
+
),
|
| 201 |
+
context_dim=self.input_dim_before_seq_,
|
| 202 |
+
),
|
| 203 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 204 |
+
]
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
get_latent_attn = lambda: PreNorm(
|
| 208 |
+
latent_dim,
|
| 209 |
+
Attention(
|
| 210 |
+
latent_dim,
|
| 211 |
+
heads=latent_heads,
|
| 212 |
+
dim_head=latent_dim_head,
|
| 213 |
+
dropout=attn_dropout,
|
| 214 |
+
),
|
| 215 |
+
)
|
| 216 |
+
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
|
| 217 |
+
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
|
| 218 |
+
|
| 219 |
+
# self attention layers
|
| 220 |
+
self.layers = nn.ModuleList([])
|
| 221 |
+
cache_args = {"_cache": weight_tie_layers}
|
| 222 |
+
|
| 223 |
+
for i in range(depth):
|
| 224 |
+
self.layers.append(
|
| 225 |
+
nn.ModuleList(
|
| 226 |
+
[get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
|
| 227 |
+
)
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
# decoder cross attention
|
| 231 |
+
self.decoder_cross_attn = PreNorm(
|
| 232 |
+
self.input_dim_before_seq_,
|
| 233 |
+
Attention(
|
| 234 |
+
self.input_dim_before_seq_,
|
| 235 |
+
latent_dim,
|
| 236 |
+
heads=cross_heads,
|
| 237 |
+
dim_head=cross_dim_head,
|
| 238 |
+
dropout=decoder_dropout,
|
| 239 |
+
),
|
| 240 |
+
context_dim=latent_dim,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# upsample conv
|
| 244 |
+
self.up0 = Conv3DUpsampleBlock(
|
| 245 |
+
self.input_dim_before_seq_,
|
| 246 |
+
self.final_dim,
|
| 247 |
+
kernel_sizes=self.voxel_patch_size,
|
| 248 |
+
strides=self.voxel_patch_stride,
|
| 249 |
+
norm=None,
|
| 250 |
+
activation=activation,
|
| 251 |
+
)
|
| 252 |
+
|
| 253 |
+
# 2nd 3D softmax
|
| 254 |
+
self.ss1 = SpatialSoftmax3D(
|
| 255 |
+
spatial_size, spatial_size, spatial_size, self.input_dim_before_seq_
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
flat_size += self.input_dim_before_seq_ * 4
|
| 259 |
+
|
| 260 |
+
# final 3D softmax
|
| 261 |
+
self.final = Conv3DBlock(
|
| 262 |
+
self.im_channels
|
| 263 |
+
if (self.no_perceiver or self.no_skip_connection)
|
| 264 |
+
else self.im_channels * 2,
|
| 265 |
+
self.im_channels,
|
| 266 |
+
kernel_sizes=3,
|
| 267 |
+
strides=1,
|
| 268 |
+
norm=None,
|
| 269 |
+
activation=activation,
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
self.trans_decoder = Conv3DBlock(
|
| 273 |
+
self.final_dim,
|
| 274 |
+
1,
|
| 275 |
+
kernel_sizes=3,
|
| 276 |
+
strides=1,
|
| 277 |
+
norm=None,
|
| 278 |
+
activation=None,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
# rotation, gripper, and collision MLP layers
|
| 282 |
+
if self.num_rotation_classes > 0:
|
| 283 |
+
self.ss_final = SpatialSoftmax3D(
|
| 284 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
flat_size += self.im_channels * 4
|
| 288 |
+
|
| 289 |
+
self.dense0 = DenseBlock(flat_size, 256, None, activation)
|
| 290 |
+
self.dense1 = DenseBlock(256, self.final_dim, None, activation)
|
| 291 |
+
|
| 292 |
+
self.rot_grip_collision_ff = DenseBlock(
|
| 293 |
+
self.final_dim,
|
| 294 |
+
self.num_rotation_classes * 3
|
| 295 |
+
+ self.num_grip_classes
|
| 296 |
+
+ self.num_collision_classes,
|
| 297 |
+
None,
|
| 298 |
+
None,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
def encode_text(self, x):
|
| 302 |
+
with torch.no_grad():
|
| 303 |
+
text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
|
| 304 |
+
|
| 305 |
+
text_feat = text_feat.detach()
|
| 306 |
+
text_emb = text_emb.detach()
|
| 307 |
+
text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
|
| 308 |
+
return text_feat, text_emb
|
| 309 |
+
|
| 310 |
+
def forward(
|
| 311 |
+
self,
|
| 312 |
+
ins,
|
| 313 |
+
proprio,
|
| 314 |
+
lang_goal_emb,
|
| 315 |
+
lang_token_embs,
|
| 316 |
+
prev_layer_voxel_grid,
|
| 317 |
+
bounds,
|
| 318 |
+
prev_layer_bounds,
|
| 319 |
+
mask=None,
|
| 320 |
+
arm=None,
|
| 321 |
+
):
|
| 322 |
+
# preprocess input
|
| 323 |
+
d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
|
| 324 |
+
|
| 325 |
+
# aggregated features from 1st softmax and maxpool for MLP decoders
|
| 326 |
+
feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
|
| 327 |
+
|
| 328 |
+
# patchify input (5x5x5 patches)
|
| 329 |
+
ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
|
| 330 |
+
|
| 331 |
+
b, c, d, h, w, device = *ins.shape, ins.device
|
| 332 |
+
axis = [d, h, w]
|
| 333 |
+
assert (
|
| 334 |
+
len(axis) == self.input_axis
|
| 335 |
+
), "input must have the same number of axis as input_axis"
|
| 336 |
+
|
| 337 |
+
# concat proprio
|
| 338 |
+
if self.low_dim_size > 0:
|
| 339 |
+
p = self.proprio_preprocess(proprio) # [B,4] -> [B,64]
|
| 340 |
+
p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 341 |
+
ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
|
| 342 |
+
|
| 343 |
+
# language ablation
|
| 344 |
+
if self.no_language:
|
| 345 |
+
lang_goal_emb = torch.zeros_like(lang_goal_emb)
|
| 346 |
+
lang_token_embs = torch.zeros_like(lang_token_embs)
|
| 347 |
+
|
| 348 |
+
# option 1: tile and concat lang goal to input
|
| 349 |
+
if self.lang_fusion_type == "concat":
|
| 350 |
+
lang_emb = lang_goal_emb
|
| 351 |
+
lang_emb = lang_emb.to(dtype=ins.dtype)
|
| 352 |
+
l = self.lang_preprocess(lang_emb)
|
| 353 |
+
l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 354 |
+
ins = torch.cat([ins, l], dim=1)
|
| 355 |
+
|
| 356 |
+
# channel last
|
| 357 |
+
ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
|
| 358 |
+
|
| 359 |
+
# add pos encoding to grid
|
| 360 |
+
if not self.pos_encoding_with_lang:
|
| 361 |
+
ins = ins + self.pos_encoding
|
| 362 |
+
|
| 363 |
+
######################## NOTE #############################
|
| 364 |
+
# NOTE: If you add positional encodings ^here the lang embs
|
| 365 |
+
# won't have positional encodings. I accidently forgot
|
| 366 |
+
# to turn this off for all the experiments in the paper.
|
| 367 |
+
# So I guess those models were using language embs
|
| 368 |
+
# as a bag of words :( But it doesn't matter much for
|
| 369 |
+
# RLBench tasks since we don't test for novel instructions
|
| 370 |
+
# at test time anyway. The recommend way is to add
|
| 371 |
+
# positional encodings to the final input sequence
|
| 372 |
+
# fed into the Perceiver Transformer, as done below
|
| 373 |
+
# (and also in the Colab tutorial).
|
| 374 |
+
###########################################################
|
| 375 |
+
|
| 376 |
+
# concat to channels of and flatten axis
|
| 377 |
+
queries_orig_shape = ins.shape
|
| 378 |
+
|
| 379 |
+
# rearrange input to be channel last
|
| 380 |
+
ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
|
| 381 |
+
ins_wo_prev_layers = ins
|
| 382 |
+
# option 2: add lang token embs as a sequence
|
| 383 |
+
if self.anybimanual:
|
| 384 |
+
l = self.lang_preprocess(lang_token_embs) # [B,77,512] -> [B,77,128]
|
| 385 |
+
mask_right, mask_left = self.visual_aligner(ins)
|
| 386 |
+
L_voxel = symmetric_kl_divergence(mask_left, mask_right)
|
| 387 |
+
right_skill = self.skill_manager(mask_right, l)
|
| 388 |
+
left_skill = self.skill_manager(mask_left, l)
|
| 389 |
+
right_skill = self.lang_preprocess(right_skill)
|
| 390 |
+
left_skill = self.lang_preprocess(left_skill)
|
| 391 |
+
L_skill = (
|
| 392 |
+
l1_norm(left_skill) + l1_norm(right_skill) +
|
| 393 |
+
0.01 * (l2_1_norm(left_skill) + l2_1_norm(right_skill))
|
| 394 |
+
)
|
| 395 |
+
l_right = torch.cat((right_skill, l), dim=1)
|
| 396 |
+
ins_right = torch.cat((l_right, mask_right), dim=1)
|
| 397 |
+
l_left = torch.cat((left_skill, l), dim=1)
|
| 398 |
+
ins_left = torch.cat((l_left, mask_left), dim=1)
|
| 399 |
+
if arm == "right":
|
| 400 |
+
skill = right_skill
|
| 401 |
+
ins_ = ins_right
|
| 402 |
+
else:
|
| 403 |
+
skill = left_skill
|
| 404 |
+
ins_ = ins_left
|
| 405 |
+
if self.pos_encoding_with_lang:
|
| 406 |
+
ins_ = ins_ + self.pos_encoding
|
| 407 |
+
else:
|
| 408 |
+
if self.lang_fusion_type == "seq":
|
| 409 |
+
l = self.lang_preprocess(lang_token_embs) # [B,77,1024] -> [B,77,128]
|
| 410 |
+
ins = torch.cat((l, ins), dim=1) # [B,8077,128]
|
| 411 |
+
# add pos encoding to language + flattened grid (the recommended way)
|
| 412 |
+
if self.pos_encoding_with_lang:
|
| 413 |
+
ins = ins + self.pos_encoding
|
| 414 |
+
|
| 415 |
+
if self.anybimanual:
|
| 416 |
+
skill_l = torch.cat((skill, l), dim=1)
|
| 417 |
+
ins = torch.cat((skill_l, ins),dim=1)
|
| 418 |
+
ins = torch.cat((ins_, ins),dim=2)
|
| 419 |
+
# batchify latents
|
| 420 |
+
x = repeat(self.latents, "n d -> b n d", b=b)
|
| 421 |
+
|
| 422 |
+
cross_attn, cross_ff = self.cross_attend_blocks
|
| 423 |
+
|
| 424 |
+
for it in range(self.iterations):
|
| 425 |
+
# encoder cross attention
|
| 426 |
+
x = cross_attn(x, context=ins, mask=mask) + x
|
| 427 |
+
x = cross_ff(x) + x
|
| 428 |
+
|
| 429 |
+
# self-attention layers
|
| 430 |
+
for self_attn, self_ff in self.layers:
|
| 431 |
+
x = self_attn(x) + x
|
| 432 |
+
x = self_ff(x) + x
|
| 433 |
+
|
| 434 |
+
# decoder cross attention
|
| 435 |
+
latents = self.decoder_cross_attn(ins, context=x)
|
| 436 |
+
# crop out the language part of the output sequence
|
| 437 |
+
if self.lang_fusion_type == "seq":
|
| 438 |
+
latents = latents[:, self.lang_max_seq_len :]
|
| 439 |
+
|
| 440 |
+
# reshape back to voxel grid
|
| 441 |
+
latents = latents.view(
|
| 442 |
+
b, *queries_orig_shape[1:-1], latents.shape[-1]
|
| 443 |
+
) # [B,20,20,20,64]
|
| 444 |
+
latents = rearrange(latents, "b ... d -> b d ...") # [B,64,20,20,20]
|
| 445 |
+
|
| 446 |
+
# aggregated features from 2nd softmax and maxpool for MLP decoders
|
| 447 |
+
feats.extend(
|
| 448 |
+
[self.ss1(latents.contiguous()), self.global_maxp(latents).view(b, -1)]
|
| 449 |
+
)
|
| 450 |
+
|
| 451 |
+
# upsample
|
| 452 |
+
u0 = self.up0(latents)
|
| 453 |
+
|
| 454 |
+
# ablations
|
| 455 |
+
if self.no_skip_connection:
|
| 456 |
+
u = self.final(u0)
|
| 457 |
+
elif self.no_perceiver:
|
| 458 |
+
u = self.final(d0)
|
| 459 |
+
else:
|
| 460 |
+
u = self.final(torch.cat([d0, u0], dim=1))
|
| 461 |
+
|
| 462 |
+
# translation decoder
|
| 463 |
+
trans = self.trans_decoder(u)
|
| 464 |
+
|
| 465 |
+
# rotation, gripper, and collision MLPs
|
| 466 |
+
rot_and_grip_out = None
|
| 467 |
+
if self.num_rotation_classes > 0:
|
| 468 |
+
feats.extend(
|
| 469 |
+
[self.ss_final(u.contiguous()), self.global_maxp(u).view(b, -1)]
|
| 470 |
+
)
|
| 471 |
+
|
| 472 |
+
dense0 = self.dense0(torch.cat(feats, dim=1))
|
| 473 |
+
dense1 = self.dense1(dense0) # [B,72*3+2+2]
|
| 474 |
+
|
| 475 |
+
rot_and_grip_collision_out = self.rot_grip_collision_ff(dense1)
|
| 476 |
+
rot_and_grip_out = rot_and_grip_collision_out[
|
| 477 |
+
:, : -self.num_collision_classes
|
| 478 |
+
]
|
| 479 |
+
collision_out = rot_and_grip_collision_out[:, -self.num_collision_classes :]
|
| 480 |
+
|
| 481 |
+
return trans, rot_and_grip_out, collision_out
|
third_party/AnyBimanual/agents/peract_bc/qattention_peract_bc_agent.py
ADDED
|
@@ -0,0 +1,939 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from pytorch3d import transforms as torch3d_tf
|
| 12 |
+
from yarr.agents.agent import (
|
| 13 |
+
Agent,
|
| 14 |
+
ActResult,
|
| 15 |
+
ScalarSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
ImageSummary,
|
| 18 |
+
Summary,
|
| 19 |
+
)
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
import PIL.Image as Image
|
| 22 |
+
import wandb
|
| 23 |
+
import io
|
| 24 |
+
from termcolor import colored, cprint
|
| 25 |
+
from helpers import utils
|
| 26 |
+
from helpers.utils import visualise_voxel, stack_on_channel
|
| 27 |
+
from voxel.voxel_grid import VoxelGrid
|
| 28 |
+
from einops import rearrange
|
| 29 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 30 |
+
|
| 31 |
+
import transformers
|
| 32 |
+
from helpers.optim.lamb import Lamb
|
| 33 |
+
|
| 34 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class QFunction(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
perceiver_encoder: nn.Module,
|
| 41 |
+
voxelizer: VoxelGrid,
|
| 42 |
+
bounds_offset: float,
|
| 43 |
+
rotation_resolution: float,
|
| 44 |
+
device,
|
| 45 |
+
training,
|
| 46 |
+
):
|
| 47 |
+
super(QFunction, self).__init__()
|
| 48 |
+
self._rotation_resolution = rotation_resolution
|
| 49 |
+
self._voxelizer = voxelizer
|
| 50 |
+
self._bounds_offset = bounds_offset
|
| 51 |
+
self._qnet = perceiver_encoder.to(device)
|
| 52 |
+
|
| 53 |
+
# distributed training
|
| 54 |
+
if training:
|
| 55 |
+
self._qnet = DDP(self._qnet, device_ids=[device], find_unused_parameters=True)
|
| 56 |
+
|
| 57 |
+
def _argmax_3d(self, tensor_orig):
|
| 58 |
+
b, c, d, h, w = tensor_orig.shape # c will be one
|
| 59 |
+
idxs = tensor_orig.view(b, c, -1).argmax(-1)
|
| 60 |
+
indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
|
| 61 |
+
return indices
|
| 62 |
+
|
| 63 |
+
def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
|
| 64 |
+
coords = self._argmax_3d(q_trans)
|
| 65 |
+
rot_and_grip_indicies = None
|
| 66 |
+
ignore_collision = None
|
| 67 |
+
if q_rot_grip is not None:
|
| 68 |
+
q_rot = torch.stack(
|
| 69 |
+
torch.split(
|
| 70 |
+
q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 71 |
+
),
|
| 72 |
+
dim=1,
|
| 73 |
+
)
|
| 74 |
+
rot_and_grip_indicies = torch.cat(
|
| 75 |
+
[
|
| 76 |
+
q_rot[:, 0:1].argmax(-1),
|
| 77 |
+
q_rot[:, 1:2].argmax(-1),
|
| 78 |
+
q_rot[:, 2:3].argmax(-1),
|
| 79 |
+
q_rot_grip[:, -2:].argmax(-1, keepdim=True),
|
| 80 |
+
],
|
| 81 |
+
-1,
|
| 82 |
+
)
|
| 83 |
+
ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
|
| 84 |
+
return coords, rot_and_grip_indicies, ignore_collision
|
| 85 |
+
|
| 86 |
+
def forward(
|
| 87 |
+
self,
|
| 88 |
+
rgb_pcd,
|
| 89 |
+
proprio,
|
| 90 |
+
pcd,
|
| 91 |
+
lang_goal_emb,
|
| 92 |
+
lang_token_embs,
|
| 93 |
+
bounds=None,
|
| 94 |
+
prev_bounds=None,
|
| 95 |
+
prev_layer_voxel_grid=None,
|
| 96 |
+
arm=None,
|
| 97 |
+
):
|
| 98 |
+
# rgb_pcd will be list of list (list of [rgb, pcd])
|
| 99 |
+
b = rgb_pcd[0][0].shape[0]
|
| 100 |
+
pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
|
| 101 |
+
|
| 102 |
+
# flatten RGBs and Pointclouds
|
| 103 |
+
rgb = [rp[0] for rp in rgb_pcd]
|
| 104 |
+
feat_size = rgb[0].shape[1]
|
| 105 |
+
flat_imag_features = torch.cat(
|
| 106 |
+
[p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# construct voxel grid
|
| 110 |
+
voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
|
| 111 |
+
pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# swap to channels fist
|
| 115 |
+
voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
|
| 116 |
+
|
| 117 |
+
# batch bounds if necessary
|
| 118 |
+
if bounds.shape[0] != b:
|
| 119 |
+
bounds = bounds.repeat(b, 1)
|
| 120 |
+
|
| 121 |
+
# forward pass
|
| 122 |
+
q_trans, q_rot_and_grip, q_ignore_collisions = self._qnet(
|
| 123 |
+
voxel_grid,
|
| 124 |
+
proprio,
|
| 125 |
+
lang_goal_emb,
|
| 126 |
+
lang_token_embs,
|
| 127 |
+
prev_layer_voxel_grid,
|
| 128 |
+
bounds,
|
| 129 |
+
prev_bounds,
|
| 130 |
+
arm=arm,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
return q_trans, q_rot_and_grip, q_ignore_collisions, voxel_grid
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
class QAttentionPerActBCAgent(Agent):
|
| 137 |
+
def __init__(
|
| 138 |
+
self,
|
| 139 |
+
layer: int,
|
| 140 |
+
coordinate_bounds: list,
|
| 141 |
+
perceiver_encoder: nn.Module,
|
| 142 |
+
camera_names: list,
|
| 143 |
+
batch_size: int,
|
| 144 |
+
voxel_size: int,
|
| 145 |
+
bounds_offset: float,
|
| 146 |
+
voxel_feature_size: int,
|
| 147 |
+
image_crop_size: int,
|
| 148 |
+
num_rotation_classes: int,
|
| 149 |
+
rotation_resolution: float,
|
| 150 |
+
lr: float = 0.0001,
|
| 151 |
+
lr_scheduler: bool = False,
|
| 152 |
+
training_iterations: int = 100000,
|
| 153 |
+
num_warmup_steps: int = 20000,
|
| 154 |
+
trans_loss_weight: float = 1.0,
|
| 155 |
+
rot_loss_weight: float = 1.0,
|
| 156 |
+
grip_loss_weight: float = 1.0,
|
| 157 |
+
collision_loss_weight: float = 1.0,
|
| 158 |
+
include_low_dim_state: bool = False,
|
| 159 |
+
image_resolution: list = None,
|
| 160 |
+
lambda_weight_l2: float = 0.0,
|
| 161 |
+
transform_augmentation: bool = True,
|
| 162 |
+
transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
|
| 163 |
+
transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
|
| 164 |
+
transform_augmentation_rot_resolution: int = 5,
|
| 165 |
+
optimizer_type: str = "adam",
|
| 166 |
+
num_devices: int = 1,
|
| 167 |
+
checkpoint_name_prefix=None,
|
| 168 |
+
anybimanual = False,
|
| 169 |
+
cfg=None,
|
| 170 |
+
):
|
| 171 |
+
self._layer = layer
|
| 172 |
+
self._coordinate_bounds = coordinate_bounds
|
| 173 |
+
self._perceiver_encoder = perceiver_encoder
|
| 174 |
+
self._voxel_feature_size = voxel_feature_size
|
| 175 |
+
self._bounds_offset = bounds_offset
|
| 176 |
+
self._image_crop_size = image_crop_size
|
| 177 |
+
self._lr = lr
|
| 178 |
+
self._lr_scheduler = lr_scheduler
|
| 179 |
+
self._training_iterations = training_iterations
|
| 180 |
+
self._num_warmup_steps = num_warmup_steps
|
| 181 |
+
self._trans_loss_weight = trans_loss_weight
|
| 182 |
+
self._rot_loss_weight = rot_loss_weight
|
| 183 |
+
self._grip_loss_weight = grip_loss_weight
|
| 184 |
+
self._collision_loss_weight = collision_loss_weight
|
| 185 |
+
self._include_low_dim_state = include_low_dim_state
|
| 186 |
+
self._image_resolution = image_resolution or [128, 128]
|
| 187 |
+
self._voxel_size = voxel_size
|
| 188 |
+
self._camera_names = camera_names
|
| 189 |
+
self._num_cameras = len(camera_names)
|
| 190 |
+
self._batch_size = batch_size
|
| 191 |
+
self._lambda_weight_l2 = lambda_weight_l2
|
| 192 |
+
self._transform_augmentation = transform_augmentation
|
| 193 |
+
self._transform_augmentation_xyz = torch.from_numpy(
|
| 194 |
+
np.array(transform_augmentation_xyz)
|
| 195 |
+
)
|
| 196 |
+
self._transform_augmentation_rpy = transform_augmentation_rpy
|
| 197 |
+
self._transform_augmentation_rot_resolution = (
|
| 198 |
+
transform_augmentation_rot_resolution
|
| 199 |
+
)
|
| 200 |
+
self._optimizer_type = optimizer_type
|
| 201 |
+
self._num_devices = num_devices
|
| 202 |
+
self._num_rotation_classes = num_rotation_classes
|
| 203 |
+
self._rotation_resolution = rotation_resolution
|
| 204 |
+
|
| 205 |
+
self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
| 206 |
+
checkpoint_name_prefix = checkpoint_name_prefix or "QAttentionAgent"
|
| 207 |
+
self._name = f"{checkpoint_name_prefix}_layer_{self._layer}"
|
| 208 |
+
self.anybimanual = anybimanual
|
| 209 |
+
self.cfg=cfg
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
def build(self, training: bool, device: torch.device = None):
|
| 213 |
+
self._training = training
|
| 214 |
+
|
| 215 |
+
if device is None:
|
| 216 |
+
device = torch.device("cpu")
|
| 217 |
+
|
| 218 |
+
self._device = device
|
| 219 |
+
|
| 220 |
+
self._voxelizer = VoxelGrid(
|
| 221 |
+
coord_bounds=self._coordinate_bounds,
|
| 222 |
+
voxel_size=self._voxel_size,
|
| 223 |
+
device=device,
|
| 224 |
+
batch_size=self._batch_size if training else 1,
|
| 225 |
+
feature_size=self._voxel_feature_size,
|
| 226 |
+
max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
self._q = (
|
| 230 |
+
QFunction(
|
| 231 |
+
self._perceiver_encoder,
|
| 232 |
+
self._voxelizer,
|
| 233 |
+
self._bounds_offset,
|
| 234 |
+
self._rotation_resolution,
|
| 235 |
+
device,
|
| 236 |
+
training,
|
| 237 |
+
)
|
| 238 |
+
.to(device)
|
| 239 |
+
.train(training)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
grid_for_crop = (
|
| 243 |
+
torch.arange(0, self._image_crop_size, device=device)
|
| 244 |
+
.unsqueeze(0)
|
| 245 |
+
.repeat(self._image_crop_size, 1)
|
| 246 |
+
.unsqueeze(-1)
|
| 247 |
+
)
|
| 248 |
+
self._grid_for_crop = torch.cat(
|
| 249 |
+
[grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
|
| 250 |
+
).unsqueeze(0)
|
| 251 |
+
|
| 252 |
+
self._coordinate_bounds = torch.tensor(
|
| 253 |
+
self._coordinate_bounds, device=device
|
| 254 |
+
).unsqueeze(0)
|
| 255 |
+
|
| 256 |
+
if self._training:
|
| 257 |
+
# optimizer
|
| 258 |
+
if self._optimizer_type == "lamb":
|
| 259 |
+
self._optimizer = Lamb(
|
| 260 |
+
self._q.parameters(),
|
| 261 |
+
lr=self._lr,
|
| 262 |
+
weight_decay=self._lambda_weight_l2,
|
| 263 |
+
betas=(0.9, 0.999),
|
| 264 |
+
adam=False,
|
| 265 |
+
)
|
| 266 |
+
elif self._optimizer_type == "adam":
|
| 267 |
+
self._optimizer = torch.optim.Adam(
|
| 268 |
+
self._q.parameters(),
|
| 269 |
+
lr=self._lr,
|
| 270 |
+
weight_decay=self._lambda_weight_l2,
|
| 271 |
+
)
|
| 272 |
+
else:
|
| 273 |
+
raise Exception("Unknown optimizer type")
|
| 274 |
+
|
| 275 |
+
# learning rate scheduler
|
| 276 |
+
if self._lr_scheduler:
|
| 277 |
+
self._scheduler = (
|
| 278 |
+
transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 279 |
+
self._optimizer,
|
| 280 |
+
num_warmup_steps=self._num_warmup_steps,
|
| 281 |
+
num_training_steps=self._training_iterations,
|
| 282 |
+
num_cycles=self._training_iterations // 10000,
|
| 283 |
+
)
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
# one-hot zero tensors
|
| 287 |
+
self._action_trans_one_hot_zeros = torch.zeros(
|
| 288 |
+
(
|
| 289 |
+
self._batch_size,
|
| 290 |
+
1,
|
| 291 |
+
self._voxel_size,
|
| 292 |
+
self._voxel_size,
|
| 293 |
+
self._voxel_size,
|
| 294 |
+
),
|
| 295 |
+
dtype=int,
|
| 296 |
+
device=device,
|
| 297 |
+
)
|
| 298 |
+
self._action_rot_x_one_hot_zeros = torch.zeros(
|
| 299 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 300 |
+
)
|
| 301 |
+
self._action_rot_y_one_hot_zeros = torch.zeros(
|
| 302 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 303 |
+
)
|
| 304 |
+
self._action_rot_z_one_hot_zeros = torch.zeros(
|
| 305 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 306 |
+
)
|
| 307 |
+
self._action_grip_one_hot_zeros = torch.zeros(
|
| 308 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 309 |
+
)
|
| 310 |
+
self._action_ignore_collisions_one_hot_zeros = torch.zeros(
|
| 311 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
# print total params
|
| 315 |
+
logging.info(
|
| 316 |
+
"# Q Params: %d"
|
| 317 |
+
% sum(
|
| 318 |
+
p.numel()
|
| 319 |
+
for name, p in self._q.named_parameters()
|
| 320 |
+
if p.requires_grad and "clip" not in name
|
| 321 |
+
)
|
| 322 |
+
)
|
| 323 |
+
else:
|
| 324 |
+
for param in self._q.parameters():
|
| 325 |
+
param.requires_grad = False
|
| 326 |
+
|
| 327 |
+
# load CLIP for encoding language goals during evaluation
|
| 328 |
+
model, _ = load_clip("RN50", jit=False)
|
| 329 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 330 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 331 |
+
self._clip_rn50.eval()
|
| 332 |
+
del model
|
| 333 |
+
|
| 334 |
+
self._voxelizer.to(device)
|
| 335 |
+
self._q.to(device)
|
| 336 |
+
|
| 337 |
+
def _extract_crop(self, pixel_action, observation):
|
| 338 |
+
# Pixel action will now be (B, 2)
|
| 339 |
+
# observation = stack_on_channel(observation)
|
| 340 |
+
h = observation.shape[-1]
|
| 341 |
+
top_left_corner = torch.clamp(
|
| 342 |
+
pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
|
| 343 |
+
)
|
| 344 |
+
grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
|
| 345 |
+
grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
|
| 346 |
+
# Used for cropping the images across a batch
|
| 347 |
+
# swap fro y x, to x, y
|
| 348 |
+
grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
|
| 349 |
+
crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
|
| 350 |
+
return crop
|
| 351 |
+
|
| 352 |
+
def _preprocess_inputs(self, replay_sample):
|
| 353 |
+
obs = []
|
| 354 |
+
pcds = []
|
| 355 |
+
self._crop_summary = []
|
| 356 |
+
for n in self._camera_names:
|
| 357 |
+
rgb = replay_sample["%s_rgb" % n]
|
| 358 |
+
pcd = replay_sample["%s_point_cloud" % n]
|
| 359 |
+
|
| 360 |
+
obs.append([rgb, pcd])
|
| 361 |
+
pcds.append(pcd)
|
| 362 |
+
return obs, pcds
|
| 363 |
+
|
| 364 |
+
def _act_preprocess_inputs(self, observation):
|
| 365 |
+
obs, pcds = [], []
|
| 366 |
+
for n in self._camera_names:
|
| 367 |
+
rgb = observation["%s_rgb" % n]
|
| 368 |
+
pcd = observation["%s_point_cloud" % n]
|
| 369 |
+
|
| 370 |
+
obs.append([rgb, pcd])
|
| 371 |
+
pcds.append(pcd)
|
| 372 |
+
return obs, pcds
|
| 373 |
+
|
| 374 |
+
def _get_value_from_voxel_index(self, q, voxel_idx):
|
| 375 |
+
b, c, d, h, w = q.shape
|
| 376 |
+
q_trans_flat = q.view(b, c, d * h * w)
|
| 377 |
+
flat_indicies = (
|
| 378 |
+
voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
|
| 379 |
+
)[:, None].int()
|
| 380 |
+
highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
|
| 381 |
+
chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
|
| 382 |
+
..., 0
|
| 383 |
+
] # (B, trans + rot + grip)
|
| 384 |
+
return chosen_voxel_values
|
| 385 |
+
|
| 386 |
+
def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
|
| 387 |
+
q_rot = torch.stack(
|
| 388 |
+
torch.split(
|
| 389 |
+
rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 390 |
+
),
|
| 391 |
+
dim=1,
|
| 392 |
+
) # B, 3, 72
|
| 393 |
+
q_grip = rot_grip_q[:, -2:]
|
| 394 |
+
rot_and_grip_values = torch.cat(
|
| 395 |
+
[
|
| 396 |
+
q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
|
| 397 |
+
q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
|
| 398 |
+
q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
|
| 399 |
+
q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
|
| 400 |
+
],
|
| 401 |
+
-1,
|
| 402 |
+
)
|
| 403 |
+
return rot_and_grip_values
|
| 404 |
+
|
| 405 |
+
def _celoss(self, pred, labels):
|
| 406 |
+
return self._cross_entropy_loss(pred, labels.argmax(-1))
|
| 407 |
+
|
| 408 |
+
def _softmax_q_trans(self, q):
|
| 409 |
+
q_shape = q.shape
|
| 410 |
+
return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
|
| 411 |
+
|
| 412 |
+
def _softmax_q_rot_grip(self, q_rot_grip):
|
| 413 |
+
q_rot_x_flat = q_rot_grip[
|
| 414 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 415 |
+
]
|
| 416 |
+
q_rot_y_flat = q_rot_grip[
|
| 417 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 418 |
+
]
|
| 419 |
+
q_rot_z_flat = q_rot_grip[
|
| 420 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 421 |
+
]
|
| 422 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 423 |
+
|
| 424 |
+
q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
|
| 425 |
+
q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
|
| 426 |
+
q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
|
| 427 |
+
q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
|
| 428 |
+
|
| 429 |
+
return torch.cat(
|
| 430 |
+
[
|
| 431 |
+
q_rot_x_flat_softmax,
|
| 432 |
+
q_rot_y_flat_softmax,
|
| 433 |
+
q_rot_z_flat_softmax,
|
| 434 |
+
q_grip_flat_softmax,
|
| 435 |
+
],
|
| 436 |
+
dim=1,
|
| 437 |
+
)
|
| 438 |
+
|
| 439 |
+
def _softmax_ignore_collision(self, q_collision):
|
| 440 |
+
q_collision_softmax = F.softmax(q_collision, dim=1)
|
| 441 |
+
return q_collision_softmax
|
| 442 |
+
|
| 443 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 444 |
+
action_trans = replay_sample["trans_action_indicies"][
|
| 445 |
+
:, self._layer * 3 : self._layer * 3 + 3
|
| 446 |
+
].int()
|
| 447 |
+
action_rot_grip = replay_sample["rot_grip_action_indicies"].int()
|
| 448 |
+
action_gripper_pose = replay_sample["gripper_pose"]
|
| 449 |
+
action_ignore_collisions = replay_sample["ignore_collisions"].int()
|
| 450 |
+
lang_goal_emb = replay_sample["lang_goal_emb"].float()
|
| 451 |
+
lang_token_embs = replay_sample["lang_token_embs"].float()
|
| 452 |
+
prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
|
| 453 |
+
prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
|
| 454 |
+
device = self._device
|
| 455 |
+
rank = device
|
| 456 |
+
bounds = self._coordinate_bounds.to(device)
|
| 457 |
+
if self._layer > 0:
|
| 458 |
+
cp = replay_sample["attention_coordinate_layer_%d" % (self._layer - 1)]
|
| 459 |
+
bounds = torch.cat(
|
| 460 |
+
[cp - self._bounds_offset, cp + self._bounds_offset], dim=1
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
proprio = None
|
| 464 |
+
if self._include_low_dim_state:
|
| 465 |
+
proprio = replay_sample["low_dim_state"]
|
| 466 |
+
|
| 467 |
+
obs, pcd = self._preprocess_inputs(replay_sample)
|
| 468 |
+
if proprio.shape[-1] == 4:
|
| 469 |
+
arm = "right"
|
| 470 |
+
else:
|
| 471 |
+
arm = "left"
|
| 472 |
+
# batch size
|
| 473 |
+
bs = pcd[0].shape[0]
|
| 474 |
+
|
| 475 |
+
# SE(3) augmentation of point clouds and actions
|
| 476 |
+
if self._transform_augmentation:
|
| 477 |
+
from voxel import augmentation
|
| 478 |
+
action_trans, action_rot_grip, pcd = augmentation.apply_se3_augmentation(
|
| 479 |
+
pcd,
|
| 480 |
+
action_gripper_pose,
|
| 481 |
+
action_trans,
|
| 482 |
+
action_rot_grip,
|
| 483 |
+
bounds,
|
| 484 |
+
self._layer,
|
| 485 |
+
self._transform_augmentation_xyz,
|
| 486 |
+
self._transform_augmentation_rpy,
|
| 487 |
+
self._transform_augmentation_rot_resolution,
|
| 488 |
+
self._voxel_size,
|
| 489 |
+
self._rotation_resolution,
|
| 490 |
+
self._device,
|
| 491 |
+
)
|
| 492 |
+
|
| 493 |
+
# forward pass
|
| 494 |
+
q_trans, q_rot_grip, q_collision, voxel_grid = self._q(
|
| 495 |
+
obs,
|
| 496 |
+
proprio,
|
| 497 |
+
pcd,
|
| 498 |
+
lang_goal_emb,
|
| 499 |
+
lang_token_embs,
|
| 500 |
+
bounds,
|
| 501 |
+
prev_layer_bounds,
|
| 502 |
+
prev_layer_voxel_grid,
|
| 503 |
+
arm=arm,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# argmax to choose best action
|
| 507 |
+
(
|
| 508 |
+
coords,
|
| 509 |
+
rot_and_grip_indicies,
|
| 510 |
+
ignore_collision_indicies,
|
| 511 |
+
) = self._q.choose_highest_action(q_trans, q_rot_grip, q_collision)
|
| 512 |
+
|
| 513 |
+
q_trans_loss, q_rot_loss, q_grip_loss, q_collision_loss = 0.0, 0.0, 0.0, 0.0
|
| 514 |
+
|
| 515 |
+
# translation one-hot
|
| 516 |
+
action_trans_one_hot = self._action_trans_one_hot_zeros.clone()
|
| 517 |
+
for b in range(bs):
|
| 518 |
+
gt_coord = action_trans[b, :].int()
|
| 519 |
+
action_trans_one_hot[b, :, gt_coord[0], gt_coord[1], gt_coord[2]] = 1
|
| 520 |
+
|
| 521 |
+
# translation loss
|
| 522 |
+
q_trans_flat = q_trans.view(bs, -1)
|
| 523 |
+
action_trans_one_hot_flat = action_trans_one_hot.view(bs, -1)
|
| 524 |
+
q_trans_loss = self._celoss(q_trans_flat, action_trans_one_hot_flat)
|
| 525 |
+
|
| 526 |
+
with_rot_and_grip = rot_and_grip_indicies is not None
|
| 527 |
+
if with_rot_and_grip:
|
| 528 |
+
# rotation, gripper, and collision one-hots
|
| 529 |
+
action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 530 |
+
action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 531 |
+
action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 532 |
+
action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 533 |
+
action_ignore_collisions_one_hot = (
|
| 534 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
for b in range(bs):
|
| 538 |
+
gt_rot_grip = action_rot_grip[b, :].int()
|
| 539 |
+
action_rot_x_one_hot[b, gt_rot_grip[0]] = 1
|
| 540 |
+
action_rot_y_one_hot[b, gt_rot_grip[1]] = 1
|
| 541 |
+
action_rot_z_one_hot[b, gt_rot_grip[2]] = 1
|
| 542 |
+
action_grip_one_hot[b, gt_rot_grip[3]] = 1
|
| 543 |
+
|
| 544 |
+
gt_ignore_collisions = action_ignore_collisions[b, :].int()
|
| 545 |
+
action_ignore_collisions_one_hot[b, gt_ignore_collisions[0]] = 1
|
| 546 |
+
|
| 547 |
+
# flatten predictions
|
| 548 |
+
q_rot_x_flat = q_rot_grip[
|
| 549 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 550 |
+
]
|
| 551 |
+
q_rot_y_flat = q_rot_grip[
|
| 552 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 553 |
+
]
|
| 554 |
+
q_rot_z_flat = q_rot_grip[
|
| 555 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 556 |
+
]
|
| 557 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 558 |
+
q_ignore_collisions_flat = q_collision
|
| 559 |
+
|
| 560 |
+
# rotation loss
|
| 561 |
+
q_rot_loss += self._celoss(q_rot_x_flat, action_rot_x_one_hot)
|
| 562 |
+
q_rot_loss += self._celoss(q_rot_y_flat, action_rot_y_one_hot)
|
| 563 |
+
q_rot_loss += self._celoss(q_rot_z_flat, action_rot_z_one_hot)
|
| 564 |
+
|
| 565 |
+
# gripper loss
|
| 566 |
+
q_grip_loss += self._celoss(q_grip_flat, action_grip_one_hot)
|
| 567 |
+
|
| 568 |
+
# collision loss
|
| 569 |
+
q_collision_loss += self._celoss(
|
| 570 |
+
q_ignore_collisions_flat, action_ignore_collisions_one_hot
|
| 571 |
+
)
|
| 572 |
+
|
| 573 |
+
combined_losses = (
|
| 574 |
+
(q_trans_loss * self._trans_loss_weight)
|
| 575 |
+
+ (q_rot_loss * self._rot_loss_weight)
|
| 576 |
+
+ (q_grip_loss * self._grip_loss_weight)
|
| 577 |
+
+ (q_collision_loss * self._collision_loss_weight)
|
| 578 |
+
)
|
| 579 |
+
total_loss = combined_losses.mean()
|
| 580 |
+
if step % 10 == 0 and rank == 0 and wandb.run is not None:
|
| 581 |
+
wandb.log({
|
| 582 |
+
'train/grip_loss': q_grip_loss.mean(),
|
| 583 |
+
'train/trans_loss': q_trans_loss.mean(),
|
| 584 |
+
'train/rot_loss': q_rot_loss.mean(),
|
| 585 |
+
'train/collision_loss': q_collision_loss.mean(),
|
| 586 |
+
'train/total_loss': total_loss,
|
| 587 |
+
}, step=step)
|
| 588 |
+
|
| 589 |
+
self._optimizer.zero_grad()
|
| 590 |
+
total_loss.backward()
|
| 591 |
+
self._optimizer.step()
|
| 592 |
+
|
| 593 |
+
self._summaries = {
|
| 594 |
+
"losses/total_loss": total_loss,
|
| 595 |
+
"losses/trans_loss": q_trans_loss.mean(),
|
| 596 |
+
"losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 597 |
+
"losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 598 |
+
"losses/collision_loss": q_collision_loss.mean()
|
| 599 |
+
if with_rot_and_grip
|
| 600 |
+
else 0.0,
|
| 601 |
+
}
|
| 602 |
+
self._wandb_summaries = {
|
| 603 |
+
'losses/total_loss': total_loss,
|
| 604 |
+
'losses/trans_loss': q_trans_loss.mean(),
|
| 605 |
+
'losses/rot_loss': q_rot_loss.mean() if with_rot_and_grip else 0.,
|
| 606 |
+
'losses/grip_loss': q_grip_loss.mean() if with_rot_and_grip else 0.,
|
| 607 |
+
'losses/collision_loss': q_collision_loss.mean() if with_rot_and_grip else 0.
|
| 608 |
+
}
|
| 609 |
+
if self._lr_scheduler:
|
| 610 |
+
self._scheduler.step()
|
| 611 |
+
self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
|
| 612 |
+
|
| 613 |
+
self._vis_voxel_grid = voxel_grid[0]
|
| 614 |
+
self._vis_translation_qvalue = self._softmax_q_trans(q_trans[0])
|
| 615 |
+
self._vis_max_coordinate = coords[0]
|
| 616 |
+
self._vis_gt_coordinate = action_trans[0]
|
| 617 |
+
|
| 618 |
+
# Note: PerAct doesn't use multi-layer voxel grids like C2FARM
|
| 619 |
+
# stack prev_layer_voxel_grid(s) from previous layers into a list
|
| 620 |
+
if prev_layer_voxel_grid is None:
|
| 621 |
+
prev_layer_voxel_grid = [voxel_grid]
|
| 622 |
+
else:
|
| 623 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
|
| 624 |
+
|
| 625 |
+
# stack prev_layer_bound(s) from previous layers into a list
|
| 626 |
+
if prev_layer_bounds is None:
|
| 627 |
+
prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
|
| 628 |
+
else:
|
| 629 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 630 |
+
|
| 631 |
+
q_trans_vis=True
|
| 632 |
+
log_freq = getattr(getattr(getattr(self, "cfg", None), "framework", None), "log_freq", None)
|
| 633 |
+
if log_freq and step % log_freq == 0 and rank == 0:
|
| 634 |
+
print(f"{arm}_arm_predict: {self._vis_max_coordinate}")
|
| 635 |
+
print(f"{arm}_gt: {self._vis_gt_coordinate}")
|
| 636 |
+
rendered_img = visualise_voxel(
|
| 637 |
+
voxel_grid[0].cpu().detach().numpy(), # [10, 100, 100, 100]
|
| 638 |
+
self._vis_translation_qvalue.detach().cpu().numpy() if q_trans_vis else None,
|
| 639 |
+
self._vis_max_coordinate.detach().cpu().numpy(),
|
| 640 |
+
self._vis_gt_coordinate.detach().cpu().numpy(),
|
| 641 |
+
voxel_size=0.045,
|
| 642 |
+
# voxel_size=0.1, # more focus ??
|
| 643 |
+
rotation_amount=np.deg2rad(-90),
|
| 644 |
+
highlight_alpha=1.0,
|
| 645 |
+
alpha=0.4,
|
| 646 |
+
)
|
| 647 |
+
os.makedirs('recon', exist_ok=True)
|
| 648 |
+
# plot three images in one row with subplots:
|
| 649 |
+
rgb_src = obs[0][0][0].squeeze(0).permute(1, 2, 0) / 2 + 0.5
|
| 650 |
+
|
| 651 |
+
fig, axs = plt.subplots(1, 4, figsize=(9, 3))
|
| 652 |
+
# src
|
| 653 |
+
axs[0].imshow(rgb_src.cpu().numpy())
|
| 654 |
+
axs[0].title.set_text('src')
|
| 655 |
+
|
| 656 |
+
axs[1].imshow(rendered_img)
|
| 657 |
+
axs[1].text(0, 40, 'predicted', color='blue')
|
| 658 |
+
axs[1].text(0, 80, 'gt', color='red')
|
| 659 |
+
for ax in axs:
|
| 660 |
+
ax.axis('off')
|
| 661 |
+
plt.tight_layout()
|
| 662 |
+
|
| 663 |
+
if rank == 0:
|
| 664 |
+
if wandb.run is not None:
|
| 665 |
+
buf = io.BytesIO()
|
| 666 |
+
plt.savefig(buf, format='png')
|
| 667 |
+
buf.seek(0)
|
| 668 |
+
|
| 669 |
+
image = Image.open(buf)
|
| 670 |
+
wandb.log({"eval/recon_img": wandb.Image(image)}, step=step)
|
| 671 |
+
|
| 672 |
+
buf.close()
|
| 673 |
+
cprint(f'Saved to wandb', 'cyan')
|
| 674 |
+
else:
|
| 675 |
+
plt.savefig(f'recon/{step}_rgb.png')
|
| 676 |
+
workdir = os.getcwd()
|
| 677 |
+
cprint(f'Saved {workdir}/recon/{step}_rgb.png locally', 'cyan')
|
| 678 |
+
return {
|
| 679 |
+
"total_loss": total_loss,
|
| 680 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 681 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 682 |
+
}
|
| 683 |
+
|
| 684 |
+
def update_wandb_summaries(self):
|
| 685 |
+
summaries = dict()
|
| 686 |
+
for k, v in self._wandb_summaries.items():
|
| 687 |
+
summaries[k] = v
|
| 688 |
+
return summaries
|
| 689 |
+
|
| 690 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 691 |
+
deterministic = True
|
| 692 |
+
bounds = self._coordinate_bounds
|
| 693 |
+
prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
|
| 694 |
+
prev_layer_bounds = observation.get("prev_layer_bounds", None)
|
| 695 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 696 |
+
|
| 697 |
+
# extract CLIP language embs
|
| 698 |
+
with torch.no_grad():
|
| 699 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 700 |
+
(
|
| 701 |
+
lang_goal_emb,
|
| 702 |
+
lang_token_embs,
|
| 703 |
+
) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
|
| 704 |
+
|
| 705 |
+
# voxelization resolution
|
| 706 |
+
res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
|
| 707 |
+
max_rot_index = int(360 // self._rotation_resolution)
|
| 708 |
+
proprio = None
|
| 709 |
+
|
| 710 |
+
if self._include_low_dim_state:
|
| 711 |
+
proprio = observation["low_dim_state"]
|
| 712 |
+
proprio = proprio[0].to(self._device)
|
| 713 |
+
|
| 714 |
+
obs, pcd = self._act_preprocess_inputs(observation)
|
| 715 |
+
|
| 716 |
+
# correct batch size and device
|
| 717 |
+
obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
|
| 718 |
+
pcd = [p[0].to(self._device) for p in pcd]
|
| 719 |
+
lang_goal_emb = lang_goal_emb.to(self._device)
|
| 720 |
+
lang_token_embs = lang_token_embs.to(self._device)
|
| 721 |
+
bounds = torch.as_tensor(bounds, device=self._device)
|
| 722 |
+
prev_layer_voxel_grid = (
|
| 723 |
+
prev_layer_voxel_grid.to(self._device)
|
| 724 |
+
if prev_layer_voxel_grid is not None
|
| 725 |
+
else None
|
| 726 |
+
)
|
| 727 |
+
prev_layer_bounds = (
|
| 728 |
+
prev_layer_bounds.to(self._device)
|
| 729 |
+
if prev_layer_bounds is not None
|
| 730 |
+
else None
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# inference
|
| 734 |
+
q_trans, q_rot_grip, q_ignore_collisions, vox_grid = self._q(
|
| 735 |
+
obs,
|
| 736 |
+
proprio,
|
| 737 |
+
pcd,
|
| 738 |
+
lang_goal_emb,
|
| 739 |
+
lang_token_embs,
|
| 740 |
+
bounds,
|
| 741 |
+
prev_layer_bounds,
|
| 742 |
+
prev_layer_voxel_grid,
|
| 743 |
+
)
|
| 744 |
+
|
| 745 |
+
# softmax Q predictions
|
| 746 |
+
q_trans = self._softmax_q_trans(q_trans)
|
| 747 |
+
q_rot_grip = (
|
| 748 |
+
self._softmax_q_rot_grip(q_rot_grip)
|
| 749 |
+
if q_rot_grip is not None
|
| 750 |
+
else q_rot_grip
|
| 751 |
+
)
|
| 752 |
+
q_ignore_collisions = (
|
| 753 |
+
self._softmax_ignore_collision(q_ignore_collisions)
|
| 754 |
+
if q_ignore_collisions is not None
|
| 755 |
+
else q_ignore_collisions
|
| 756 |
+
)
|
| 757 |
+
|
| 758 |
+
# argmax Q predictions
|
| 759 |
+
(
|
| 760 |
+
coords,
|
| 761 |
+
rot_and_grip_indicies,
|
| 762 |
+
ignore_collisions,
|
| 763 |
+
) = self._q.choose_highest_action(q_trans, q_rot_grip, q_ignore_collisions)
|
| 764 |
+
|
| 765 |
+
rot_grip_action = rot_and_grip_indicies if q_rot_grip is not None else None
|
| 766 |
+
ignore_collisions_action = (
|
| 767 |
+
ignore_collisions.int() if ignore_collisions is not None else None
|
| 768 |
+
)
|
| 769 |
+
|
| 770 |
+
coords = coords.int()
|
| 771 |
+
attention_coordinate = bounds[:, :3] + res * coords + res / 2
|
| 772 |
+
|
| 773 |
+
# stack prev_layer_voxel_grid(s) into a list
|
| 774 |
+
# NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
|
| 775 |
+
if prev_layer_voxel_grid is None:
|
| 776 |
+
prev_layer_voxel_grid = [vox_grid]
|
| 777 |
+
else:
|
| 778 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
|
| 779 |
+
|
| 780 |
+
if prev_layer_bounds is None:
|
| 781 |
+
prev_layer_bounds = [bounds]
|
| 782 |
+
else:
|
| 783 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 784 |
+
|
| 785 |
+
observation_elements = {
|
| 786 |
+
"attention_coordinate": attention_coordinate,
|
| 787 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 788 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 789 |
+
}
|
| 790 |
+
info = {
|
| 791 |
+
"voxel_grid_depth%d" % self._layer: vox_grid,
|
| 792 |
+
"q_depth%d" % self._layer: q_trans,
|
| 793 |
+
"voxel_idx_depth%d" % self._layer: coords,
|
| 794 |
+
}
|
| 795 |
+
self._act_voxel_grid = vox_grid[0]
|
| 796 |
+
self._act_max_coordinate = coords[0]
|
| 797 |
+
self._act_qvalues = q_trans[0].detach()
|
| 798 |
+
return ActResult(
|
| 799 |
+
(coords, rot_grip_action, ignore_collisions_action),
|
| 800 |
+
observation_elements=observation_elements,
|
| 801 |
+
info=info,
|
| 802 |
+
)
|
| 803 |
+
|
| 804 |
+
def update_summaries(self) -> List[Summary]:
|
| 805 |
+
summaries = [
|
| 806 |
+
ImageSummary(
|
| 807 |
+
"%s/update_qattention" % self._name,
|
| 808 |
+
transforms.ToTensor()(
|
| 809 |
+
visualise_voxel(
|
| 810 |
+
self._vis_voxel_grid.detach().cpu().numpy(),
|
| 811 |
+
self._vis_translation_qvalue.detach().cpu().numpy(),
|
| 812 |
+
self._vis_max_coordinate.detach().cpu().numpy(),
|
| 813 |
+
self._vis_gt_coordinate.detach().cpu().numpy(),
|
| 814 |
+
)
|
| 815 |
+
),
|
| 816 |
+
)
|
| 817 |
+
]
|
| 818 |
+
|
| 819 |
+
for n, v in self._summaries.items():
|
| 820 |
+
summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
|
| 821 |
+
|
| 822 |
+
for name, crop in self._crop_summary:
|
| 823 |
+
crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
|
| 824 |
+
summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
|
| 825 |
+
|
| 826 |
+
for tag, param in self._q.named_parameters():
|
| 827 |
+
# assert not torch.isnan(param.grad.abs() <= 1.0).all()
|
| 828 |
+
summaries.append(
|
| 829 |
+
HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
|
| 830 |
+
)
|
| 831 |
+
summaries.append(
|
| 832 |
+
HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
|
| 833 |
+
)
|
| 834 |
+
|
| 835 |
+
return summaries
|
| 836 |
+
|
| 837 |
+
def act_summaries(self) -> List[Summary]:
|
| 838 |
+
return [
|
| 839 |
+
ImageSummary(
|
| 840 |
+
"%s/act_Qattention" % self._name,
|
| 841 |
+
transforms.ToTensor()(
|
| 842 |
+
visualise_voxel(
|
| 843 |
+
self._act_voxel_grid.cpu().numpy(),
|
| 844 |
+
self._act_qvalues.cpu().numpy(),
|
| 845 |
+
self._act_max_coordinate.cpu().numpy(),
|
| 846 |
+
)
|
| 847 |
+
),
|
| 848 |
+
)
|
| 849 |
+
]
|
| 850 |
+
def concat_weights(self, param, target_size, dims=-1):
|
| 851 |
+
if param.size(-1) < target_size:
|
| 852 |
+
param = torch.cat([param, param], dims)
|
| 853 |
+
return param
|
| 854 |
+
|
| 855 |
+
def load_weights(self, savedir: str):
|
| 856 |
+
device = (
|
| 857 |
+
self._device
|
| 858 |
+
if not self._training
|
| 859 |
+
else torch.device("cuda:%d" % self._device)
|
| 860 |
+
)
|
| 861 |
+
weight_file = os.path.join(savedir, "%s.pt" % self._name)
|
| 862 |
+
state_dict = torch.load(weight_file, map_location=device)
|
| 863 |
+
|
| 864 |
+
# load only keys that are in the current model
|
| 865 |
+
merged_state_dict = self._q.state_dict()
|
| 866 |
+
if not self._training:
|
| 867 |
+
for k, v in state_dict.items():
|
| 868 |
+
if not self._training:
|
| 869 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 870 |
+
if k in merged_state_dict:
|
| 871 |
+
merged_state_dict[k] = v
|
| 872 |
+
else:
|
| 873 |
+
if "_voxelizer" not in k:
|
| 874 |
+
logging.warning("key %s not found in checkpoint" % k)
|
| 875 |
+
else:
|
| 876 |
+
for k, v in state_dict.items():
|
| 877 |
+
if not self._training:
|
| 878 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 879 |
+
elif k == "_qnet.module.pos_encoding":
|
| 880 |
+
if (v.shape[1] != 8077 or v.shape[1] != 8154) and v.shape[1] < 154:
|
| 881 |
+
if self.anybimanual:
|
| 882 |
+
lang_max_seq_len = 154
|
| 883 |
+
else:
|
| 884 |
+
lang_max_seq_len = 77
|
| 885 |
+
spatial_size = v.shape[1]
|
| 886 |
+
input_dim_before_seq = v.shape[-1]
|
| 887 |
+
flattened_v = v.view(1, -1, input_dim_before_seq) # (1, spatial_size**3, self.input_dim_before_seq)
|
| 888 |
+
new_pos_encoding = torch.randn(1, lang_max_seq_len, input_dim_before_seq, device=device)
|
| 889 |
+
merged_pos_encoding = torch.cat([flattened_v, new_pos_encoding], dim=1) # (1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq)
|
| 890 |
+
merged_state_dict["_qnet.module.pos_encoding"] = merged_pos_encoding
|
| 891 |
+
else:
|
| 892 |
+
merged_state_dict["_qnet.module.pos_encoding"] = v
|
| 893 |
+
elif k.startswith("_qnet.module.cross_attend_blocks"):
|
| 894 |
+
if self.anybimanual:
|
| 895 |
+
if v.size(-1) == 128:
|
| 896 |
+
merged_state_dict[k] = self.concat_weights(v, 256)
|
| 897 |
+
elif k.startswith("_qnet.module.decoder_cross_attn"):
|
| 898 |
+
if self.anybimanual:
|
| 899 |
+
if v.size(0) == 128:
|
| 900 |
+
merged_state_dict[k] = self.concat_weights(v, 256, 0)
|
| 901 |
+
merged_state_dict[k] = self.concat_weights(v, 256, 0)
|
| 902 |
+
if v.size(-1) == 128:
|
| 903 |
+
merged_state_dict[k] = self.concat_weights(v, 256)
|
| 904 |
+
merged_state_dict[k] = self.concat_weights(v, 256)
|
| 905 |
+
elif k == "_qnet.module.up0.conv_up.0.conv3d.weight":
|
| 906 |
+
if self.anybimanual:
|
| 907 |
+
if v.size(1) == 128:
|
| 908 |
+
merged_state_dict[k] = self.concat_weights(v, 256, 1)
|
| 909 |
+
elif k.startswith("_qnet.module.dense0"):
|
| 910 |
+
if self.anybimanual:
|
| 911 |
+
if v.size(-1) == 1024:
|
| 912 |
+
merged_state_dict[k] = torch.cat([v, v[:, :512]], dim=-1)
|
| 913 |
+
elif k in merged_state_dict:
|
| 914 |
+
merged_state_dict[k] = v
|
| 915 |
+
else:
|
| 916 |
+
if "_voxelizer" not in k:
|
| 917 |
+
logging.warning("key %s not found in checkpoint" % k)
|
| 918 |
+
|
| 919 |
+
if not self._training:
|
| 920 |
+
# reshape voxelizer weights
|
| 921 |
+
b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
|
| 922 |
+
merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
|
| 923 |
+
"_voxelizer._ones_max_coords"
|
| 924 |
+
][0:1]
|
| 925 |
+
flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
|
| 926 |
+
merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
|
| 927 |
+
"_voxelizer._flat_output"
|
| 928 |
+
][0 : flat_shape // b]
|
| 929 |
+
merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
|
| 930 |
+
"_voxelizer._tiled_batch_indices"
|
| 931 |
+
][0:1]
|
| 932 |
+
merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
|
| 933 |
+
"_voxelizer._index_grid"
|
| 934 |
+
][0:1]
|
| 935 |
+
self._q.load_state_dict(merged_state_dict)
|
| 936 |
+
print("loaded weights from %s" % weight_file)
|
| 937 |
+
|
| 938 |
+
def save_weights(self, savedir: str):
|
| 939 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
|
third_party/AnyBimanual/agents/peract_bc/qattention_stack_agent.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from yarr.agents.agent import Agent, ActResult, Summary
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from helpers import utils
|
| 9 |
+
from agents.peract_bc.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 10 |
+
|
| 11 |
+
NAME = "QAttentionStackAgent"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QAttentionStackAgent(Agent):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
qattention_agents: List[QAttentionPerActBCAgent],
|
| 18 |
+
rotation_resolution: float,
|
| 19 |
+
camera_names: List[str],
|
| 20 |
+
rotation_prediction_depth: int = 0,
|
| 21 |
+
):
|
| 22 |
+
super(QAttentionStackAgent, self).__init__()
|
| 23 |
+
self._qattention_agents = qattention_agents
|
| 24 |
+
self._rotation_resolution = rotation_resolution
|
| 25 |
+
self._camera_names = camera_names
|
| 26 |
+
self._rotation_prediction_depth = rotation_prediction_depth
|
| 27 |
+
|
| 28 |
+
def build(self, training: bool, device=None) -> None:
|
| 29 |
+
self._device = device
|
| 30 |
+
if self._device is None:
|
| 31 |
+
self._device = torch.device("cpu")
|
| 32 |
+
for qa in self._qattention_agents:
|
| 33 |
+
qa.build(training, device)
|
| 34 |
+
|
| 35 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 36 |
+
priorities = 0
|
| 37 |
+
total_losses = 0.0
|
| 38 |
+
for qa in self._qattention_agents:
|
| 39 |
+
update_dict = qa.update(step, replay_sample)
|
| 40 |
+
replay_sample.update(update_dict)
|
| 41 |
+
total_losses += update_dict["total_loss"]
|
| 42 |
+
return {
|
| 43 |
+
"total_losses": total_losses,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 47 |
+
observation_elements = {}
|
| 48 |
+
translation_results, rot_grip_results, ignore_collisions_results = [], [], []
|
| 49 |
+
infos = {}
|
| 50 |
+
for depth, qagent in enumerate(self._qattention_agents):
|
| 51 |
+
act_results = qagent.act(step, observation, deterministic)
|
| 52 |
+
attention_coordinate = (
|
| 53 |
+
act_results.observation_elements["attention_coordinate"].cpu().numpy()
|
| 54 |
+
)
|
| 55 |
+
observation_elements[
|
| 56 |
+
"attention_coordinate_layer_%d" % depth
|
| 57 |
+
] = attention_coordinate[0]
|
| 58 |
+
|
| 59 |
+
translation_idxs, rot_grip_idxs, ignore_collisions_idxs = act_results.action
|
| 60 |
+
translation_results.append(translation_idxs)
|
| 61 |
+
if rot_grip_idxs is not None:
|
| 62 |
+
rot_grip_results.append(rot_grip_idxs)
|
| 63 |
+
if ignore_collisions_idxs is not None:
|
| 64 |
+
ignore_collisions_results.append(ignore_collisions_idxs)
|
| 65 |
+
|
| 66 |
+
observation["attention_coordinate"] = act_results.observation_elements[
|
| 67 |
+
"attention_coordinate"
|
| 68 |
+
]
|
| 69 |
+
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
|
| 70 |
+
"prev_layer_voxel_grid"
|
| 71 |
+
]
|
| 72 |
+
observation["prev_layer_bounds"] = act_results.observation_elements[
|
| 73 |
+
"prev_layer_bounds"
|
| 74 |
+
]
|
| 75 |
+
|
| 76 |
+
for n in self._camera_names:
|
| 77 |
+
px, py = utils.point_to_pixel_index(
|
| 78 |
+
attention_coordinate[0],
|
| 79 |
+
observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy(),
|
| 80 |
+
observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy(),
|
| 81 |
+
)
|
| 82 |
+
pc_t = torch.tensor(
|
| 83 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 84 |
+
)
|
| 85 |
+
observation["%s_pixel_coord" % n] = pc_t
|
| 86 |
+
observation_elements["%s_pixel_coord" % n] = [py, px]
|
| 87 |
+
|
| 88 |
+
infos.update(act_results.info)
|
| 89 |
+
|
| 90 |
+
rgai = torch.cat(rot_grip_results, 1)[0].cpu().numpy()
|
| 91 |
+
ignore_collisions = float(
|
| 92 |
+
torch.cat(ignore_collisions_results, 1)[0].cpu().numpy()
|
| 93 |
+
)
|
| 94 |
+
observation_elements["trans_action_indicies"] = (
|
| 95 |
+
torch.cat(translation_results, 1)[0].cpu().numpy()
|
| 96 |
+
)
|
| 97 |
+
observation_elements["rot_grip_action_indicies"] = rgai
|
| 98 |
+
continuous_action = np.concatenate(
|
| 99 |
+
[
|
| 100 |
+
act_results.observation_elements["attention_coordinate"]
|
| 101 |
+
.cpu()
|
| 102 |
+
.numpy()[0],
|
| 103 |
+
utils.discrete_euler_to_quaternion(
|
| 104 |
+
rgai[-4:-1], self._rotation_resolution
|
| 105 |
+
),
|
| 106 |
+
rgai[-1:],
|
| 107 |
+
[ignore_collisions],
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
return ActResult(
|
| 111 |
+
continuous_action, observation_elements=observation_elements, info=infos
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
def update_summaries(self) -> List[Summary]:
|
| 115 |
+
summaries = []
|
| 116 |
+
for qa in self._qattention_agents:
|
| 117 |
+
summaries.extend(qa.update_summaries())
|
| 118 |
+
return summaries
|
| 119 |
+
|
| 120 |
+
def act_summaries(self) -> List[Summary]:
|
| 121 |
+
s = []
|
| 122 |
+
for qa in self._qattention_agents:
|
| 123 |
+
s.extend(qa.act_summaries())
|
| 124 |
+
return s
|
| 125 |
+
|
| 126 |
+
def load_weights(self, savedir: str):
|
| 127 |
+
for qa in self._qattention_agents:
|
| 128 |
+
qa.load_weights(savedir)
|
| 129 |
+
|
| 130 |
+
def save_weights(self, savedir: str):
|
| 131 |
+
for qa in self._qattention_agents:
|
| 132 |
+
qa.save_weights(savedir)
|
third_party/AnyBimanual/agents/peract_bc/skill_manager.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import transformers
|
| 4 |
+
from agents.peract_bimanual.trajectory_gpt2 import GPT2Model
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
class SkillManager(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
num_classes,
|
| 10 |
+
embedding_matrix=None,
|
| 11 |
+
voxel_dim=128,
|
| 12 |
+
lang_dim=128,
|
| 13 |
+
hidden_size=128,
|
| 14 |
+
output_dim=18,
|
| 15 |
+
max_voxels=8000,
|
| 16 |
+
max_lang_tokens=77,
|
| 17 |
+
**kwargs):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.hidden_size = hidden_size
|
| 21 |
+
self.output_dim = output_dim
|
| 22 |
+
|
| 23 |
+
# GPT-2 configuration
|
| 24 |
+
config = transformers.GPT2Config(
|
| 25 |
+
vocab_size=1, # not used
|
| 26 |
+
n_embd=hidden_size,
|
| 27 |
+
n_head=4,
|
| 28 |
+
n_ctx=1077,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.max_voxels = max_voxels
|
| 32 |
+
self.max_lang_tokens = max_lang_tokens
|
| 33 |
+
self.embed_voxel = nn.Linear(voxel_dim, hidden_size)
|
| 34 |
+
self.embed_lang = nn.Linear(lang_dim, hidden_size)
|
| 35 |
+
self.transformer = GPT2Model(config)
|
| 36 |
+
self.embed_ln = nn.LayerNorm(hidden_size)
|
| 37 |
+
self.predict_logits = nn.Linear(hidden_size, output_dim)
|
| 38 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
self.num_class = num_classes
|
| 40 |
+
if embedding_matrix is not None:
|
| 41 |
+
self.embeddings_matrix = embedding_matrix.to(self.device)
|
| 42 |
+
|
| 43 |
+
def forward(self, voxel_embedding, language_embedding):
|
| 44 |
+
batch_size = voxel_embedding.shape[0]
|
| 45 |
+
voxel_embeddings = self.embed_voxel(voxel_embedding) # [b, 8000, hidden_size]
|
| 46 |
+
language_embeddings = self.embed_lang(language_embedding) # [b, 77, hidden_size]
|
| 47 |
+
voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, hidden_size, 8000]
|
| 48 |
+
voxel_embeddings = F.avg_pool1d(voxel_embeddings, kernel_size=16, stride=16) # [b, hidden_size, 1000]
|
| 49 |
+
voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, 1000, hidden_size]
|
| 50 |
+
inputs = torch.cat([language_embeddings, voxel_embeddings], dim=1) # [b, 8077, hidden_size]
|
| 51 |
+
stacked_inputs = self.embed_ln(inputs)
|
| 52 |
+
attention_mask = torch.ones(
|
| 53 |
+
(batch_size, self.max_lang_tokens + self.max_voxels),
|
| 54 |
+
device=voxel_embedding.device,
|
| 55 |
+
dtype=torch.long # Ensure correct dtype
|
| 56 |
+
)
|
| 57 |
+
assert torch.isfinite(attention_mask).all(), "attention_mask contains NaN or Inf"
|
| 58 |
+
assert torch.all((attention_mask == 1)), "attention_mask contains values not equal to 1"
|
| 59 |
+
transformer_outputs = self.transformer(
|
| 60 |
+
inputs_embeds=stacked_inputs,
|
| 61 |
+
attention_mask=None,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
hidden_state = transformer_outputs.last_hidden_state # [b, 8077, hidden_size]
|
| 65 |
+
aggregated_hidden = hidden_state.mean(dim=1) # [b, hidden_size]
|
| 66 |
+
logits = self.predict_logits(aggregated_hidden) # [b, output_dim]
|
| 67 |
+
probs = F.softmax(logits, dim=1)
|
| 68 |
+
skill = torch.matmul(probs, self.embeddings_matrix.to(probs.device))
|
| 69 |
+
skill = skill.view(-1,77,512)
|
| 70 |
+
return skill
|
third_party/AnyBimanual/agents/peract_bc/trajectory_gpt2.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch OpenAI GPT-2 model."""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from transformers.activations import ACT2FN
|
| 27 |
+
from transformers.file_utils import (
|
| 28 |
+
ModelOutput,
|
| 29 |
+
add_code_sample_docstrings,
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
replace_return_docstrings,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_outputs import (
|
| 35 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 36 |
+
)
|
| 37 |
+
from transformers.modeling_utils import (
|
| 38 |
+
Conv1D,
|
| 39 |
+
PreTrainedModel,
|
| 40 |
+
SequenceSummary,
|
| 41 |
+
find_pruneable_heads_and_indices,
|
| 42 |
+
prune_conv1d_layer,
|
| 43 |
+
)
|
| 44 |
+
from transformers.utils import logging
|
| 45 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 46 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
_CONFIG_FOR_DOC = "GPT2Config"
|
| 51 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
| 52 |
+
|
| 53 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 54 |
+
"gpt2",
|
| 55 |
+
"gpt2-medium",
|
| 56 |
+
"gpt2-large",
|
| 57 |
+
"gpt2-xl",
|
| 58 |
+
"distilgpt2",
|
| 59 |
+
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
| 64 |
+
"""Load tf checkpoints in a pytorch model"""
|
| 65 |
+
try:
|
| 66 |
+
import re
|
| 67 |
+
|
| 68 |
+
import tensorflow as tf
|
| 69 |
+
except ImportError:
|
| 70 |
+
logger.error(
|
| 71 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 72 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 73 |
+
)
|
| 74 |
+
raise
|
| 75 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
| 76 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 77 |
+
# Load weights from TF model
|
| 78 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 79 |
+
names = []
|
| 80 |
+
arrays = []
|
| 81 |
+
for name, shape in init_vars:
|
| 82 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
| 83 |
+
array = tf.train.load_variable(tf_path, name)
|
| 84 |
+
names.append(name)
|
| 85 |
+
arrays.append(array.squeeze())
|
| 86 |
+
|
| 87 |
+
for name, array in zip(names, arrays):
|
| 88 |
+
name = name[6:] # skip "model/"
|
| 89 |
+
name = name.split("/")
|
| 90 |
+
pointer = model
|
| 91 |
+
for m_name in name:
|
| 92 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
| 93 |
+
scope_names = re.split(r"(\d+)", m_name)
|
| 94 |
+
else:
|
| 95 |
+
scope_names = [m_name]
|
| 96 |
+
if scope_names[0] == "w" or scope_names[0] == "g":
|
| 97 |
+
pointer = getattr(pointer, "weight")
|
| 98 |
+
elif scope_names[0] == "b":
|
| 99 |
+
pointer = getattr(pointer, "bias")
|
| 100 |
+
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
| 101 |
+
pointer = getattr(pointer, scope_names[0])
|
| 102 |
+
pointer = getattr(pointer, "weight")
|
| 103 |
+
else:
|
| 104 |
+
pointer = getattr(pointer, scope_names[0])
|
| 105 |
+
if len(scope_names) >= 2:
|
| 106 |
+
num = int(scope_names[1])
|
| 107 |
+
pointer = pointer[num]
|
| 108 |
+
try:
|
| 109 |
+
assert (
|
| 110 |
+
pointer.shape == array.shape
|
| 111 |
+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
| 112 |
+
except AssertionError as e:
|
| 113 |
+
e.args += (pointer.shape, array.shape)
|
| 114 |
+
raise
|
| 115 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
| 116 |
+
pointer.data = torch.from_numpy(array)
|
| 117 |
+
return model
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Attention(nn.Module):
|
| 121 |
+
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
| 125 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
| 126 |
+
assert n_state % config.n_head == 0
|
| 127 |
+
self.register_buffer(
|
| 128 |
+
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
|
| 129 |
+
)
|
| 130 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
| 131 |
+
self.n_head = config.n_head
|
| 132 |
+
self.split_size = n_state
|
| 133 |
+
self.scale = scale
|
| 134 |
+
self.is_cross_attention = is_cross_attention
|
| 135 |
+
if self.is_cross_attention:
|
| 136 |
+
self.c_attn = Conv1D(2 * n_state, nx)
|
| 137 |
+
self.q_attn = Conv1D(n_state, nx)
|
| 138 |
+
else:
|
| 139 |
+
self.c_attn = Conv1D(3 * n_state, nx)
|
| 140 |
+
self.c_proj = Conv1D(n_state, nx)
|
| 141 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 142 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 143 |
+
self.pruned_heads = set()
|
| 144 |
+
|
| 145 |
+
def prune_heads(self, heads):
|
| 146 |
+
if len(heads) == 0:
|
| 147 |
+
return
|
| 148 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 149 |
+
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
| 150 |
+
)
|
| 151 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
| 152 |
+
|
| 153 |
+
# Prune conv1d layers
|
| 154 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
| 155 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
| 156 |
+
|
| 157 |
+
# Update hyper params
|
| 158 |
+
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
| 159 |
+
self.n_head = self.n_head - len(heads)
|
| 160 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 161 |
+
|
| 162 |
+
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
|
| 163 |
+
w = torch.matmul(q, k)
|
| 164 |
+
if self.scale:
|
| 165 |
+
w = w / (float(v.size(-1)) ** 0.5)
|
| 166 |
+
nd, ns = w.size(-2), w.size(-1)
|
| 167 |
+
|
| 168 |
+
if not self.is_cross_attention:
|
| 169 |
+
# if only "normal" attention layer implements causal mask
|
| 170 |
+
mask = self.bias[:, :, ns - nd: ns, :ns]
|
| 171 |
+
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
| 172 |
+
|
| 173 |
+
if attention_mask is not None:
|
| 174 |
+
# Apply the attention mask
|
| 175 |
+
w = w + attention_mask
|
| 176 |
+
|
| 177 |
+
w = nn.Softmax(dim=-1)(w)
|
| 178 |
+
w = self.attn_dropout(w)
|
| 179 |
+
|
| 180 |
+
# Mask heads if we want to
|
| 181 |
+
if head_mask is not None:
|
| 182 |
+
w = w * head_mask
|
| 183 |
+
|
| 184 |
+
outputs = [torch.matmul(w, v)]
|
| 185 |
+
if output_attentions:
|
| 186 |
+
outputs.append(w)
|
| 187 |
+
return outputs
|
| 188 |
+
|
| 189 |
+
def merge_heads(self, x):
|
| 190 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 191 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
| 192 |
+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
| 193 |
+
|
| 194 |
+
def split_heads(self, x, k=False):
|
| 195 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
| 196 |
+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
| 197 |
+
if k:
|
| 198 |
+
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
| 199 |
+
else:
|
| 200 |
+
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
hidden_states,
|
| 205 |
+
layer_past=None,
|
| 206 |
+
attention_mask=None,
|
| 207 |
+
head_mask=None,
|
| 208 |
+
encoder_hidden_states=None,
|
| 209 |
+
encoder_attention_mask=None,
|
| 210 |
+
use_cache=False,
|
| 211 |
+
output_attentions=False,
|
| 212 |
+
):
|
| 213 |
+
if encoder_hidden_states is not None:
|
| 214 |
+
assert hasattr(
|
| 215 |
+
self, "q_attn"
|
| 216 |
+
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
|
| 217 |
+
query = self.q_attn(hidden_states)
|
| 218 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
| 219 |
+
attention_mask = encoder_attention_mask
|
| 220 |
+
else:
|
| 221 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 222 |
+
|
| 223 |
+
query = self.split_heads(query)
|
| 224 |
+
key = self.split_heads(key, k=True)
|
| 225 |
+
value = self.split_heads(value)
|
| 226 |
+
if layer_past is not None:
|
| 227 |
+
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
| 228 |
+
key = torch.cat((past_key, key), dim=-1)
|
| 229 |
+
value = torch.cat((past_value, value), dim=-2)
|
| 230 |
+
|
| 231 |
+
if use_cache is True:
|
| 232 |
+
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
| 233 |
+
else:
|
| 234 |
+
present = (None,)
|
| 235 |
+
|
| 236 |
+
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
| 237 |
+
a = attn_outputs[0]
|
| 238 |
+
|
| 239 |
+
a = self.merge_heads(a)
|
| 240 |
+
a = self.c_proj(a)
|
| 241 |
+
a = self.resid_dropout(a)
|
| 242 |
+
|
| 243 |
+
outputs = [a, present] + attn_outputs[1:]
|
| 244 |
+
return outputs # a, present, (attentions)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class MLP(nn.Module):
|
| 248 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
| 249 |
+
super().__init__()
|
| 250 |
+
nx = config.n_embd
|
| 251 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 252 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 253 |
+
self.act = ACT2FN[config.activation_function]
|
| 254 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 255 |
+
|
| 256 |
+
def forward(self, x):
|
| 257 |
+
h = self.act(self.c_fc(x))
|
| 258 |
+
h2 = self.c_proj(h)
|
| 259 |
+
return self.dropout(h2)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class AdapterMLP(nn.Module):
|
| 263 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
| 264 |
+
super().__init__()
|
| 265 |
+
nx = config.n_embd
|
| 266 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 267 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 268 |
+
self.act = ACT2FN[config.activation_function]
|
| 269 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 270 |
+
|
| 271 |
+
def forward(self, x):
|
| 272 |
+
h = self.act(self.c_fc(x))
|
| 273 |
+
h2 = self.c_proj(h)
|
| 274 |
+
return self.dropout(h2)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Block(nn.Module):
|
| 278 |
+
def __init__(self, n_ctx, config, scale=False):
|
| 279 |
+
super().__init__()
|
| 280 |
+
hidden_size = config.n_embd
|
| 281 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
| 282 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 283 |
+
self.attn = Attention(hidden_size, n_ctx, config, scale)
|
| 284 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 285 |
+
# self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 286 |
+
if config.add_cross_attention:
|
| 287 |
+
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
|
| 288 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 289 |
+
self.mlp = MLP(inner_dim, config)
|
| 290 |
+
# self.adapter_mlp = AdapterMLP(512, config) # ADAPTER
|
| 291 |
+
|
| 292 |
+
def forward(
|
| 293 |
+
self,
|
| 294 |
+
hidden_states,
|
| 295 |
+
layer_past=None,
|
| 296 |
+
attention_mask=None,
|
| 297 |
+
head_mask=None,
|
| 298 |
+
encoder_hidden_states=None,
|
| 299 |
+
encoder_attention_mask=None,
|
| 300 |
+
use_cache=False,
|
| 301 |
+
output_attentions=False,
|
| 302 |
+
):
|
| 303 |
+
attn_outputs = self.attn(
|
| 304 |
+
self.ln_1(hidden_states),
|
| 305 |
+
layer_past=layer_past,
|
| 306 |
+
attention_mask=attention_mask,
|
| 307 |
+
head_mask=head_mask,
|
| 308 |
+
use_cache=use_cache,
|
| 309 |
+
output_attentions=output_attentions,
|
| 310 |
+
)
|
| 311 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
| 312 |
+
outputs = attn_outputs[1:]
|
| 313 |
+
# residual connection
|
| 314 |
+
hidden_states = attn_output + hidden_states
|
| 315 |
+
|
| 316 |
+
if encoder_hidden_states is not None:
|
| 317 |
+
# add one self-attention block for cross-attention
|
| 318 |
+
assert hasattr(
|
| 319 |
+
self, "crossattention"
|
| 320 |
+
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
| 321 |
+
cross_attn_outputs = self.crossattention(
|
| 322 |
+
self.ln_cross_attn(hidden_states),
|
| 323 |
+
attention_mask=attention_mask,
|
| 324 |
+
head_mask=head_mask,
|
| 325 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 326 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 327 |
+
output_attentions=output_attentions,
|
| 328 |
+
)
|
| 329 |
+
attn_output = cross_attn_outputs[0]
|
| 330 |
+
# residual connection
|
| 331 |
+
hidden_states = hidden_states + attn_output
|
| 332 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
| 333 |
+
|
| 334 |
+
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
| 335 |
+
# residual connection
|
| 336 |
+
hidden_states = hidden_states + feed_forward_hidden_states
|
| 337 |
+
# hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states))
|
| 338 |
+
|
| 339 |
+
outputs = [hidden_states] + outputs
|
| 340 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class GPT2PreTrainedModel(PreTrainedModel):
|
| 344 |
+
"""
|
| 345 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 346 |
+
models.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
config_class = GPT2Config
|
| 350 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
| 351 |
+
base_model_prefix = "transformer"
|
| 352 |
+
|
| 353 |
+
def __init__(self, *inputs, **kwargs):
|
| 354 |
+
super().__init__(*inputs, **kwargs)
|
| 355 |
+
|
| 356 |
+
def _init_weights(self, module):
|
| 357 |
+
"""Initialize the weights."""
|
| 358 |
+
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
| 359 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 360 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 361 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 362 |
+
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
| 363 |
+
module.bias.data.zero_()
|
| 364 |
+
elif isinstance(module, nn.LayerNorm):
|
| 365 |
+
module.bias.data.zero_()
|
| 366 |
+
module.weight.data.fill_(1.0)
|
| 367 |
+
# module.weight.data.fill_(.01) # KL: Adapter change
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
@dataclass
|
| 371 |
+
class GPT2DoubleHeadsModelOutput(ModelOutput):
|
| 372 |
+
"""
|
| 373 |
+
Base class for outputs of models predicting if two sentences are consecutive or not.
|
| 374 |
+
Args:
|
| 375 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
| 376 |
+
Language modeling loss.
|
| 377 |
+
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
| 378 |
+
Multiple choice classification loss.
|
| 379 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
| 380 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 381 |
+
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
| 382 |
+
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
| 383 |
+
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
| 384 |
+
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
| 385 |
+
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
| 386 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
| 387 |
+
:obj:`past_key_values` input) to speed up sequential decoding.
|
| 388 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 389 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 390 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 391 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 392 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 393 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
| 394 |
+
sequence_length, sequence_length)`.
|
| 395 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 396 |
+
heads.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
loss: Optional[torch.FloatTensor] = None
|
| 400 |
+
mc_loss: Optional[torch.FloatTensor] = None
|
| 401 |
+
logits: torch.FloatTensor = None
|
| 402 |
+
mc_logits: torch.FloatTensor = None
|
| 403 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 404 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 405 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
GPT2_START_DOCSTRING = r"""
|
| 409 |
+
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
| 410 |
+
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
| 411 |
+
pruning heads etc.)
|
| 412 |
+
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
| 413 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
| 414 |
+
general usage and behavior.
|
| 415 |
+
Parameters:
|
| 416 |
+
config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
| 417 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 418 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
| 419 |
+
weights.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
| 423 |
+
Args:
|
| 424 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
| 425 |
+
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
| 426 |
+
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
| 427 |
+
sequence tokens in the vocabulary.
|
| 428 |
+
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
| 429 |
+
passed as ``input_ids``.
|
| 430 |
+
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
|
| 431 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
| 432 |
+
details.
|
| 433 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
| 434 |
+
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
| 435 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
| 436 |
+
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
| 437 |
+
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
| 438 |
+
computed.
|
| 439 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 440 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
| 441 |
+
- 1 for tokens that are **not masked**,
|
| 442 |
+
- 0 for tokens that are **masked**.
|
| 443 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 444 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
|
| 445 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
| 446 |
+
1]``:
|
| 447 |
+
- 0 corresponds to a `sentence A` token,
|
| 448 |
+
- 1 corresponds to a `sentence B` token.
|
| 449 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
| 450 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 451 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
| 452 |
+
config.max_position_embeddings - 1]``.
|
| 453 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
| 454 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
| 455 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
| 456 |
+
- 1 indicates the head is **not masked**,
|
| 457 |
+
- 0 indicates the head is **masked**.
|
| 458 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 459 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
| 460 |
+
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
| 461 |
+
vectors than the model's internal embedding lookup matrix.
|
| 462 |
+
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
|
| 463 |
+
:obj:`past_key_values`).
|
| 464 |
+
use_cache (:obj:`bool`, `optional`):
|
| 465 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 466 |
+
decoding (see :obj:`past_key_values`).
|
| 467 |
+
output_attentions (:obj:`bool`, `optional`):
|
| 468 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
| 469 |
+
tensors for more detail.
|
| 470 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
| 471 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
| 472 |
+
more detail.
|
| 473 |
+
return_dict (:obj:`bool`, `optional`):
|
| 474 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
| 475 |
+
"""
|
| 476 |
+
PARALLELIZE_DOCSTRING = r"""
|
| 477 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
| 478 |
+
it will evenly distribute blocks across all devices.
|
| 479 |
+
Args:
|
| 480 |
+
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
|
| 481 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
| 482 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
| 483 |
+
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
| 484 |
+
following number of attention modules:
|
| 485 |
+
- gpt2: 12
|
| 486 |
+
- gpt2-medium: 24
|
| 487 |
+
- gpt2-large: 36
|
| 488 |
+
- gpt2-xl: 48
|
| 489 |
+
Example::
|
| 490 |
+
# Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
| 491 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
|
| 492 |
+
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
| 493 |
+
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
| 494 |
+
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
| 495 |
+
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
|
| 496 |
+
model.parallelize(device_map)
|
| 497 |
+
"""
|
| 498 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
| 499 |
+
Moves the model to cpu from a model parallel state.
|
| 500 |
+
Example::
|
| 501 |
+
# On a 4 GPU machine with gpt2-large:
|
| 502 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
|
| 503 |
+
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
|
| 504 |
+
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
| 505 |
+
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
| 506 |
+
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
|
| 507 |
+
model.parallelize(device_map) # Splits the model across several devices
|
| 508 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
@add_start_docstrings(
|
| 513 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
| 514 |
+
GPT2_START_DOCSTRING,
|
| 515 |
+
)
|
| 516 |
+
class GPT2Model(GPT2PreTrainedModel):
|
| 517 |
+
def __init__(self, config):
|
| 518 |
+
super().__init__(config)
|
| 519 |
+
|
| 520 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 521 |
+
# self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 522 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 523 |
+
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
| 524 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 525 |
+
|
| 526 |
+
self.init_weights()
|
| 527 |
+
# Model parallel
|
| 528 |
+
self.model_parallel = False
|
| 529 |
+
self.device_map = None
|
| 530 |
+
|
| 531 |
+
self.use_layers = None
|
| 532 |
+
|
| 533 |
+
def set_layers(self, num_layers):
|
| 534 |
+
assert 1 <= num_layers <= len(self.h)
|
| 535 |
+
if num_layers is not None:
|
| 536 |
+
num_layers -= 1
|
| 537 |
+
self.use_layers = num_layers
|
| 538 |
+
|
| 539 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
| 540 |
+
def parallelize(self, device_map=None):
|
| 541 |
+
# Check validity of device_map
|
| 542 |
+
self.device_map = (
|
| 543 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
| 544 |
+
)
|
| 545 |
+
assert_device_map(self.device_map, len(self.h))
|
| 546 |
+
self.model_parallel = True
|
| 547 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
| 548 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
| 549 |
+
self.wte = self.wte.to(self.first_device)
|
| 550 |
+
self.wpe = self.wpe.to(self.first_device)
|
| 551 |
+
# Load onto devices
|
| 552 |
+
for k, v in self.device_map.items():
|
| 553 |
+
for block in v:
|
| 554 |
+
cuda_device = "cuda:" + str(k)
|
| 555 |
+
self.h[block] = self.h[block].to(cuda_device)
|
| 556 |
+
# ln_f to last
|
| 557 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
| 558 |
+
|
| 559 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
| 560 |
+
def deparallelize(self):
|
| 561 |
+
self.model_parallel = False
|
| 562 |
+
self.device_map = None
|
| 563 |
+
self.first_device = "cpu"
|
| 564 |
+
self.last_device = "cpu"
|
| 565 |
+
self.wte = self.wte.to("cpu")
|
| 566 |
+
self.wpe = self.wpe.to("cpu")
|
| 567 |
+
for index in range(len(self.h)):
|
| 568 |
+
self.h[index] = self.h[index].to("cpu")
|
| 569 |
+
self.ln_f = self.ln_f.to("cpu")
|
| 570 |
+
torch.cuda.empty_cache()
|
| 571 |
+
|
| 572 |
+
def get_input_embeddings(self):
|
| 573 |
+
return self.wte
|
| 574 |
+
|
| 575 |
+
def set_input_embeddings(self, new_embeddings):
|
| 576 |
+
self.wte = new_embeddings
|
| 577 |
+
|
| 578 |
+
def _prune_heads(self, heads_to_prune):
|
| 579 |
+
"""
|
| 580 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 581 |
+
"""
|
| 582 |
+
for layer, heads in heads_to_prune.items():
|
| 583 |
+
self.h[layer].attn.prune_heads(heads)
|
| 584 |
+
|
| 585 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
| 586 |
+
@add_code_sample_docstrings(
|
| 587 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 588 |
+
checkpoint="gpt2",
|
| 589 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
| 590 |
+
config_class=_CONFIG_FOR_DOC,
|
| 591 |
+
)
|
| 592 |
+
def forward(
|
| 593 |
+
self,
|
| 594 |
+
input_ids=None,
|
| 595 |
+
past_key_values=None,
|
| 596 |
+
attention_mask=None,
|
| 597 |
+
token_type_ids=None,
|
| 598 |
+
position_ids=None,
|
| 599 |
+
head_mask=None,
|
| 600 |
+
inputs_embeds=None,
|
| 601 |
+
encoder_hidden_states=None,
|
| 602 |
+
encoder_attention_mask=None,
|
| 603 |
+
use_cache=None,
|
| 604 |
+
output_attentions=None,
|
| 605 |
+
output_hidden_states=None,
|
| 606 |
+
return_dict=None,
|
| 607 |
+
):
|
| 608 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 609 |
+
output_hidden_states = (
|
| 610 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 611 |
+
)
|
| 612 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 613 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 614 |
+
|
| 615 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 616 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 617 |
+
elif input_ids is not None:
|
| 618 |
+
input_shape = input_ids.size()
|
| 619 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 620 |
+
batch_size = input_ids.shape[0]
|
| 621 |
+
elif inputs_embeds is not None:
|
| 622 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 623 |
+
batch_size = inputs_embeds.shape[0]
|
| 624 |
+
else:
|
| 625 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 626 |
+
|
| 627 |
+
if token_type_ids is not None:
|
| 628 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
| 629 |
+
if position_ids is not None:
|
| 630 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
| 631 |
+
|
| 632 |
+
if past_key_values is None:
|
| 633 |
+
past_length = 0
|
| 634 |
+
past_key_values = [None] * len(self.h)
|
| 635 |
+
else:
|
| 636 |
+
past_length = past_key_values[0][0].size(-2)
|
| 637 |
+
if position_ids is None:
|
| 638 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 639 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
| 640 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
| 641 |
+
|
| 642 |
+
# Attention mask.
|
| 643 |
+
if attention_mask is not None:
|
| 644 |
+
assert batch_size > 0, "batch_size has to be defined and > 0"
|
| 645 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
| 646 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 647 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 648 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 649 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 650 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 651 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 652 |
+
|
| 653 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 654 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 655 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 656 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 657 |
+
# effectively the same as removing these entirely.
|
| 658 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 659 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
| 660 |
+
|
| 661 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
| 662 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 663 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
| 664 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 665 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 666 |
+
if encoder_attention_mask is None:
|
| 667 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 668 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 669 |
+
else:
|
| 670 |
+
encoder_attention_mask = None
|
| 671 |
+
|
| 672 |
+
# Prepare head mask if needed
|
| 673 |
+
# 1.0 in head_mask indicate we keep the head
|
| 674 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 675 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
| 676 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 677 |
+
|
| 678 |
+
if inputs_embeds is None:
|
| 679 |
+
inputs_embeds = self.wte(input_ids)
|
| 680 |
+
# position_embeds = self.wpe(position_ids)
|
| 681 |
+
hidden_states = inputs_embeds # + position_embeds
|
| 682 |
+
|
| 683 |
+
if token_type_ids is not None:
|
| 684 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 685 |
+
hidden_states = hidden_states + token_type_embeds
|
| 686 |
+
|
| 687 |
+
hidden_states = self.drop(hidden_states)
|
| 688 |
+
|
| 689 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 690 |
+
|
| 691 |
+
presents = () if use_cache else None
|
| 692 |
+
all_self_attentions = () if output_attentions else None
|
| 693 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 694 |
+
all_hidden_states = () if output_hidden_states else None
|
| 695 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 696 |
+
|
| 697 |
+
if self.use_layers is not None and i >= self.use_layers:
|
| 698 |
+
break
|
| 699 |
+
|
| 700 |
+
# Model parallel
|
| 701 |
+
if self.model_parallel:
|
| 702 |
+
torch.cuda.set_device(hidden_states.device)
|
| 703 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
| 704 |
+
if layer_past is not None:
|
| 705 |
+
layer_past = layer_past.to(hidden_states.device)
|
| 706 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
| 707 |
+
if attention_mask is not None:
|
| 708 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 709 |
+
if isinstance(head_mask, torch.Tensor):
|
| 710 |
+
head_mask = head_mask.to(hidden_states.device)
|
| 711 |
+
if output_hidden_states:
|
| 712 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 713 |
+
|
| 714 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
| 715 |
+
|
| 716 |
+
def create_custom_forward(module):
|
| 717 |
+
def custom_forward(*inputs):
|
| 718 |
+
# checkpointing only works with tuple returns, not with lists
|
| 719 |
+
return tuple(output for output in module(*inputs, use_cache, output_attentions))
|
| 720 |
+
|
| 721 |
+
return custom_forward
|
| 722 |
+
|
| 723 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
| 724 |
+
create_custom_forward(block),
|
| 725 |
+
hidden_states,
|
| 726 |
+
layer_past,
|
| 727 |
+
attention_mask,
|
| 728 |
+
head_mask[i],
|
| 729 |
+
encoder_hidden_states,
|
| 730 |
+
encoder_attention_mask,
|
| 731 |
+
)
|
| 732 |
+
else:
|
| 733 |
+
outputs = block(
|
| 734 |
+
hidden_states,
|
| 735 |
+
layer_past=layer_past,
|
| 736 |
+
attention_mask=attention_mask,
|
| 737 |
+
head_mask=head_mask[i],
|
| 738 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 739 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 740 |
+
use_cache=use_cache,
|
| 741 |
+
output_attentions=output_attentions,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
hidden_states, present = outputs[:2]
|
| 745 |
+
if use_cache is True:
|
| 746 |
+
presents = presents + (present,)
|
| 747 |
+
|
| 748 |
+
if output_attentions:
|
| 749 |
+
all_self_attentions = all_self_attentions + (outputs[2],)
|
| 750 |
+
if self.config.add_cross_attention:
|
| 751 |
+
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
| 752 |
+
|
| 753 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
| 754 |
+
if self.model_parallel:
|
| 755 |
+
for k, v in self.device_map.items():
|
| 756 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
| 757 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
| 758 |
+
|
| 759 |
+
hidden_states = self.ln_f(hidden_states)
|
| 760 |
+
|
| 761 |
+
hidden_states = hidden_states.view(*output_shape)
|
| 762 |
+
# Add last hidden state
|
| 763 |
+
if output_hidden_states:
|
| 764 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 765 |
+
|
| 766 |
+
if not return_dict:
|
| 767 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
| 768 |
+
|
| 769 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 770 |
+
last_hidden_state=hidden_states,
|
| 771 |
+
past_key_values=presents,
|
| 772 |
+
hidden_states=all_hidden_states,
|
| 773 |
+
attentions=all_self_attentions,
|
| 774 |
+
cross_attentions=all_cross_attentions,
|
| 775 |
+
)
|
third_party/AnyBimanual/agents/peract_bc/visual_aligner.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class VisualAligner(nn.Module):
|
| 6 |
+
def __init__(self, input_dim=128, hidden_dim=256, mask_dim=128):
|
| 7 |
+
super(VisualAligner, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 10 |
+
|
| 11 |
+
self.conv_res1 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 12 |
+
self.conv_res2 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 13 |
+
|
| 14 |
+
self.conv2_right = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
|
| 15 |
+
self.conv2_left = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
|
| 16 |
+
|
| 17 |
+
self.activation = nn.ReLU()
|
| 18 |
+
|
| 19 |
+
def forward(self, ins):
|
| 20 |
+
ins = ins.transpose(1, 2)
|
| 21 |
+
|
| 22 |
+
features = self.activation(self.conv1(ins))
|
| 23 |
+
|
| 24 |
+
residual = features
|
| 25 |
+
features = self.activation(self.conv_res1(features))
|
| 26 |
+
features = self.conv_res2(features)
|
| 27 |
+
features = features + residual
|
| 28 |
+
|
| 29 |
+
mask_right = self.activation(self.conv2_right(features))
|
| 30 |
+
mask_left = self.activation(self.conv2_left(features))
|
| 31 |
+
|
| 32 |
+
mask_right = mask_right.transpose(1, 2)
|
| 33 |
+
mask_left = mask_left.transpose(1, 2)
|
| 34 |
+
ins = ins.transpose(1, 2)
|
| 35 |
+
|
| 36 |
+
masked_ins1 = ins * mask_right
|
| 37 |
+
masked_ins2 = ins * mask_left
|
| 38 |
+
|
| 39 |
+
return masked_ins1, masked_ins2
|
third_party/AnyBimanual/agents/peract_bimanual/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
import agents.peract_bimanual.launch_utils
|
third_party/AnyBimanual/agents/peract_bimanual/launch_utils.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Adapted from ARM
|
| 2 |
+
# Source: https://github.com/stepjam/ARM
|
| 3 |
+
# License: https://github.com/stepjam/ARM/LICENSE
|
| 4 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 5 |
+
|
| 6 |
+
from agents.peract_bimanual.perceiver_lang_io import PerceiverVoxelLangEncoder
|
| 7 |
+
from agents.peract_bimanual.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 8 |
+
from agents.peract_bimanual.qattention_stack_agent import QAttentionStackAgent
|
| 9 |
+
from agents.peract_bimanual.skill_manager import SkillManager
|
| 10 |
+
from agents.peract_bimanual.visual_aligner import VisualAligner
|
| 11 |
+
from omegaconf import DictConfig
|
| 12 |
+
import pickle
|
| 13 |
+
import torch
|
| 14 |
+
import os
|
| 15 |
+
def create_agent(cfg: DictConfig):
|
| 16 |
+
depth_0bounds = cfg.rlbench.scene_bounds
|
| 17 |
+
cam_resolution = cfg.rlbench.camera_resolution
|
| 18 |
+
|
| 19 |
+
num_rotation_classes = int(360.0 // cfg.method.rotation_resolution)
|
| 20 |
+
qattention_agents = []
|
| 21 |
+
|
| 22 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 23 |
+
pkl_path = os.path.join(current_dir, "../../lang_token.pkl")
|
| 24 |
+
pkl_path = os.path.abspath(pkl_path)
|
| 25 |
+
with open(pkl_path, "rb") as f:
|
| 26 |
+
embeddings_dict = pickle.load(f)
|
| 27 |
+
flattened_embeddings = []
|
| 28 |
+
for key in embeddings_dict.keys():
|
| 29 |
+
embedding = torch.tensor(embeddings_dict[key])
|
| 30 |
+
flattened_embedding = embedding.view(-1)
|
| 31 |
+
flattened_embeddings.append(flattened_embedding)
|
| 32 |
+
embeddings_matrix = torch.stack(flattened_embeddings)
|
| 33 |
+
|
| 34 |
+
skill_manager = SkillManager(num_classes=18,embedding_matrix=embeddings_matrix)
|
| 35 |
+
visual_aligner = VisualAligner()
|
| 36 |
+
for depth, vox_size in enumerate(cfg.method.voxel_sizes):
|
| 37 |
+
last = depth == len(cfg.method.voxel_sizes) - 1
|
| 38 |
+
if cfg.framework.use_skill:
|
| 39 |
+
perceiver_encoder = PerceiverVoxelLangEncoder(
|
| 40 |
+
depth=cfg.method.transformer_depth,
|
| 41 |
+
iterations=cfg.method.transformer_iterations,
|
| 42 |
+
voxel_size=vox_size,
|
| 43 |
+
initial_dim=3 + 3 + 1 + 3,
|
| 44 |
+
low_dim_size=cfg.method.low_dim_size,
|
| 45 |
+
layer=depth,
|
| 46 |
+
num_rotation_classes=num_rotation_classes if last else 0,
|
| 47 |
+
num_grip_classes=2 if last else 0,
|
| 48 |
+
num_collision_classes=2 if last else 0,
|
| 49 |
+
input_axis=3,
|
| 50 |
+
num_latents=cfg.method.num_latents,
|
| 51 |
+
latent_dim=cfg.method.latent_dim,
|
| 52 |
+
cross_heads=cfg.method.cross_heads,
|
| 53 |
+
latent_heads=cfg.method.latent_heads,
|
| 54 |
+
cross_dim_head=cfg.method.cross_dim_head,
|
| 55 |
+
latent_dim_head=cfg.method.latent_dim_head,
|
| 56 |
+
weight_tie_layers=False,
|
| 57 |
+
activation=cfg.method.activation,
|
| 58 |
+
pos_encoding_with_lang=cfg.method.pos_encoding_with_lang,
|
| 59 |
+
input_dropout=cfg.method.input_dropout,
|
| 60 |
+
attn_dropout=cfg.method.attn_dropout,
|
| 61 |
+
decoder_dropout=cfg.method.decoder_dropout,
|
| 62 |
+
lang_fusion_type=cfg.method.lang_fusion_type,
|
| 63 |
+
voxel_patch_size=cfg.method.voxel_patch_size,
|
| 64 |
+
voxel_patch_stride=cfg.method.voxel_patch_stride,
|
| 65 |
+
no_skip_connection=cfg.method.no_skip_connection,
|
| 66 |
+
no_perceiver=cfg.method.no_perceiver,
|
| 67 |
+
no_language=cfg.method.no_language,
|
| 68 |
+
final_dim=cfg.method.final_dim,
|
| 69 |
+
anybimanual=cfg.framework.anybimanual,
|
| 70 |
+
skill_manager = skill_manager,
|
| 71 |
+
visual_aligner = visual_aligner
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
qattention_agent = QAttentionPerActBCAgent(
|
| 75 |
+
layer=depth,
|
| 76 |
+
coordinate_bounds=depth_0bounds,
|
| 77 |
+
perceiver_encoder=perceiver_encoder,
|
| 78 |
+
camera_names=cfg.rlbench.cameras,
|
| 79 |
+
voxel_size=vox_size,
|
| 80 |
+
bounds_offset=cfg.method.bounds_offset[depth - 1] if depth > 0 else None,
|
| 81 |
+
image_crop_size=cfg.method.image_crop_size,
|
| 82 |
+
lr=cfg.method.lr,
|
| 83 |
+
training_iterations=cfg.framework.training_iterations,
|
| 84 |
+
lr_scheduler=cfg.method.lr_scheduler,
|
| 85 |
+
num_warmup_steps=cfg.method.num_warmup_steps,
|
| 86 |
+
trans_loss_weight=cfg.method.trans_loss_weight,
|
| 87 |
+
rot_loss_weight=cfg.method.rot_loss_weight,
|
| 88 |
+
grip_loss_weight=cfg.method.grip_loss_weight,
|
| 89 |
+
collision_loss_weight=cfg.method.collision_loss_weight,
|
| 90 |
+
include_low_dim_state=True,
|
| 91 |
+
image_resolution=cam_resolution,
|
| 92 |
+
batch_size=cfg.replay.batch_size,
|
| 93 |
+
voxel_feature_size=3,
|
| 94 |
+
lambda_weight_l2=cfg.method.lambda_weight_l2,
|
| 95 |
+
num_rotation_classes=num_rotation_classes,
|
| 96 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 97 |
+
transform_augmentation=cfg.method.transform_augmentation.apply_se3,
|
| 98 |
+
transform_augmentation_xyz=cfg.method.transform_augmentation.aug_xyz,
|
| 99 |
+
transform_augmentation_rpy=cfg.method.transform_augmentation.aug_rpy,
|
| 100 |
+
transform_augmentation_rot_resolution=cfg.method.transform_augmentation.aug_rot_resolution,
|
| 101 |
+
optimizer_type=cfg.method.optimizer,
|
| 102 |
+
num_devices=cfg.ddp.num_devices,
|
| 103 |
+
anybimanual=cfg.framework.anybimanual,
|
| 104 |
+
load_exists_weights = cfg.framework.load_existing_weights,
|
| 105 |
+
frozen = cfg.framework.frozen,
|
| 106 |
+
cfg = cfg,
|
| 107 |
+
aug_type=cfg.framework.augmentation_type,
|
| 108 |
+
)
|
| 109 |
+
qattention_agents.append(qattention_agent)
|
| 110 |
+
|
| 111 |
+
rotation_agent = QAttentionStackAgent(
|
| 112 |
+
qattention_agents=qattention_agents,
|
| 113 |
+
rotation_resolution=cfg.method.rotation_resolution,
|
| 114 |
+
camera_names=cfg.rlbench.cameras,
|
| 115 |
+
)
|
| 116 |
+
preprocess_agent = PreprocessAgent(pose_agent=rotation_agent)
|
| 117 |
+
return preprocess_agent
|
third_party/AnyBimanual/agents/peract_bimanual/perceiver_lang_io.py
ADDED
|
@@ -0,0 +1,628 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Perceiver IO implementation adpated for manipulation
|
| 2 |
+
# Source: https://github.com/lucidrains/perceiver-pytorch
|
| 3 |
+
# License: https://github.com/lucidrains/perceiver-pytorch/blob/main/LICENSE
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from einops import repeat
|
| 10 |
+
import numpy as np
|
| 11 |
+
from perceiver_pytorch.perceiver_pytorch import cache_fn
|
| 12 |
+
from perceiver_pytorch.perceiver_pytorch import PreNorm, FeedForward, Attention
|
| 13 |
+
|
| 14 |
+
from helpers.network_utils import (
|
| 15 |
+
DenseBlock,
|
| 16 |
+
SpatialSoftmax3D,
|
| 17 |
+
Conv3DBlock,
|
| 18 |
+
Conv3DUpsampleBlock,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def symmetric_kl_divergence(left, right):
|
| 22 |
+
eps = 1e-2
|
| 23 |
+
left_prob = torch.clamp(F.log_softmax(left, dim=-1), min=-10, max=10)
|
| 24 |
+
right_prob = torch.clamp(F.log_softmax(right, dim=-1), min=-10, max=10)
|
| 25 |
+
|
| 26 |
+
kl_left_to_right = F.kl_div(left_prob, right_prob.exp(), reduction="batchmean")*eps
|
| 27 |
+
kl_right_to_left = F.kl_div(right_prob, left_prob.exp(), reduction="batchmean")*eps
|
| 28 |
+
|
| 29 |
+
symmetric_kl = -(kl_left_to_right + kl_right_to_left) / 2.0
|
| 30 |
+
return symmetric_kl
|
| 31 |
+
|
| 32 |
+
def l1_norm(tensor):
|
| 33 |
+
return torch.sum(torch.abs(tensor)) + 1e-4 * torch.norm(tensor)
|
| 34 |
+
|
| 35 |
+
def l2_1_norm(tensor):
|
| 36 |
+
l2_norm_per_skill = torch.norm(tensor, dim=-1)
|
| 37 |
+
return torch.sum(l2_norm_per_skill)
|
| 38 |
+
|
| 39 |
+
torch.autograd.set_detect_anomaly(True)
|
| 40 |
+
# PerceiverIO adapted for 6-DoF manipulation
|
| 41 |
+
class PerceiverVoxelLangEncoder(nn.Module):
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
depth, # number of self-attention layers
|
| 45 |
+
iterations, # number cross-attention iterations (PerceiverIO uses just 1)
|
| 46 |
+
voxel_size, # N voxels per side (size: N*N*N)
|
| 47 |
+
initial_dim, # 10 dimensions - dimension of the input sequence to be encoded
|
| 48 |
+
low_dim_size, # 4 dimensions - proprioception: {gripper_open, left_finger, right_finger, timestep}
|
| 49 |
+
layer=0,
|
| 50 |
+
num_rotation_classes=72, # 5 degree increments (5*72=360) for each of the 3-axis
|
| 51 |
+
num_grip_classes=2, # open or not open
|
| 52 |
+
num_collision_classes=2, # collisions allowed or not allowed
|
| 53 |
+
input_axis=3, # 3D tensors have 3 axes
|
| 54 |
+
num_latents=512, # number of latent vectors
|
| 55 |
+
im_channels=64, # intermediate channel size
|
| 56 |
+
latent_dim=512, # dimensions of latent vectors
|
| 57 |
+
cross_heads=1, # number of cross-attention heads
|
| 58 |
+
latent_heads=8, # number of latent heads
|
| 59 |
+
cross_dim_head=64,
|
| 60 |
+
latent_dim_head=64,
|
| 61 |
+
activation="relu",
|
| 62 |
+
weight_tie_layers=False,
|
| 63 |
+
pos_encoding_with_lang=True,
|
| 64 |
+
input_dropout=0.1,
|
| 65 |
+
attn_dropout=0.1,
|
| 66 |
+
decoder_dropout=0.0,
|
| 67 |
+
lang_fusion_type="seq",
|
| 68 |
+
voxel_patch_size=9,
|
| 69 |
+
voxel_patch_stride=8,
|
| 70 |
+
no_skip_connection=False,
|
| 71 |
+
no_perceiver=False,
|
| 72 |
+
no_language=False,
|
| 73 |
+
final_dim=64,
|
| 74 |
+
anybimanual=False,
|
| 75 |
+
skill_manager=None,
|
| 76 |
+
visual_aligner=None,
|
| 77 |
+
):
|
| 78 |
+
super().__init__()
|
| 79 |
+
self.depth = depth
|
| 80 |
+
self.layer = layer
|
| 81 |
+
self.init_dim = int(initial_dim)
|
| 82 |
+
self.iterations = iterations
|
| 83 |
+
self.input_axis = input_axis
|
| 84 |
+
self.voxel_size = voxel_size
|
| 85 |
+
self.low_dim_size = low_dim_size
|
| 86 |
+
self.im_channels = im_channels
|
| 87 |
+
self.pos_encoding_with_lang = pos_encoding_with_lang
|
| 88 |
+
self.lang_fusion_type = lang_fusion_type
|
| 89 |
+
self.voxel_patch_size = voxel_patch_size
|
| 90 |
+
self.voxel_patch_stride = voxel_patch_stride
|
| 91 |
+
self.num_rotation_classes = num_rotation_classes
|
| 92 |
+
self.num_grip_classes = num_grip_classes
|
| 93 |
+
self.num_collision_classes = num_collision_classes
|
| 94 |
+
self.final_dim = final_dim
|
| 95 |
+
self.input_dropout = input_dropout
|
| 96 |
+
self.attn_dropout = attn_dropout
|
| 97 |
+
self.decoder_dropout = decoder_dropout
|
| 98 |
+
self.no_skip_connection = no_skip_connection
|
| 99 |
+
self.no_perceiver = no_perceiver
|
| 100 |
+
self.no_language = no_language
|
| 101 |
+
self.anybimanual = anybimanual
|
| 102 |
+
self.skill_manager = skill_manager
|
| 103 |
+
self.visual_aligner = visual_aligner
|
| 104 |
+
# patchified input dimensions
|
| 105 |
+
spatial_size = voxel_size // self.voxel_patch_stride # 100/5 = 20
|
| 106 |
+
# 64 voxel features + 64 proprio features (+ 64 lang goal features if concattenated)
|
| 107 |
+
self.input_dim_before_seq = (
|
| 108 |
+
self.im_channels * 3
|
| 109 |
+
if self.lang_fusion_type == "concat"
|
| 110 |
+
else self.im_channels * 2
|
| 111 |
+
)
|
| 112 |
+
if self.anybimanual:
|
| 113 |
+
self.input_dim_before_seq_ = self.input_dim_before_seq*2
|
| 114 |
+
else:
|
| 115 |
+
self.input_dim_before_seq_ = self.input_dim_before_seq
|
| 116 |
+
# CLIP language feature dimensions
|
| 117 |
+
if self.anybimanual:
|
| 118 |
+
lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 154
|
| 119 |
+
else:
|
| 120 |
+
lang_feat_dim, lang_emb_dim, lang_max_seq_len = 1024, 512, 77
|
| 121 |
+
|
| 122 |
+
# learnable positional encoding
|
| 123 |
+
# peract2 pos_encoding_with_lang = True / peract = Falses?
|
| 124 |
+
if self.pos_encoding_with_lang:
|
| 125 |
+
self.pos_encoding = nn.Parameter(
|
| 126 |
+
torch.randn(
|
| 127 |
+
1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
# assert self.lang_fusion_type == 'concat', 'Only concat is supported for pos encoding without lang.'
|
| 132 |
+
self.pos_encoding = nn.Parameter(
|
| 133 |
+
torch.randn(
|
| 134 |
+
1,
|
| 135 |
+
spatial_size,
|
| 136 |
+
spatial_size,
|
| 137 |
+
spatial_size,
|
| 138 |
+
self.input_dim_before_seq,
|
| 139 |
+
)
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
# voxel input preprocessing 1x1 conv encoder
|
| 143 |
+
self.input_preprocess = Conv3DBlock(
|
| 144 |
+
self.init_dim,
|
| 145 |
+
self.im_channels,
|
| 146 |
+
kernel_sizes=1,
|
| 147 |
+
strides=1,
|
| 148 |
+
norm=None,
|
| 149 |
+
activation=activation,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# patchify conv
|
| 153 |
+
self.patchify = Conv3DBlock(
|
| 154 |
+
self.input_preprocess.out_channels,
|
| 155 |
+
self.im_channels,
|
| 156 |
+
kernel_sizes=self.voxel_patch_size,
|
| 157 |
+
strides=self.voxel_patch_stride,
|
| 158 |
+
norm=None,
|
| 159 |
+
activation=activation,
|
| 160 |
+
)
|
| 161 |
+
# language preprocess
|
| 162 |
+
if self.lang_fusion_type == "concat":
|
| 163 |
+
self.lang_preprocess = nn.Linear(lang_feat_dim, self.im_channels)
|
| 164 |
+
elif self.lang_fusion_type == "seq":
|
| 165 |
+
self.lang_preprocess = nn.Linear(lang_emb_dim, self.im_channels * 2)
|
| 166 |
+
|
| 167 |
+
# proprioception
|
| 168 |
+
if self.low_dim_size > 0:
|
| 169 |
+
self.proprio_preprocess = DenseBlock(
|
| 170 |
+
self.low_dim_size,
|
| 171 |
+
self.im_channels,
|
| 172 |
+
norm=None,
|
| 173 |
+
activation=activation,
|
| 174 |
+
)
|
| 175 |
+
# pooling functions
|
| 176 |
+
self.local_maxp = nn.MaxPool3d(3, 2, padding=1)
|
| 177 |
+
self.global_maxp = nn.AdaptiveMaxPool3d(1)
|
| 178 |
+
|
| 179 |
+
# 1st 3D softmax
|
| 180 |
+
self.ss0 = SpatialSoftmax3D(
|
| 181 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 182 |
+
)
|
| 183 |
+
flat_size = self.im_channels * 4
|
| 184 |
+
|
| 185 |
+
# latent vectors (that are randomly initialized)
|
| 186 |
+
self.latents = nn.Parameter(torch.randn(num_latents, latent_dim))
|
| 187 |
+
|
| 188 |
+
if self.anybimanual:
|
| 189 |
+
self.cross_attend_blocks = nn.ModuleList(
|
| 190 |
+
[
|
| 191 |
+
PreNorm(
|
| 192 |
+
latent_dim,
|
| 193 |
+
Attention(
|
| 194 |
+
latent_dim,
|
| 195 |
+
self.input_dim_before_seq_,
|
| 196 |
+
heads=cross_heads,
|
| 197 |
+
dim_head=cross_dim_head,
|
| 198 |
+
dropout=input_dropout,
|
| 199 |
+
),
|
| 200 |
+
context_dim=self.input_dim_before_seq_,
|
| 201 |
+
),
|
| 202 |
+
PreNorm(
|
| 203 |
+
latent_dim,
|
| 204 |
+
Attention(
|
| 205 |
+
latent_dim,
|
| 206 |
+
self.input_dim_before_seq_,
|
| 207 |
+
heads=cross_heads,
|
| 208 |
+
dim_head=cross_dim_head,
|
| 209 |
+
dropout=input_dropout,
|
| 210 |
+
),
|
| 211 |
+
context_dim=self.input_dim_before_seq_,
|
| 212 |
+
),
|
| 213 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 214 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 215 |
+
]
|
| 216 |
+
)
|
| 217 |
+
else:
|
| 218 |
+
# encoder cross attention
|
| 219 |
+
self.cross_attend_blocks = nn.ModuleList(
|
| 220 |
+
[
|
| 221 |
+
PreNorm(
|
| 222 |
+
latent_dim,
|
| 223 |
+
Attention(
|
| 224 |
+
latent_dim,
|
| 225 |
+
self.input_dim_before_seq_,
|
| 226 |
+
heads=cross_heads,
|
| 227 |
+
dim_head=cross_dim_head,
|
| 228 |
+
dropout=input_dropout,
|
| 229 |
+
),
|
| 230 |
+
context_dim=self.input_dim_before_seq_,
|
| 231 |
+
),
|
| 232 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 233 |
+
PreNorm(latent_dim, FeedForward(latent_dim)),
|
| 234 |
+
]
|
| 235 |
+
)
|
| 236 |
+
|
| 237 |
+
get_latent_attn = lambda: PreNorm(
|
| 238 |
+
latent_dim,
|
| 239 |
+
Attention(
|
| 240 |
+
latent_dim,
|
| 241 |
+
heads=latent_heads,
|
| 242 |
+
dim_head=latent_dim_head,
|
| 243 |
+
dropout=attn_dropout,
|
| 244 |
+
),
|
| 245 |
+
)
|
| 246 |
+
get_latent_ff = lambda: PreNorm(latent_dim, FeedForward(latent_dim))
|
| 247 |
+
get_latent_attn, get_latent_ff = map(cache_fn, (get_latent_attn, get_latent_ff))
|
| 248 |
+
|
| 249 |
+
# self attention layers
|
| 250 |
+
self.layers = nn.ModuleList([])
|
| 251 |
+
cache_args = {"_cache": weight_tie_layers}
|
| 252 |
+
|
| 253 |
+
for i in range(depth):
|
| 254 |
+
self.layers.append(
|
| 255 |
+
nn.ModuleList(
|
| 256 |
+
[get_latent_attn(**cache_args), get_latent_ff(**cache_args),
|
| 257 |
+
get_latent_attn(**cache_args), get_latent_ff(**cache_args)]
|
| 258 |
+
)
|
| 259 |
+
)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
self.combined_latent_attn = get_latent_attn(**cache_args)
|
| 263 |
+
self.combined_latent_ff = get_latent_ff(**cache_args)
|
| 264 |
+
|
| 265 |
+
|
| 266 |
+
# decoder cross attention
|
| 267 |
+
self.decoder_cross_attn_right = PreNorm(
|
| 268 |
+
self.input_dim_before_seq_,
|
| 269 |
+
Attention(
|
| 270 |
+
self.input_dim_before_seq_,
|
| 271 |
+
latent_dim,
|
| 272 |
+
heads=cross_heads,
|
| 273 |
+
dim_head=cross_dim_head,
|
| 274 |
+
dropout=decoder_dropout,
|
| 275 |
+
),
|
| 276 |
+
context_dim=latent_dim,
|
| 277 |
+
)
|
| 278 |
+
|
| 279 |
+
self.decoder_cross_attn_left = PreNorm(
|
| 280 |
+
self.input_dim_before_seq_,
|
| 281 |
+
Attention(
|
| 282 |
+
self.input_dim_before_seq_,
|
| 283 |
+
latent_dim,
|
| 284 |
+
heads=cross_heads,
|
| 285 |
+
dim_head=cross_dim_head,
|
| 286 |
+
dropout=decoder_dropout,
|
| 287 |
+
),
|
| 288 |
+
context_dim=latent_dim,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
# upsample conv
|
| 292 |
+
self.up0 = Conv3DUpsampleBlock(
|
| 293 |
+
self.input_dim_before_seq_,
|
| 294 |
+
self.final_dim,
|
| 295 |
+
kernel_sizes=self.voxel_patch_size,
|
| 296 |
+
strides=self.voxel_patch_stride,
|
| 297 |
+
norm=None,
|
| 298 |
+
activation=activation,
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# 2nd 3D softmax
|
| 302 |
+
self.ss1 = SpatialSoftmax3D(
|
| 303 |
+
spatial_size, spatial_size, spatial_size, self.input_dim_before_seq_
|
| 304 |
+
)
|
| 305 |
+
|
| 306 |
+
flat_size += self.input_dim_before_seq_ * 4
|
| 307 |
+
|
| 308 |
+
# final 3D softmax
|
| 309 |
+
self.final = Conv3DBlock(
|
| 310 |
+
self.im_channels
|
| 311 |
+
if (self.no_perceiver or self.no_skip_connection)
|
| 312 |
+
else self.im_channels * 2,
|
| 313 |
+
self.im_channels,
|
| 314 |
+
kernel_sizes=3,
|
| 315 |
+
strides=1,
|
| 316 |
+
norm=None,
|
| 317 |
+
activation=activation,
|
| 318 |
+
)
|
| 319 |
+
|
| 320 |
+
self.right_trans_decoder = Conv3DBlock(
|
| 321 |
+
self.final_dim,
|
| 322 |
+
1,
|
| 323 |
+
kernel_sizes=3,
|
| 324 |
+
strides=1,
|
| 325 |
+
norm=None,
|
| 326 |
+
activation=None,
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
self.left_trans_decoder = Conv3DBlock(
|
| 330 |
+
self.final_dim,
|
| 331 |
+
1,
|
| 332 |
+
kernel_sizes=3,
|
| 333 |
+
strides=1,
|
| 334 |
+
norm=None,
|
| 335 |
+
activation=None,
|
| 336 |
+
)
|
| 337 |
+
|
| 338 |
+
# rotation, gripper, and collision MLP layers
|
| 339 |
+
if self.num_rotation_classes > 0:
|
| 340 |
+
self.ss_final = SpatialSoftmax3D(
|
| 341 |
+
self.voxel_size, self.voxel_size, self.voxel_size, self.im_channels
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
flat_size += self.im_channels * 4
|
| 345 |
+
|
| 346 |
+
self.right_dense0 = DenseBlock(flat_size, 256, None, activation)
|
| 347 |
+
self.right_dense1 = DenseBlock(256, self.final_dim, None, activation)
|
| 348 |
+
|
| 349 |
+
self.left_dense0 = DenseBlock(flat_size, 256, None, activation)
|
| 350 |
+
self.left_dense1 = DenseBlock(256, self.final_dim, None, activation)
|
| 351 |
+
|
| 352 |
+
self.right_rot_grip_collision_ff = DenseBlock(
|
| 353 |
+
self.final_dim,
|
| 354 |
+
self.num_rotation_classes * 3
|
| 355 |
+
+ self.num_grip_classes
|
| 356 |
+
+ self.num_collision_classes,
|
| 357 |
+
None,
|
| 358 |
+
None,
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
self.left_rot_grip_collision_ff = DenseBlock(
|
| 362 |
+
self.final_dim,
|
| 363 |
+
self.num_rotation_classes * 3
|
| 364 |
+
+ self.num_grip_classes
|
| 365 |
+
+ self.num_collision_classes,
|
| 366 |
+
None,
|
| 367 |
+
None,
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
def encode_text(self, x):
|
| 371 |
+
with torch.no_grad():
|
| 372 |
+
text_feat, text_emb = self._clip_rn50.encode_text_with_embeddings(x)
|
| 373 |
+
|
| 374 |
+
text_feat = text_feat.detach()
|
| 375 |
+
text_emb = text_emb.detach()
|
| 376 |
+
text_mask = torch.where(x == 0, x, 1) # [1, max_token_len]
|
| 377 |
+
return text_feat, text_emb
|
| 378 |
+
|
| 379 |
+
def forward(
|
| 380 |
+
self,
|
| 381 |
+
ins,
|
| 382 |
+
proprio,
|
| 383 |
+
lang_goal_emb,
|
| 384 |
+
lang_token_embs,
|
| 385 |
+
prev_layer_voxel_grid,
|
| 386 |
+
bounds,
|
| 387 |
+
prev_layer_bounds,
|
| 388 |
+
mask=None,
|
| 389 |
+
):
|
| 390 |
+
# preprocess input
|
| 391 |
+
ins_numpy = str(ins.cpu().numpy())
|
| 392 |
+
d0 = self.input_preprocess(ins) # [B,10,100,100,100] -> [B,64,100,100,100]
|
| 393 |
+
|
| 394 |
+
# aggregated features from 1st softmax and maxpool for MLP decoders
|
| 395 |
+
feats = [self.ss0(d0.contiguous()), self.global_maxp(d0).view(ins.shape[0], -1)]
|
| 396 |
+
|
| 397 |
+
# patchify input (5x5x5 patches)
|
| 398 |
+
ins = self.patchify(d0) # [B,64,100,100,100] -> [B,64,20,20,20]
|
| 399 |
+
|
| 400 |
+
b, c, d, h, w, device = *ins.shape, ins.device
|
| 401 |
+
axis = [d, h, w]
|
| 402 |
+
assert (
|
| 403 |
+
len(axis) == self.input_axis
|
| 404 |
+
), "input must have the same number of axis as input_axis"
|
| 405 |
+
# concat proprio
|
| 406 |
+
if self.low_dim_size > 0:
|
| 407 |
+
p = self.proprio_preprocess(proprio) # [B,8] -> [B,64]
|
| 408 |
+
p = p.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 409 |
+
ins = torch.cat([ins, p], dim=1) # [B,128,20,20,20]
|
| 410 |
+
|
| 411 |
+
# language ablation
|
| 412 |
+
if self.no_language:
|
| 413 |
+
lang_goal_emb = torch.zeros_like(lang_goal_emb)
|
| 414 |
+
lang_token_embs = torch.zeros_like(lang_token_embs)
|
| 415 |
+
|
| 416 |
+
# option 1: tile and concat lang goal to input
|
| 417 |
+
if self.lang_fusion_type == "concat":
|
| 418 |
+
lang_emb = lang_goal_emb
|
| 419 |
+
lang_emb = lang_emb.to(dtype=ins.dtype)
|
| 420 |
+
l = self.lang_preprocess(lang_emb)
|
| 421 |
+
l = l.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).repeat(1, 1, d, h, w)
|
| 422 |
+
ins = torch.cat([ins, l], dim=1)
|
| 423 |
+
|
| 424 |
+
# channel last
|
| 425 |
+
ins = rearrange(ins, "b d ... -> b ... d") # [B,20,20,20,128]
|
| 426 |
+
# add pos encoding to grid
|
| 427 |
+
if not self.pos_encoding_with_lang:
|
| 428 |
+
ins = ins + self.pos_encoding
|
| 429 |
+
|
| 430 |
+
######################## NOTE #############################
|
| 431 |
+
# NOTE: If you add positional encodings ^here the lang embs
|
| 432 |
+
# won't have positional encodings. I accidently forgot
|
| 433 |
+
# to turn this off for all the experiments in the paper.
|
| 434 |
+
# So I guess those models were using language embs
|
| 435 |
+
# as a bag of words :( But it doesn't matter much for
|
| 436 |
+
# RLBench tasks since we don't test for novel instructions
|
| 437 |
+
# at test time anyway. The recommend way is to add
|
| 438 |
+
# positional encodings to the final input sequence
|
| 439 |
+
# fed into the Perceiver Transformer, as done below
|
| 440 |
+
# (and also in the Colab tutorial).
|
| 441 |
+
###########################################################
|
| 442 |
+
|
| 443 |
+
# concat to channels of and flatten axis
|
| 444 |
+
queries_orig_shape = ins.shape
|
| 445 |
+
|
| 446 |
+
# rearrange input to be channel last
|
| 447 |
+
ins = rearrange(ins, "b ... d -> b (...) d") # [B,8000,128]
|
| 448 |
+
ins_wo_prev_layers = ins
|
| 449 |
+
|
| 450 |
+
# option 2: add lang token embs as a sequence
|
| 451 |
+
if self.anybimanual:
|
| 452 |
+
if self.lang_fusion_type == "seq":
|
| 453 |
+
l = self.lang_preprocess(lang_token_embs) # [B,77,512] -> [B,77,128]
|
| 454 |
+
mask_right, mask_left = self.visual_aligner(ins)
|
| 455 |
+
L_voxel = symmetric_kl_divergence(mask_left, mask_right)
|
| 456 |
+
right_skill = self.skill_manager(mask_right, l)
|
| 457 |
+
left_skill = self.skill_manager(mask_left, l)
|
| 458 |
+
right_skill = self.lang_preprocess(right_skill)
|
| 459 |
+
left_skill = self.lang_preprocess(left_skill)
|
| 460 |
+
L_skill = (
|
| 461 |
+
l1_norm(left_skill) + l1_norm(right_skill) +
|
| 462 |
+
0.01 * (l2_1_norm(left_skill) + l2_1_norm(right_skill))
|
| 463 |
+
)
|
| 464 |
+
l_right = torch.cat((right_skill, l), dim=1)
|
| 465 |
+
ins_right = torch.cat((l_right, mask_right), dim=1)
|
| 466 |
+
l_left = torch.cat((left_skill, l), dim=1)
|
| 467 |
+
ins_left = torch.cat((l_left, mask_left), dim=1)
|
| 468 |
+
if self.pos_encoding_with_lang:
|
| 469 |
+
ins_right = ins_right + self.pos_encoding
|
| 470 |
+
ins_left = ins_left + self.pos_encoding
|
| 471 |
+
else:
|
| 472 |
+
if self.lang_fusion_type == "seq":
|
| 473 |
+
# print(lang_token_embs.requires_grad) # False
|
| 474 |
+
l = self.lang_preprocess(lang_token_embs) # [B,77,512] -> [B,77,128]
|
| 475 |
+
# print(l.requires_grad) # True
|
| 476 |
+
ins = torch.cat((l, ins), dim=1)
|
| 477 |
+
# add pos encoding to language + flattened grid (the recommended way)
|
| 478 |
+
if self.pos_encoding_with_lang:
|
| 479 |
+
ins = ins + self.pos_encoding
|
| 480 |
+
|
| 481 |
+
# batchify latents
|
| 482 |
+
if self.anybimanual:
|
| 483 |
+
x = repeat(self.latents, "n d -> b n d", b=b)
|
| 484 |
+
cross_attn_right, cross_attn_left, cross_ff_right, cross_ff_left = self.cross_attend_blocks
|
| 485 |
+
else:
|
| 486 |
+
x = repeat(self.latents, "n d -> b n d", b=b)
|
| 487 |
+
cross_attn, cross_ff_right, cross_ff_left = self.cross_attend_blocks
|
| 488 |
+
|
| 489 |
+
if self.anybimanual:
|
| 490 |
+
ins_r = torch.cat((l_right, ins),dim=1)
|
| 491 |
+
ins_l = torch.cat((l_left, ins), dim=1)
|
| 492 |
+
ins_right = torch.cat((ins_right, ins_r), dim=2)
|
| 493 |
+
ins_left = torch.cat((ins_left, ins_l), dim=2)
|
| 494 |
+
for it in range(self.iterations):
|
| 495 |
+
# encoder cross attention
|
| 496 |
+
if self.anybimanual:
|
| 497 |
+
x_r, x_l = x.chunk(2, dim=1)
|
| 498 |
+
x_right = cross_attn_right(x_r, context=ins_right, mask=mask) + x_r
|
| 499 |
+
x_left = cross_attn_left(x_l, context=ins_left, mask=mask) + x_l
|
| 500 |
+
else:
|
| 501 |
+
x = cross_attn(x, context=ins, mask=mask) + x
|
| 502 |
+
x_right, x_left = x.chunk(2, dim=1)
|
| 503 |
+
x_right = cross_ff_right(x_right) + x_right
|
| 504 |
+
x_left = cross_ff_left(x_left) + x_left
|
| 505 |
+
# self-attention layers
|
| 506 |
+
for self_attn_right, self_ff_right, self_attn_left, self_ff_left in self.layers:
|
| 507 |
+
|
| 508 |
+
x_right = self_attn_right(x_right) + x_right
|
| 509 |
+
x_right = self_ff_right(x_right) + x_right
|
| 510 |
+
|
| 511 |
+
x_left = self_attn_left(x_left) + x_left
|
| 512 |
+
x_left = self_ff_left(x_left) + x_left
|
| 513 |
+
|
| 514 |
+
x = torch.concat([x_right, x_left], dim=1)
|
| 515 |
+
x = self.combined_latent_attn(x) + x
|
| 516 |
+
x = self.combined_latent_ff(x) + x
|
| 517 |
+
|
| 518 |
+
x_right, x_left = x.chunk(2, dim=1)
|
| 519 |
+
|
| 520 |
+
# decoder cross attention
|
| 521 |
+
if self.anybimanual:
|
| 522 |
+
latents_right = self.decoder_cross_attn_right(ins_right, context=x_right)
|
| 523 |
+
latents_left = self.decoder_cross_attn_left(ins_left, context=x_left)
|
| 524 |
+
if self.lang_fusion_type == "seq":
|
| 525 |
+
latents_right = latents_right[:, l_right.shape[1] :]
|
| 526 |
+
latents_left = latents_left[:, l_left.shape[1] :]
|
| 527 |
+
else:
|
| 528 |
+
latents_right = self.decoder_cross_attn_right(ins, context=x_right)
|
| 529 |
+
latents_left = self.decoder_cross_attn_left(ins, context=x_left)
|
| 530 |
+
if self.lang_fusion_type == "seq":
|
| 531 |
+
latents_right = latents_right[:, l.shape[1] :]
|
| 532 |
+
latents_left = latents_left[:, l.shape[1] :]
|
| 533 |
+
|
| 534 |
+
# crop out the language part of the output sequence
|
| 535 |
+
|
| 536 |
+
# reshape back to voxel grid
|
| 537 |
+
latents_right = latents_right.view(
|
| 538 |
+
b, *queries_orig_shape[1:-1], latents_right.shape[-1]
|
| 539 |
+
) # [B,20,20,20,64]
|
| 540 |
+
latents_right = rearrange(latents_right, "b ... d -> b d ...") # [B,64,20,20,20]
|
| 541 |
+
|
| 542 |
+
# reshape back to voxel grid
|
| 543 |
+
latents_left = latents_left.view(
|
| 544 |
+
b, *queries_orig_shape[1:-1], latents_left.shape[-1]
|
| 545 |
+
) # [B,20,20,20,64]
|
| 546 |
+
latents_left = rearrange(latents_left, "b ... d -> b d ...") # [B,64,20,20,20]
|
| 547 |
+
|
| 548 |
+
# aggregated features from 2nd softmax and maxpool for MLP decoders
|
| 549 |
+
|
| 550 |
+
feats_right = feats.copy()
|
| 551 |
+
feats_left = feats
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
feats_right.extend(
|
| 555 |
+
[self.ss1(latents_right.contiguous()), self.global_maxp(latents_right).view(b, -1)]
|
| 556 |
+
)
|
| 557 |
+
feats_left.extend(
|
| 558 |
+
[self.ss1(latents_left.contiguous()), self.global_maxp(latents_left).view(b, -1)]
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# upsample
|
| 562 |
+
u0_right = self.up0(latents_right)
|
| 563 |
+
u0_left = self.up0(latents_left)
|
| 564 |
+
|
| 565 |
+
# ablations
|
| 566 |
+
if self.no_skip_connection:
|
| 567 |
+
u_right = self.final(u0_right)
|
| 568 |
+
u_left = self.final(u0_left)
|
| 569 |
+
elif self.no_perceiver:
|
| 570 |
+
u_right = self.final(d0)
|
| 571 |
+
u_left = self.final(d0)
|
| 572 |
+
else:
|
| 573 |
+
u_right = self.final(torch.cat([d0, u0_right], dim=1))
|
| 574 |
+
u_left = self.final(torch.cat([d0, u0_left], dim=1))
|
| 575 |
+
|
| 576 |
+
# translation decoder
|
| 577 |
+
right_trans = self.right_trans_decoder(u_right)
|
| 578 |
+
left_trans = self.left_trans_decoder(u_left)
|
| 579 |
+
|
| 580 |
+
# rotation, gripper, and collision MLPs
|
| 581 |
+
rot_and_grip_out = None
|
| 582 |
+
if self.num_rotation_classes > 0:
|
| 583 |
+
feats_right.extend(
|
| 584 |
+
[self.ss_final(u_right.contiguous()), self.global_maxp(u_right).view(b, -1)]
|
| 585 |
+
)
|
| 586 |
+
|
| 587 |
+
right_dense0 = self.right_dense0(torch.cat(feats_right, dim=1))
|
| 588 |
+
right_dense1 = self.right_dense1(right_dense0) # [B,72*3+2+2]
|
| 589 |
+
|
| 590 |
+
right_rot_and_grip_collision_out = self.right_rot_grip_collision_ff(
|
| 591 |
+
right_dense1
|
| 592 |
+
)
|
| 593 |
+
right_rot_and_grip_out = right_rot_and_grip_collision_out[
|
| 594 |
+
:, : -self.num_collision_classes
|
| 595 |
+
]
|
| 596 |
+
right_collision_out = right_rot_and_grip_collision_out[
|
| 597 |
+
:, -self.num_collision_classes :
|
| 598 |
+
]
|
| 599 |
+
|
| 600 |
+
feats_left.extend(
|
| 601 |
+
[self.ss_final(u_left.contiguous()), self.global_maxp(u_left).view(b, -1)]
|
| 602 |
+
)
|
| 603 |
+
|
| 604 |
+
left_dense0 = self.left_dense0(torch.cat(feats_left, dim=1))
|
| 605 |
+
left_dense1 = self.left_dense1(left_dense0) # [B,72*3+2+2]
|
| 606 |
+
|
| 607 |
+
left_rot_and_grip_collision_out = self.left_rot_grip_collision_ff(
|
| 608 |
+
left_dense1
|
| 609 |
+
)
|
| 610 |
+
left_rot_and_grip_out = left_rot_and_grip_collision_out[
|
| 611 |
+
:, : -self.num_collision_classes
|
| 612 |
+
]
|
| 613 |
+
left_collision_out = left_rot_and_grip_collision_out[
|
| 614 |
+
:, -self.num_collision_classes :
|
| 615 |
+
]
|
| 616 |
+
|
| 617 |
+
if not self.anybimanual:
|
| 618 |
+
L_skill = 0
|
| 619 |
+
L_voxel = 0
|
| 620 |
+
return (
|
| 621 |
+
right_trans,
|
| 622 |
+
right_rot_and_grip_out,
|
| 623 |
+
right_collision_out,
|
| 624 |
+
left_trans,
|
| 625 |
+
left_rot_and_grip_out,
|
| 626 |
+
left_collision_out
|
| 627 |
+
), L_skill,L_voxel,
|
| 628 |
+
|
third_party/AnyBimanual/agents/peract_bimanual/qattention_peract_bc_agent.py
ADDED
|
@@ -0,0 +1,1317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from torchvision import transforms
|
| 11 |
+
from pytorch3d import transforms as torch3d_tf
|
| 12 |
+
from yarr.agents.agent import (
|
| 13 |
+
Agent,
|
| 14 |
+
ActResult,
|
| 15 |
+
ScalarSummary,
|
| 16 |
+
HistogramSummary,
|
| 17 |
+
ImageSummary,
|
| 18 |
+
Summary,
|
| 19 |
+
)
|
| 20 |
+
import io
|
| 21 |
+
import PIL.Image as Image
|
| 22 |
+
import matplotlib.pyplot as plt
|
| 23 |
+
from helpers import utils
|
| 24 |
+
from helpers.utils import visualise_voxel, stack_on_channel
|
| 25 |
+
from voxel import augmentation_ab
|
| 26 |
+
from voxel.voxel_grid import VoxelGrid
|
| 27 |
+
from voxel.augmentation import apply_se3_augmentation
|
| 28 |
+
from einops import rearrange
|
| 29 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 30 |
+
|
| 31 |
+
import transformers
|
| 32 |
+
from helpers.optim.lamb import Lamb
|
| 33 |
+
import wandb
|
| 34 |
+
from termcolor import colored, cprint
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 36 |
+
NAME = "QAttentionAgent"
|
| 37 |
+
import plotly.graph_objects as go
|
| 38 |
+
|
| 39 |
+
class QFunction(nn.Module):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
perceiver_encoder: nn.Module,
|
| 43 |
+
voxelizer: VoxelGrid,
|
| 44 |
+
bounds_offset: float,
|
| 45 |
+
rotation_resolution: float,
|
| 46 |
+
device,
|
| 47 |
+
training,
|
| 48 |
+
):
|
| 49 |
+
super(QFunction, self).__init__()
|
| 50 |
+
self._rotation_resolution = rotation_resolution
|
| 51 |
+
self._voxelizer = voxelizer
|
| 52 |
+
self._bounds_offset = bounds_offset
|
| 53 |
+
self._qnet = perceiver_encoder.to(device)
|
| 54 |
+
# distributed training
|
| 55 |
+
if training:
|
| 56 |
+
self._qnet = DDP(self._qnet, device_ids=[device], find_unused_parameters=True)
|
| 57 |
+
|
| 58 |
+
def _argmax_3d(self, tensor_orig):
|
| 59 |
+
b, c, d, h, w = tensor_orig.shape # c will be one
|
| 60 |
+
idxs = tensor_orig.view(b, c, -1).argmax(-1)
|
| 61 |
+
indices = torch.cat([((idxs // h) // d), (idxs // h) % w, idxs % w], 1)
|
| 62 |
+
return indices
|
| 63 |
+
|
| 64 |
+
def choose_highest_action(self, q_trans, q_rot_grip, q_collision):
|
| 65 |
+
coords = self._argmax_3d(q_trans)
|
| 66 |
+
rot_and_grip_indicies = None
|
| 67 |
+
ignore_collision = None
|
| 68 |
+
if q_rot_grip is not None:
|
| 69 |
+
q_rot = torch.stack(
|
| 70 |
+
torch.split(
|
| 71 |
+
q_rot_grip[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 72 |
+
),
|
| 73 |
+
dim=1,
|
| 74 |
+
)
|
| 75 |
+
rot_and_grip_indicies = torch.cat(
|
| 76 |
+
[
|
| 77 |
+
q_rot[:, 0:1].argmax(-1),
|
| 78 |
+
q_rot[:, 1:2].argmax(-1),
|
| 79 |
+
q_rot[:, 2:3].argmax(-1),
|
| 80 |
+
q_rot_grip[:, -2:].argmax(-1, keepdim=True),
|
| 81 |
+
],
|
| 82 |
+
-1,
|
| 83 |
+
)
|
| 84 |
+
ignore_collision = q_collision[:, -2:].argmax(-1, keepdim=True)
|
| 85 |
+
return coords, rot_and_grip_indicies, ignore_collision
|
| 86 |
+
|
| 87 |
+
def forward(
|
| 88 |
+
self,
|
| 89 |
+
rgb_pcd,
|
| 90 |
+
proprio,
|
| 91 |
+
pcd,
|
| 92 |
+
lang_goal_emb,
|
| 93 |
+
lang_token_embs,
|
| 94 |
+
bounds=None,
|
| 95 |
+
prev_bounds=None,
|
| 96 |
+
prev_layer_voxel_grid=None,
|
| 97 |
+
):
|
| 98 |
+
# rgb_pcd will be list of list (list of [rgb, pcd])
|
| 99 |
+
b = rgb_pcd[0][0].shape[0]
|
| 100 |
+
pcd_flat = torch.cat([p.permute(0, 2, 3, 1).reshape(b, -1, 3) for p in pcd], 1)
|
| 101 |
+
|
| 102 |
+
# flatten RGBs and Pointclouds
|
| 103 |
+
rgb = [rp[0] for rp in rgb_pcd]
|
| 104 |
+
feat_size = rgb[0].shape[1]
|
| 105 |
+
flat_imag_features = torch.cat(
|
| 106 |
+
[p.permute(0, 2, 3, 1).reshape(b, -1, feat_size) for p in rgb], 1
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
# construct voxel grid
|
| 110 |
+
voxel_grid = self._voxelizer.coords_to_bounding_voxel_grid(
|
| 111 |
+
pcd_flat, coord_features=flat_imag_features, coord_bounds=bounds
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
# swap to channels fist
|
| 115 |
+
voxel_grid = voxel_grid.permute(0, 4, 1, 2, 3).detach()
|
| 116 |
+
|
| 117 |
+
# print(voxel_grid.shape) # [b, 10, 100, 100, 100]
|
| 118 |
+
# batch bounds if necessary
|
| 119 |
+
if bounds.shape[0] != b:
|
| 120 |
+
bounds = bounds.repeat(b, 1)
|
| 121 |
+
# print(lang_goal_emb.shape) # [B, 1024]
|
| 122 |
+
# forward pass
|
| 123 |
+
|
| 124 |
+
#TO DO: return more information
|
| 125 |
+
split_pred, L_skill, L_voxel = self._qnet(
|
| 126 |
+
voxel_grid,
|
| 127 |
+
proprio,
|
| 128 |
+
lang_goal_emb,
|
| 129 |
+
lang_token_embs,
|
| 130 |
+
prev_layer_voxel_grid,
|
| 131 |
+
bounds,
|
| 132 |
+
prev_bounds,
|
| 133 |
+
)
|
| 134 |
+
return split_pred, voxel_grid, L_skill, L_voxel
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class QAttentionPerActBCAgent(Agent):
|
| 138 |
+
def __init__(
|
| 139 |
+
self,
|
| 140 |
+
layer: int,
|
| 141 |
+
coordinate_bounds: list,
|
| 142 |
+
perceiver_encoder: nn.Module,
|
| 143 |
+
camera_names: list,
|
| 144 |
+
batch_size: int,
|
| 145 |
+
voxel_size: int,
|
| 146 |
+
bounds_offset: float,
|
| 147 |
+
voxel_feature_size: int,
|
| 148 |
+
image_crop_size: int,
|
| 149 |
+
num_rotation_classes: int,
|
| 150 |
+
rotation_resolution: float,
|
| 151 |
+
lr: float = 0.0001,
|
| 152 |
+
lr_scheduler: bool = False,
|
| 153 |
+
training_iterations: int = 100000,
|
| 154 |
+
num_warmup_steps: int = 20000,
|
| 155 |
+
trans_loss_weight: float = 1.0,
|
| 156 |
+
rot_loss_weight: float = 1.0,
|
| 157 |
+
grip_loss_weight: float = 1.0,
|
| 158 |
+
collision_loss_weight: float = 1.0,
|
| 159 |
+
include_low_dim_state: bool = False,
|
| 160 |
+
image_resolution: list = None,
|
| 161 |
+
lambda_weight_l2: float = 0.0,
|
| 162 |
+
transform_augmentation: bool = True,
|
| 163 |
+
transform_augmentation_xyz: list = [0.0, 0.0, 0.0],
|
| 164 |
+
transform_augmentation_rpy: list = [0.0, 0.0, 180.0],
|
| 165 |
+
transform_augmentation_rot_resolution: int = 5,
|
| 166 |
+
optimizer_type: str = "adam",
|
| 167 |
+
num_devices: int = 1,
|
| 168 |
+
anybimanual = False,
|
| 169 |
+
load_exists_weights = False,
|
| 170 |
+
frozen = False,
|
| 171 |
+
cfg = None,
|
| 172 |
+
aug_type = "standard",
|
| 173 |
+
):
|
| 174 |
+
self.frozen = frozen
|
| 175 |
+
self.load = load_exists_weights
|
| 176 |
+
self._layer = layer
|
| 177 |
+
self._coordinate_bounds = coordinate_bounds
|
| 178 |
+
self._perceiver_encoder = perceiver_encoder
|
| 179 |
+
self._voxel_feature_size = voxel_feature_size
|
| 180 |
+
self._bounds_offset = bounds_offset
|
| 181 |
+
self._image_crop_size = image_crop_size
|
| 182 |
+
self._lr = lr
|
| 183 |
+
self._lr_scheduler = lr_scheduler
|
| 184 |
+
self._training_iterations = training_iterations
|
| 185 |
+
self._num_warmup_steps = num_warmup_steps
|
| 186 |
+
self._trans_loss_weight = trans_loss_weight
|
| 187 |
+
self._rot_loss_weight = rot_loss_weight
|
| 188 |
+
self._grip_loss_weight = grip_loss_weight
|
| 189 |
+
self._collision_loss_weight = collision_loss_weight
|
| 190 |
+
self._include_low_dim_state = include_low_dim_state
|
| 191 |
+
self._image_resolution = image_resolution or [128, 128]
|
| 192 |
+
self._voxel_size = voxel_size
|
| 193 |
+
self._camera_names = camera_names
|
| 194 |
+
self._num_cameras = len(camera_names)
|
| 195 |
+
self._batch_size = batch_size
|
| 196 |
+
self._lambda_weight_l2 = lambda_weight_l2
|
| 197 |
+
self._transform_augmentation = transform_augmentation
|
| 198 |
+
self._transform_augmentation_xyz = torch.from_numpy(
|
| 199 |
+
np.array(transform_augmentation_xyz)
|
| 200 |
+
)
|
| 201 |
+
self._transform_augmentation_rpy = transform_augmentation_rpy
|
| 202 |
+
self._transform_augmentation_rot_resolution = (
|
| 203 |
+
transform_augmentation_rot_resolution
|
| 204 |
+
)
|
| 205 |
+
self._optimizer_type = optimizer_type
|
| 206 |
+
self._num_devices = num_devices
|
| 207 |
+
self._num_rotation_classes = num_rotation_classes
|
| 208 |
+
self._rotation_resolution = rotation_resolution
|
| 209 |
+
|
| 210 |
+
self._cross_entropy_loss = nn.CrossEntropyLoss(reduction="none")
|
| 211 |
+
self._name = NAME + "_layer" + str(self._layer)
|
| 212 |
+
self.anybimanual = anybimanual
|
| 213 |
+
self.aug_type = aug_type
|
| 214 |
+
self.cfg = cfg
|
| 215 |
+
def build(self, training: bool, device: torch.device = None):
|
| 216 |
+
self._training = training
|
| 217 |
+
|
| 218 |
+
if device is None:
|
| 219 |
+
device = torch.device("cpu")
|
| 220 |
+
|
| 221 |
+
self._device = device
|
| 222 |
+
|
| 223 |
+
self._voxelizer = VoxelGrid(
|
| 224 |
+
coord_bounds=self._coordinate_bounds,
|
| 225 |
+
voxel_size=self._voxel_size,
|
| 226 |
+
device=device,
|
| 227 |
+
batch_size=self._batch_size if training else 1,
|
| 228 |
+
feature_size=self._voxel_feature_size,
|
| 229 |
+
max_num_coords=np.prod(self._image_resolution) * self._num_cameras,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
self._q = (
|
| 233 |
+
QFunction(
|
| 234 |
+
self._perceiver_encoder,
|
| 235 |
+
self._voxelizer,
|
| 236 |
+
self._bounds_offset,
|
| 237 |
+
self._rotation_resolution,
|
| 238 |
+
device,
|
| 239 |
+
training,
|
| 240 |
+
)
|
| 241 |
+
.to(device)
|
| 242 |
+
.train(training)
|
| 243 |
+
)
|
| 244 |
+
|
| 245 |
+
grid_for_crop = (
|
| 246 |
+
torch.arange(0, self._image_crop_size, device=device)
|
| 247 |
+
.unsqueeze(0)
|
| 248 |
+
.repeat(self._image_crop_size, 1)
|
| 249 |
+
.unsqueeze(-1)
|
| 250 |
+
)
|
| 251 |
+
self._grid_for_crop = torch.cat(
|
| 252 |
+
[grid_for_crop.transpose(1, 0), grid_for_crop], dim=2
|
| 253 |
+
).unsqueeze(0)
|
| 254 |
+
|
| 255 |
+
self._coordinate_bounds = torch.tensor(
|
| 256 |
+
self._coordinate_bounds, device=device
|
| 257 |
+
).unsqueeze(0)
|
| 258 |
+
|
| 259 |
+
if self._training:
|
| 260 |
+
# optimizer
|
| 261 |
+
if self._optimizer_type == "lamb":
|
| 262 |
+
self._optimizer = Lamb(
|
| 263 |
+
self._q.parameters(),
|
| 264 |
+
lr=self._lr,
|
| 265 |
+
weight_decay=self._lambda_weight_l2,
|
| 266 |
+
betas=(0.9, 0.999),
|
| 267 |
+
adam=False,
|
| 268 |
+
)
|
| 269 |
+
elif self._optimizer_type == "adam":
|
| 270 |
+
self._optimizer = torch.optim.Adam(
|
| 271 |
+
self._q.parameters(),
|
| 272 |
+
lr=self._lr,
|
| 273 |
+
weight_decay=self._lambda_weight_l2,
|
| 274 |
+
)
|
| 275 |
+
else:
|
| 276 |
+
raise Exception("Unknown optimizer type")
|
| 277 |
+
|
| 278 |
+
# learning rate scheduler
|
| 279 |
+
if self._lr_scheduler:
|
| 280 |
+
self._scheduler = (
|
| 281 |
+
transformers.get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 282 |
+
self._optimizer,
|
| 283 |
+
num_warmup_steps=self._num_warmup_steps,
|
| 284 |
+
num_training_steps=self._training_iterations,
|
| 285 |
+
num_cycles=self._training_iterations // 10000,
|
| 286 |
+
)
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# one-hot zero tensors
|
| 290 |
+
self._action_trans_one_hot_zeros = torch.zeros(
|
| 291 |
+
(
|
| 292 |
+
self._batch_size,
|
| 293 |
+
1,
|
| 294 |
+
self._voxel_size,
|
| 295 |
+
self._voxel_size,
|
| 296 |
+
self._voxel_size,
|
| 297 |
+
),
|
| 298 |
+
dtype=int,
|
| 299 |
+
device=device,
|
| 300 |
+
)
|
| 301 |
+
self._action_rot_x_one_hot_zeros = torch.zeros(
|
| 302 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 303 |
+
)
|
| 304 |
+
self._action_rot_y_one_hot_zeros = torch.zeros(
|
| 305 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 306 |
+
)
|
| 307 |
+
self._action_rot_z_one_hot_zeros = torch.zeros(
|
| 308 |
+
(self._batch_size, self._num_rotation_classes), dtype=int, device=device
|
| 309 |
+
)
|
| 310 |
+
self._action_grip_one_hot_zeros = torch.zeros(
|
| 311 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 312 |
+
)
|
| 313 |
+
self._action_ignore_collisions_one_hot_zeros = torch.zeros(
|
| 314 |
+
(self._batch_size, 2), dtype=int, device=device
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
# print total params
|
| 318 |
+
logging.info(
|
| 319 |
+
"# Q Params: %d"
|
| 320 |
+
% sum(
|
| 321 |
+
p.numel()
|
| 322 |
+
for name, p in self._q.named_parameters()
|
| 323 |
+
if p.requires_grad and "clip" not in name
|
| 324 |
+
)
|
| 325 |
+
)
|
| 326 |
+
# for name, p in self._q.named_parameters():
|
| 327 |
+
# print(f"Param: {name}, requires_grad: {p.requires_grad}")
|
| 328 |
+
else:
|
| 329 |
+
for param in self._q.parameters():
|
| 330 |
+
param.requires_grad = False
|
| 331 |
+
|
| 332 |
+
# load CLIP for encoding language goals during evaluation
|
| 333 |
+
model, _ = load_clip("RN50", jit=False)
|
| 334 |
+
self._clip_rn50 = build_model(model.state_dict())
|
| 335 |
+
self._clip_rn50 = self._clip_rn50.float().to(device)
|
| 336 |
+
self._clip_rn50.eval()
|
| 337 |
+
del model
|
| 338 |
+
|
| 339 |
+
self._voxelizer.to(device)
|
| 340 |
+
self._q.to(device)
|
| 341 |
+
|
| 342 |
+
def _extract_crop(self, pixel_action, observation):
|
| 343 |
+
# Pixel action will now be (B, 2)
|
| 344 |
+
# observation = stack_on_channel(observation)
|
| 345 |
+
h = observation.shape[-1]
|
| 346 |
+
top_left_corner = torch.clamp(
|
| 347 |
+
pixel_action - self._image_crop_size // 2, 0, h - self._image_crop_size
|
| 348 |
+
)
|
| 349 |
+
grid = self._grid_for_crop + top_left_corner.unsqueeze(1)
|
| 350 |
+
grid = ((grid / float(h)) * 2.0) - 1.0 # between -1 and 1
|
| 351 |
+
# Used for cropping the images across a batch
|
| 352 |
+
# swap fro y x, to x, y
|
| 353 |
+
grid = torch.cat((grid[:, :, :, 1:2], grid[:, :, :, 0:1]), dim=-1)
|
| 354 |
+
crop = F.grid_sample(observation, grid, mode="nearest", align_corners=True)
|
| 355 |
+
return crop
|
| 356 |
+
|
| 357 |
+
def _preprocess_inputs(self, replay_sample):
|
| 358 |
+
obs = []
|
| 359 |
+
pcds = []
|
| 360 |
+
rgbs = []
|
| 361 |
+
self._crop_summary = []
|
| 362 |
+
for n in self._camera_names:
|
| 363 |
+
rgb = replay_sample["%s_rgb" % n]
|
| 364 |
+
pcd = replay_sample["%s_point_cloud" % n]
|
| 365 |
+
|
| 366 |
+
obs.append([rgb, pcd])
|
| 367 |
+
pcds.append(pcd)
|
| 368 |
+
rgbs.append(rgb)
|
| 369 |
+
return obs, pcds, rgbs
|
| 370 |
+
|
| 371 |
+
def _act_preprocess_inputs(self, observation):
|
| 372 |
+
obs, pcds = [], []
|
| 373 |
+
for n in self._camera_names:
|
| 374 |
+
rgb = observation["%s_rgb" % n]
|
| 375 |
+
pcd = observation["%s_point_cloud" % n]
|
| 376 |
+
|
| 377 |
+
obs.append([rgb, pcd])
|
| 378 |
+
pcds.append(pcd)
|
| 379 |
+
return obs, pcds
|
| 380 |
+
|
| 381 |
+
def _get_value_from_voxel_index(self, q, voxel_idx):
|
| 382 |
+
b, c, d, h, w = q.shape
|
| 383 |
+
q_trans_flat = q.view(b, c, d * h * w)
|
| 384 |
+
flat_indicies = (
|
| 385 |
+
voxel_idx[:, 0] * d * h + voxel_idx[:, 1] * h + voxel_idx[:, 2]
|
| 386 |
+
)[:, None].int()
|
| 387 |
+
highest_idxs = flat_indicies.unsqueeze(-1).repeat(1, c, 1)
|
| 388 |
+
chosen_voxel_values = q_trans_flat.gather(2, highest_idxs)[
|
| 389 |
+
..., 0
|
| 390 |
+
] # (B, trans + rot + grip)
|
| 391 |
+
return chosen_voxel_values
|
| 392 |
+
|
| 393 |
+
def _get_value_from_rot_and_grip(self, rot_grip_q, rot_and_grip_idx):
|
| 394 |
+
q_rot = torch.stack(
|
| 395 |
+
torch.split(
|
| 396 |
+
rot_grip_q[:, :-2], int(360 // self._rotation_resolution), dim=1
|
| 397 |
+
),
|
| 398 |
+
dim=1,
|
| 399 |
+
) # B, 3, 72
|
| 400 |
+
q_grip = rot_grip_q[:, -2:]
|
| 401 |
+
rot_and_grip_values = torch.cat(
|
| 402 |
+
[
|
| 403 |
+
q_rot[:, 0].gather(1, rot_and_grip_idx[:, 0:1]),
|
| 404 |
+
q_rot[:, 1].gather(1, rot_and_grip_idx[:, 1:2]),
|
| 405 |
+
q_rot[:, 2].gather(1, rot_and_grip_idx[:, 2:3]),
|
| 406 |
+
q_grip.gather(1, rot_and_grip_idx[:, 3:4]),
|
| 407 |
+
],
|
| 408 |
+
-1,
|
| 409 |
+
)
|
| 410 |
+
return rot_and_grip_values
|
| 411 |
+
|
| 412 |
+
def _celoss(self, pred, labels):
|
| 413 |
+
return self._cross_entropy_loss(pred, labels.argmax(-1))
|
| 414 |
+
|
| 415 |
+
def _softmax_q_trans(self, q):
|
| 416 |
+
q_shape = q.shape
|
| 417 |
+
return F.softmax(q.reshape(q_shape[0], -1), dim=1).reshape(q_shape)
|
| 418 |
+
|
| 419 |
+
def _softmax_q_rot_grip(self, q_rot_grip):
|
| 420 |
+
q_rot_x_flat = q_rot_grip[
|
| 421 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 422 |
+
]
|
| 423 |
+
q_rot_y_flat = q_rot_grip[
|
| 424 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 425 |
+
]
|
| 426 |
+
q_rot_z_flat = q_rot_grip[
|
| 427 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 428 |
+
]
|
| 429 |
+
q_grip_flat = q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 430 |
+
|
| 431 |
+
q_rot_x_flat_softmax = F.softmax(q_rot_x_flat, dim=1)
|
| 432 |
+
q_rot_y_flat_softmax = F.softmax(q_rot_y_flat, dim=1)
|
| 433 |
+
q_rot_z_flat_softmax = F.softmax(q_rot_z_flat, dim=1)
|
| 434 |
+
q_grip_flat_softmax = F.softmax(q_grip_flat, dim=1)
|
| 435 |
+
|
| 436 |
+
return torch.cat(
|
| 437 |
+
[
|
| 438 |
+
q_rot_x_flat_softmax,
|
| 439 |
+
q_rot_y_flat_softmax,
|
| 440 |
+
q_rot_z_flat_softmax,
|
| 441 |
+
q_grip_flat_softmax,
|
| 442 |
+
],
|
| 443 |
+
dim=1,
|
| 444 |
+
)
|
| 445 |
+
|
| 446 |
+
def _softmax_ignore_collision(self, q_collision):
|
| 447 |
+
q_collision_softmax = F.softmax(q_collision, dim=1)
|
| 448 |
+
return q_collision_softmax
|
| 449 |
+
|
| 450 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 451 |
+
if step > 50:
|
| 452 |
+
for name, param in self._q.named_parameters():
|
| 453 |
+
if 'fc1_right' in name:
|
| 454 |
+
param.requires_grad = False
|
| 455 |
+
if 'fc1_left' in name:
|
| 456 |
+
param.requires_grad = False
|
| 457 |
+
right_action_trans = replay_sample["right_trans_action_indicies"][
|
| 458 |
+
..., self._layer * 3 : self._layer * 3 + 3
|
| 459 |
+
].int()
|
| 460 |
+
right_action_rot_grip = replay_sample["right_rot_grip_action_indicies"].int()
|
| 461 |
+
right_action_gripper_pose = replay_sample["right_gripper_pose"]
|
| 462 |
+
right_action_ignore_collisions = replay_sample["right_ignore_collisions"].int()
|
| 463 |
+
|
| 464 |
+
left_action_trans = replay_sample["left_trans_action_indicies"][
|
| 465 |
+
..., self._layer * 3 : self._layer * 3 + 3
|
| 466 |
+
].int()
|
| 467 |
+
left_action_rot_grip = replay_sample["left_rot_grip_action_indicies"].int()
|
| 468 |
+
left_action_gripper_pose = replay_sample["left_gripper_pose"]
|
| 469 |
+
left_action_ignore_collisions = replay_sample["left_ignore_collisions"].int()
|
| 470 |
+
|
| 471 |
+
lang_goal_emb = replay_sample["lang_goal_emb"].float()
|
| 472 |
+
lang_token_embs = replay_sample["lang_token_embs"].float()
|
| 473 |
+
prev_layer_voxel_grid = replay_sample.get("prev_layer_voxel_grid", None)
|
| 474 |
+
prev_layer_bounds = replay_sample.get("prev_layer_bounds", None)
|
| 475 |
+
device = self._device
|
| 476 |
+
|
| 477 |
+
rank = device
|
| 478 |
+
bounds = self._coordinate_bounds.to(device)
|
| 479 |
+
if self._layer > 0:
|
| 480 |
+
right_cp = replay_sample[
|
| 481 |
+
"right_attention_coordinate_layer_%d" % (self._layer - 1)
|
| 482 |
+
]
|
| 483 |
+
|
| 484 |
+
left_cp = replay_sample[
|
| 485 |
+
"left_attention_coordinate_layer_%d" % (self._layer - 1)
|
| 486 |
+
]
|
| 487 |
+
|
| 488 |
+
right_bounds = torch.cat(
|
| 489 |
+
[right_cp - self._bounds_offset, right_cp + self._bounds_offset], dim=1
|
| 490 |
+
)
|
| 491 |
+
left_bounds = torch.cat(
|
| 492 |
+
[left_cp - self._bounds_offset, left_cp + self._bounds_offset], dim=1
|
| 493 |
+
)
|
| 494 |
+
|
| 495 |
+
else:
|
| 496 |
+
right_bounds = bounds
|
| 497 |
+
left_bounds = bounds
|
| 498 |
+
|
| 499 |
+
right_proprio = None
|
| 500 |
+
left_proprio = None
|
| 501 |
+
if self._include_low_dim_state:
|
| 502 |
+
right_proprio = replay_sample["right_low_dim_state"]
|
| 503 |
+
left_proprio = replay_sample["left_low_dim_state"]
|
| 504 |
+
|
| 505 |
+
# ..TODO::
|
| 506 |
+
# Can we add the coordinates of both robots?
|
| 507 |
+
#
|
| 508 |
+
|
| 509 |
+
obs, pcd, rgbs = self._preprocess_inputs(replay_sample)
|
| 510 |
+
|
| 511 |
+
# batch size
|
| 512 |
+
bs = pcd[0].shape[0]
|
| 513 |
+
|
| 514 |
+
# We can move the point cloud w.r.t to the other robot's cooridinate system
|
| 515 |
+
# similar to apply_se3_augmentation
|
| 516 |
+
# SE(3) augmentation of point clouds and actions
|
| 517 |
+
if self._transform_augmentation:
|
| 518 |
+
from voxel import augmentation, augmentation_ab
|
| 519 |
+
if self.aug_type == "ab":
|
| 520 |
+
(
|
| 521 |
+
right_action_trans,
|
| 522 |
+
right_action_rot_grip,
|
| 523 |
+
left_action_trans,
|
| 524 |
+
left_action_rot_grip,
|
| 525 |
+
pcd,
|
| 526 |
+
) = augmentation_ab.bimanual_apply_se3_augmentation(
|
| 527 |
+
pcd,
|
| 528 |
+
right_action_gripper_pose,
|
| 529 |
+
right_action_trans,
|
| 530 |
+
right_action_rot_grip,
|
| 531 |
+
left_action_gripper_pose,
|
| 532 |
+
left_action_trans,
|
| 533 |
+
left_action_rot_grip,
|
| 534 |
+
bounds,
|
| 535 |
+
self._layer,
|
| 536 |
+
self._transform_augmentation_xyz,
|
| 537 |
+
self._transform_augmentation_rpy,
|
| 538 |
+
self._transform_augmentation_rot_resolution,
|
| 539 |
+
self._voxel_size,
|
| 540 |
+
self._rotation_resolution,
|
| 541 |
+
self._device,
|
| 542 |
+
)
|
| 543 |
+
else:
|
| 544 |
+
(
|
| 545 |
+
right_action_trans,
|
| 546 |
+
right_action_rot_grip,
|
| 547 |
+
left_action_trans,
|
| 548 |
+
left_action_rot_grip,
|
| 549 |
+
pcd,
|
| 550 |
+
) = augmentation.bimanual_apply_se3_augmentation(
|
| 551 |
+
pcd,
|
| 552 |
+
right_action_gripper_pose,
|
| 553 |
+
right_action_trans,
|
| 554 |
+
right_action_rot_grip,
|
| 555 |
+
left_action_gripper_pose,
|
| 556 |
+
left_action_trans,
|
| 557 |
+
left_action_rot_grip,
|
| 558 |
+
bounds,
|
| 559 |
+
self._layer,
|
| 560 |
+
self._transform_augmentation_xyz,
|
| 561 |
+
self._transform_augmentation_rpy,
|
| 562 |
+
self._transform_augmentation_rot_resolution,
|
| 563 |
+
self._voxel_size,
|
| 564 |
+
self._rotation_resolution,
|
| 565 |
+
self._device,
|
| 566 |
+
)
|
| 567 |
+
else:
|
| 568 |
+
right_action_trans = right_action_trans.int()
|
| 569 |
+
left_action_trans = left_action_trans.int()
|
| 570 |
+
|
| 571 |
+
proprio = torch.cat((right_proprio, left_proprio), dim=1)
|
| 572 |
+
right_action = (
|
| 573 |
+
right_action_trans,
|
| 574 |
+
right_action_rot_grip,
|
| 575 |
+
right_action_ignore_collisions,
|
| 576 |
+
)
|
| 577 |
+
left_action = (
|
| 578 |
+
left_action_trans,
|
| 579 |
+
left_action_rot_grip,
|
| 580 |
+
left_action_ignore_collisions,
|
| 581 |
+
)
|
| 582 |
+
# forward pass
|
| 583 |
+
q, voxel_grid, L_skill, L_voxel = self._q(
|
| 584 |
+
obs,
|
| 585 |
+
proprio,
|
| 586 |
+
pcd,
|
| 587 |
+
lang_goal_emb,
|
| 588 |
+
lang_token_embs,
|
| 589 |
+
bounds,
|
| 590 |
+
prev_layer_bounds,
|
| 591 |
+
prev_layer_voxel_grid,
|
| 592 |
+
|
| 593 |
+
)
|
| 594 |
+
|
| 595 |
+
(
|
| 596 |
+
right_q_trans,
|
| 597 |
+
right_q_rot_grip,
|
| 598 |
+
right_q_collision,
|
| 599 |
+
left_q_trans,
|
| 600 |
+
left_q_rot_grip,
|
| 601 |
+
left_q_collision,
|
| 602 |
+
) = q
|
| 603 |
+
|
| 604 |
+
# argmax to choose best action
|
| 605 |
+
(
|
| 606 |
+
right_coords,
|
| 607 |
+
right_rot_and_grip_indicies,
|
| 608 |
+
right_ignore_collision_indicies,
|
| 609 |
+
) = self._q.choose_highest_action(
|
| 610 |
+
right_q_trans, right_q_rot_grip, right_q_collision
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
(
|
| 614 |
+
left_coords,
|
| 615 |
+
left_rot_and_grip_indicies,
|
| 616 |
+
left_ignore_collision_indicies,
|
| 617 |
+
) = self._q.choose_highest_action(
|
| 618 |
+
left_q_trans, left_q_rot_grip, left_q_collision
|
| 619 |
+
)
|
| 620 |
+
|
| 621 |
+
right_q_trans_loss, right_q_rot_loss, right_q_grip_loss, right_q_collision_loss = 0.0, 0.0, 0.0, 0.0
|
| 622 |
+
left_q_trans_loss, left_q_rot_loss, left_q_grip_loss, left_q_collision_loss = 0.0, 0.0, 0.0, 0.0
|
| 623 |
+
|
| 624 |
+
# translation one-hot
|
| 625 |
+
right_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
|
| 626 |
+
left_action_trans_one_hot = self._action_trans_one_hot_zeros.clone().detach()
|
| 627 |
+
for b in range(bs):
|
| 628 |
+
right_gt_coord = right_action_trans[b, :].int()
|
| 629 |
+
right_action_trans_one_hot[
|
| 630 |
+
b, :, right_gt_coord[0], right_gt_coord[1], right_gt_coord[2]
|
| 631 |
+
] = 1
|
| 632 |
+
left_gt_coord = left_action_trans[b, :].int()
|
| 633 |
+
left_action_trans_one_hot[
|
| 634 |
+
b, :, left_gt_coord[0], left_gt_coord[1], left_gt_coord[2]
|
| 635 |
+
] = 1
|
| 636 |
+
|
| 637 |
+
# translation loss
|
| 638 |
+
right_q_trans_flat = right_q_trans.view(bs, -1)
|
| 639 |
+
right_action_trans_one_hot_flat = right_action_trans_one_hot.view(bs, -1)
|
| 640 |
+
right_q_trans_loss = self._celoss(
|
| 641 |
+
right_q_trans_flat, right_action_trans_one_hot_flat
|
| 642 |
+
)
|
| 643 |
+
left_q_trans_flat = left_q_trans.view(bs, -1)
|
| 644 |
+
left_action_trans_one_hot_flat = left_action_trans_one_hot.view(bs, -1)
|
| 645 |
+
left_q_trans_loss = self._celoss(
|
| 646 |
+
left_q_trans_flat, left_action_trans_one_hot_flat
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
q_trans_loss = right_q_trans_loss + left_q_trans_loss
|
| 650 |
+
|
| 651 |
+
with_rot_and_grip = (
|
| 652 |
+
len(right_rot_and_grip_indicies) > 0 and len(left_rot_and_grip_indicies) > 0
|
| 653 |
+
)
|
| 654 |
+
if with_rot_and_grip:
|
| 655 |
+
# rotation, gripper, and collision one-hots
|
| 656 |
+
right_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 657 |
+
right_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 658 |
+
right_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 659 |
+
right_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 660 |
+
right_action_ignore_collisions_one_hot = (
|
| 661 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
+
left_action_rot_x_one_hot = self._action_rot_x_one_hot_zeros.clone()
|
| 665 |
+
left_action_rot_y_one_hot = self._action_rot_y_one_hot_zeros.clone()
|
| 666 |
+
left_action_rot_z_one_hot = self._action_rot_z_one_hot_zeros.clone()
|
| 667 |
+
left_action_grip_one_hot = self._action_grip_one_hot_zeros.clone()
|
| 668 |
+
left_action_ignore_collisions_one_hot = (
|
| 669 |
+
self._action_ignore_collisions_one_hot_zeros.clone()
|
| 670 |
+
)
|
| 671 |
+
|
| 672 |
+
for b in range(bs):
|
| 673 |
+
right_gt_rot_grip = right_action_rot_grip[b, :].int()
|
| 674 |
+
right_action_rot_x_one_hot[b, right_gt_rot_grip[0]] = 1
|
| 675 |
+
right_action_rot_y_one_hot[b, right_gt_rot_grip[1]] = 1
|
| 676 |
+
right_action_rot_z_one_hot[b, right_gt_rot_grip[2]] = 1
|
| 677 |
+
right_action_grip_one_hot[b, right_gt_rot_grip[3]] = 1
|
| 678 |
+
|
| 679 |
+
right_gt_ignore_collisions = right_action_ignore_collisions[b, :].int()
|
| 680 |
+
right_action_ignore_collisions_one_hot[
|
| 681 |
+
b, right_gt_ignore_collisions[0]
|
| 682 |
+
] = 1
|
| 683 |
+
|
| 684 |
+
left_gt_rot_grip = left_action_rot_grip[b, :].int()
|
| 685 |
+
left_action_rot_x_one_hot[b, left_gt_rot_grip[0]] = 1
|
| 686 |
+
left_action_rot_y_one_hot[b, left_gt_rot_grip[1]] = 1
|
| 687 |
+
left_action_rot_z_one_hot[b, left_gt_rot_grip[2]] = 1
|
| 688 |
+
left_action_grip_one_hot[b, left_gt_rot_grip[3]] = 1
|
| 689 |
+
|
| 690 |
+
left_gt_ignore_collisions = left_action_ignore_collisions[b, :].int()
|
| 691 |
+
left_action_ignore_collisions_one_hot[
|
| 692 |
+
b, left_gt_ignore_collisions[0]
|
| 693 |
+
] = 1
|
| 694 |
+
|
| 695 |
+
# flatten predictions
|
| 696 |
+
right_q_rot_x_flat = right_q_rot_grip[
|
| 697 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 698 |
+
]
|
| 699 |
+
right_q_rot_y_flat = right_q_rot_grip[
|
| 700 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 701 |
+
]
|
| 702 |
+
right_q_rot_z_flat = right_q_rot_grip[
|
| 703 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 704 |
+
]
|
| 705 |
+
right_q_grip_flat = right_q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 706 |
+
right_q_ignore_collisions_flat = right_q_collision
|
| 707 |
+
|
| 708 |
+
left_q_rot_x_flat = left_q_rot_grip[
|
| 709 |
+
:, 0 * self._num_rotation_classes : 1 * self._num_rotation_classes
|
| 710 |
+
]
|
| 711 |
+
left_q_rot_y_flat = left_q_rot_grip[
|
| 712 |
+
:, 1 * self._num_rotation_classes : 2 * self._num_rotation_classes
|
| 713 |
+
]
|
| 714 |
+
left_q_rot_z_flat = left_q_rot_grip[
|
| 715 |
+
:, 2 * self._num_rotation_classes : 3 * self._num_rotation_classes
|
| 716 |
+
]
|
| 717 |
+
left_q_grip_flat = left_q_rot_grip[:, 3 * self._num_rotation_classes :]
|
| 718 |
+
left_q_ignore_collisions_flat = left_q_collision
|
| 719 |
+
|
| 720 |
+
|
| 721 |
+
# rotation loss
|
| 722 |
+
right_q_rot_loss += self._celoss(right_q_rot_x_flat, right_action_rot_x_one_hot)
|
| 723 |
+
right_q_rot_loss += self._celoss(right_q_rot_y_flat, right_action_rot_y_one_hot)
|
| 724 |
+
right_q_rot_loss += self._celoss(right_q_rot_z_flat, right_action_rot_z_one_hot)
|
| 725 |
+
|
| 726 |
+
left_q_rot_loss += self._celoss(left_q_rot_x_flat, left_action_rot_x_one_hot)
|
| 727 |
+
left_q_rot_loss += self._celoss(left_q_rot_y_flat, left_action_rot_y_one_hot)
|
| 728 |
+
left_q_rot_loss += self._celoss(left_q_rot_z_flat, left_action_rot_z_one_hot)
|
| 729 |
+
|
| 730 |
+
# gripper loss
|
| 731 |
+
right_q_grip_loss += self._celoss(right_q_grip_flat, right_action_grip_one_hot)
|
| 732 |
+
left_q_grip_loss += self._celoss(left_q_grip_flat, left_action_grip_one_hot)
|
| 733 |
+
|
| 734 |
+
# collision loss
|
| 735 |
+
right_q_collision_loss += self._celoss(
|
| 736 |
+
right_q_ignore_collisions_flat, right_action_ignore_collisions_one_hot
|
| 737 |
+
)
|
| 738 |
+
left_q_collision_loss += self._celoss(
|
| 739 |
+
left_q_ignore_collisions_flat, left_action_ignore_collisions_one_hot
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
|
| 743 |
+
q_trans_loss = right_q_trans_loss + left_q_trans_loss
|
| 744 |
+
q_rot_loss = right_q_rot_loss + left_q_rot_loss
|
| 745 |
+
q_grip_loss = right_q_grip_loss + left_q_grip_loss
|
| 746 |
+
q_collision_loss = right_q_collision_loss + left_q_collision_loss
|
| 747 |
+
|
| 748 |
+
combined_losses = (
|
| 749 |
+
(q_trans_loss * self._trans_loss_weight)
|
| 750 |
+
+ (q_rot_loss * self._rot_loss_weight)
|
| 751 |
+
+ (q_grip_loss * self._grip_loss_weight)
|
| 752 |
+
+ (q_collision_loss * self._collision_loss_weight)
|
| 753 |
+
+ 0.0001*L_skill
|
| 754 |
+
+ 0.01*L_voxel
|
| 755 |
+
)
|
| 756 |
+
total_loss = combined_losses.mean()
|
| 757 |
+
|
| 758 |
+
if step % 10 == 0 and rank == 0 and wandb.run is not None:
|
| 759 |
+
wandb.log({
|
| 760 |
+
'train/grip_loss': q_grip_loss.mean(),
|
| 761 |
+
'train/trans_loss': q_trans_loss.mean(),
|
| 762 |
+
'train/rot_loss': q_rot_loss.mean(),
|
| 763 |
+
'train/collision_loss': q_collision_loss.mean(),
|
| 764 |
+
'train/total_loss': total_loss,
|
| 765 |
+
}, step=step)
|
| 766 |
+
|
| 767 |
+
torch.autograd.set_detect_anomaly(True)
|
| 768 |
+
self._optimizer.zero_grad()
|
| 769 |
+
total_loss.backward()
|
| 770 |
+
self._optimizer.step()
|
| 771 |
+
torch.cuda.empty_cache()
|
| 772 |
+
|
| 773 |
+
self._summaries = {
|
| 774 |
+
"losses/total_loss": total_loss,
|
| 775 |
+
"losses/trans_loss": q_trans_loss.mean(),
|
| 776 |
+
"losses/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 777 |
+
"losses/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 778 |
+
|
| 779 |
+
"losses/right/trans_loss": q_trans_loss.mean(),
|
| 780 |
+
"losses/right/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 781 |
+
"losses/right/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 782 |
+
"losses/right/collision_loss": q_collision_loss.mean() if with_rot_and_grip else 0.0,
|
| 783 |
+
|
| 784 |
+
"losses/left/trans_loss": q_trans_loss.mean(),
|
| 785 |
+
"losses/left/rot_loss": q_rot_loss.mean() if with_rot_and_grip else 0.0,
|
| 786 |
+
"losses/left/grip_loss": q_grip_loss.mean() if with_rot_and_grip else 0.0,
|
| 787 |
+
"losses/left/collision_loss": q_collision_loss.mean() if with_rot_and_grip else 0.0,
|
| 788 |
+
|
| 789 |
+
"losses/collision_loss": q_collision_loss.mean()
|
| 790 |
+
if with_rot_and_grip
|
| 791 |
+
else 0.0,
|
| 792 |
+
}
|
| 793 |
+
|
| 794 |
+
self._wandb_summaries = {
|
| 795 |
+
'losses/total_loss': total_loss,
|
| 796 |
+
'losses/trans_loss': q_trans_loss.mean(),
|
| 797 |
+
'losses/rot_loss': q_rot_loss.mean() if with_rot_and_grip else 0.,
|
| 798 |
+
'losses/grip_loss': q_grip_loss.mean() if with_rot_and_grip else 0.,
|
| 799 |
+
'losses/collision_loss': q_collision_loss.mean() if with_rot_and_grip else 0.
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
if self._lr_scheduler:
|
| 803 |
+
self._scheduler.step()
|
| 804 |
+
self._summaries["learning_rate"] = self._scheduler.get_last_lr()[0]
|
| 805 |
+
|
| 806 |
+
self._vis_voxel_grid = voxel_grid[0]
|
| 807 |
+
self._right_vis_translation_qvalue = self._softmax_q_trans(right_q_trans[0])
|
| 808 |
+
self._right_vis_max_coordinate = right_coords[0]
|
| 809 |
+
self._right_vis_gt_coordinate = right_action_trans[0]
|
| 810 |
+
|
| 811 |
+
self._left_vis_translation_qvalue = self._softmax_q_trans(left_q_trans[0])
|
| 812 |
+
self._left_vis_max_coordinate = left_coords[0]
|
| 813 |
+
self._left_vis_gt_coordinate = left_action_trans[0]
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
# Note: PerAct doesn't use multi-layer voxel grids like C2FARM
|
| 817 |
+
# stack prev_layer_voxel_grid(s) from previous layers into a list
|
| 818 |
+
if prev_layer_voxel_grid is None:
|
| 819 |
+
prev_layer_voxel_grid = [voxel_grid]
|
| 820 |
+
else:
|
| 821 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [voxel_grid]
|
| 822 |
+
|
| 823 |
+
# stack prev_layer_bound(s) from previous layers into a list
|
| 824 |
+
if prev_layer_bounds is None:
|
| 825 |
+
prev_layer_bounds = [self._coordinate_bounds.repeat(bs, 1)]
|
| 826 |
+
else:
|
| 827 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 828 |
+
|
| 829 |
+
q_trans_vis=True
|
| 830 |
+
log_freq = getattr(getattr(getattr(self, "cfg", None), "framework", None), "log_freq", None)
|
| 831 |
+
if log_freq and step % log_freq == 0 and rank == 0:
|
| 832 |
+
print("right_predict: ",self._right_vis_max_coordinate)
|
| 833 |
+
print("right_gt: ",self._right_vis_gt_coordinate)
|
| 834 |
+
print("left_predict: ",self._left_vis_max_coordinate)
|
| 835 |
+
print("left_gt: ",self._left_vis_gt_coordinate)
|
| 836 |
+
rendered_img_right = visualise_voxel(
|
| 837 |
+
voxel_grid[0].cpu().detach().numpy(), # [10, 100, 100, 100]
|
| 838 |
+
self._right_vis_translation_qvalue.detach().cpu().numpy() if q_trans_vis else None,
|
| 839 |
+
self._right_vis_max_coordinate.detach().cpu().numpy(),
|
| 840 |
+
self._right_vis_gt_coordinate.detach().cpu().numpy(),
|
| 841 |
+
voxel_size=0.045,
|
| 842 |
+
# voxel_size=0.1, # more focus ??
|
| 843 |
+
rotation_amount=np.deg2rad(-90),
|
| 844 |
+
highlight_alpha=1.0,
|
| 845 |
+
alpha=0.4,
|
| 846 |
+
)
|
| 847 |
+
rendered_img_left = visualise_voxel(
|
| 848 |
+
voxel_grid[0].cpu().detach().numpy(), # [10, 100, 100, 100]
|
| 849 |
+
self._left_vis_translation_qvalue.detach().cpu().numpy() if q_trans_vis else None,
|
| 850 |
+
self._left_vis_max_coordinate.detach().cpu().numpy(),
|
| 851 |
+
self._left_vis_gt_coordinate.detach().cpu().numpy(),
|
| 852 |
+
voxel_size=0.045,
|
| 853 |
+
# voxel_size=0.1, # more focus ??
|
| 854 |
+
rotation_amount=np.deg2rad(-90),
|
| 855 |
+
highlight_alpha=1.0,
|
| 856 |
+
alpha=0.4,
|
| 857 |
+
)
|
| 858 |
+
os.makedirs('recon', exist_ok=True)
|
| 859 |
+
# plot three images in one row with subplots:
|
| 860 |
+
rgb_src = obs[0][0][0].squeeze(0).permute(1, 2, 0) / 2 + 0.5
|
| 861 |
+
|
| 862 |
+
fig, axs = plt.subplots(1, 4, figsize=(9, 3))
|
| 863 |
+
# src
|
| 864 |
+
axs[0].imshow(rgb_src.cpu().numpy())
|
| 865 |
+
axs[0].title.set_text('src')
|
| 866 |
+
|
| 867 |
+
axs[1].imshow(rendered_img_right)
|
| 868 |
+
axs[1].text(0, 40, 'predicted', color='blue')
|
| 869 |
+
axs[1].text(0, 80, 'gt', color='red')
|
| 870 |
+
axs[2].imshow(rendered_img_left)
|
| 871 |
+
axs[2].text(0, 40, 'predicted', color='blue')
|
| 872 |
+
axs[2].text(0, 80, 'gt', color='red')
|
| 873 |
+
for ax in axs:
|
| 874 |
+
ax.axis('off')
|
| 875 |
+
plt.tight_layout()
|
| 876 |
+
|
| 877 |
+
if rank == 0:
|
| 878 |
+
if wandb.run is not None:
|
| 879 |
+
buf = io.BytesIO()
|
| 880 |
+
plt.savefig(buf, format='png')
|
| 881 |
+
buf.seek(0)
|
| 882 |
+
|
| 883 |
+
image = Image.open(buf)
|
| 884 |
+
wandb.log({"eval/recon_img": wandb.Image(image)}, step=step)
|
| 885 |
+
|
| 886 |
+
buf.close()
|
| 887 |
+
cprint(f'Saved to wandb', 'cyan')
|
| 888 |
+
else:
|
| 889 |
+
plt.savefig(f'recon/{step}_rgb.png')
|
| 890 |
+
workdir = os.getcwd()
|
| 891 |
+
cprint(f'Saved {workdir}/recon/{step}_rgb.png locally', 'cyan')
|
| 892 |
+
|
| 893 |
+
return {
|
| 894 |
+
"total_loss": total_loss,
|
| 895 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 896 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 897 |
+
}
|
| 898 |
+
|
| 899 |
+
def act(self, step: int,observation: dict,deterministic=False) -> ActResult:
|
| 900 |
+
deterministic = True
|
| 901 |
+
bounds = self._coordinate_bounds
|
| 902 |
+
prev_layer_voxel_grid = observation.get("prev_layer_voxel_grid", None)
|
| 903 |
+
prev_layer_bounds = observation.get("prev_layer_bounds", None)
|
| 904 |
+
lang_goal_tokens = observation.get("lang_goal_tokens", None).long()
|
| 905 |
+
|
| 906 |
+
# extract CLIP language embs
|
| 907 |
+
with torch.no_grad():
|
| 908 |
+
lang_goal_tokens = lang_goal_tokens.to(device=self._device)
|
| 909 |
+
(
|
| 910 |
+
lang_goal_emb,
|
| 911 |
+
lang_token_embs,
|
| 912 |
+
) = self._clip_rn50.encode_text_with_embeddings(lang_goal_tokens[0])
|
| 913 |
+
|
| 914 |
+
# voxelization resolution
|
| 915 |
+
res = (bounds[:, 3:] - bounds[:, :3]) / self._voxel_size
|
| 916 |
+
max_rot_index = int(360 // self._rotation_resolution)
|
| 917 |
+
right_proprio = None
|
| 918 |
+
left_proprio = None
|
| 919 |
+
|
| 920 |
+
if self._include_low_dim_state:
|
| 921 |
+
right_proprio = observation["right_low_dim_state"]
|
| 922 |
+
left_proprio = observation["left_low_dim_state"]
|
| 923 |
+
right_proprio = right_proprio[0].to(self._device)
|
| 924 |
+
left_proprio = left_proprio[0].to(self._device)
|
| 925 |
+
|
| 926 |
+
obs, pcd = self._act_preprocess_inputs(observation)
|
| 927 |
+
|
| 928 |
+
# correct batch size and device
|
| 929 |
+
obs = [[o[0][0].to(self._device), o[1][0].to(self._device)] for o in obs]
|
| 930 |
+
|
| 931 |
+
pcd = [p[0].to(self._device) for p in pcd]
|
| 932 |
+
lang_goal_emb = lang_goal_emb.to(self._device)
|
| 933 |
+
lang_token_embs = lang_token_embs.to(self._device)
|
| 934 |
+
bounds = torch.as_tensor(bounds, device=self._device)
|
| 935 |
+
prev_layer_voxel_grid = (
|
| 936 |
+
prev_layer_voxel_grid.to(self._device)
|
| 937 |
+
if prev_layer_voxel_grid is not None
|
| 938 |
+
else None
|
| 939 |
+
)
|
| 940 |
+
prev_layer_bounds = (
|
| 941 |
+
prev_layer_bounds.to(self._device)
|
| 942 |
+
if prev_layer_bounds is not None
|
| 943 |
+
else None
|
| 944 |
+
)
|
| 945 |
+
|
| 946 |
+
proprio = torch.cat((right_proprio, left_proprio), dim=1)
|
| 947 |
+
# inference
|
| 948 |
+
(
|
| 949 |
+
right_q_trans,
|
| 950 |
+
right_q_rot_grip,
|
| 951 |
+
right_q_ignore_collisions,
|
| 952 |
+
left_q_trans,
|
| 953 |
+
left_q_rot_grip,
|
| 954 |
+
left_q_ignore_collisions,
|
| 955 |
+
), vox_grid = self._q(
|
| 956 |
+
obs,
|
| 957 |
+
proprio,
|
| 958 |
+
pcd,
|
| 959 |
+
lang_goal_emb,
|
| 960 |
+
lang_token_embs,
|
| 961 |
+
bounds,
|
| 962 |
+
prev_layer_bounds,
|
| 963 |
+
prev_layer_voxel_grid
|
| 964 |
+
)
|
| 965 |
+
|
| 966 |
+
# softmax Q predictions
|
| 967 |
+
right_q_trans = self._softmax_q_trans(right_q_trans)
|
| 968 |
+
left_q_trans = self._softmax_q_trans(left_q_trans)
|
| 969 |
+
|
| 970 |
+
if right_q_rot_grip is not None:
|
| 971 |
+
right_q_rot_grip = self._softmax_q_rot_grip(right_q_rot_grip)
|
| 972 |
+
|
| 973 |
+
if left_q_rot_grip is not None:
|
| 974 |
+
left_q_rot_grip = self._softmax_q_rot_grip(left_q_rot_grip)
|
| 975 |
+
|
| 976 |
+
if right_q_ignore_collisions is not None:
|
| 977 |
+
right_q_ignore_collisions = self._softmax_ignore_collision(
|
| 978 |
+
right_q_ignore_collisions
|
| 979 |
+
)
|
| 980 |
+
|
| 981 |
+
if left_q_ignore_collisions is not None:
|
| 982 |
+
left_q_ignore_collisions = self._softmax_ignore_collision(
|
| 983 |
+
left_q_ignore_collisions
|
| 984 |
+
)
|
| 985 |
+
|
| 986 |
+
# argmax Q predictions
|
| 987 |
+
(
|
| 988 |
+
right_coords,
|
| 989 |
+
right_rot_and_grip_indicies,
|
| 990 |
+
right_ignore_collisions,
|
| 991 |
+
) = self._q.choose_highest_action(
|
| 992 |
+
right_q_trans, right_q_rot_grip, right_q_ignore_collisions
|
| 993 |
+
)
|
| 994 |
+
(
|
| 995 |
+
left_coords,
|
| 996 |
+
left_rot_and_grip_indicies,
|
| 997 |
+
left_ignore_collisions,
|
| 998 |
+
) = self._q.choose_highest_action(
|
| 999 |
+
left_q_trans, left_q_rot_grip, left_q_ignore_collisions
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
if right_q_rot_grip is not None:
|
| 1003 |
+
right_rot_grip_action = right_rot_and_grip_indicies
|
| 1004 |
+
if right_q_ignore_collisions is not None:
|
| 1005 |
+
right_ignore_collisions_action = right_ignore_collisions.int()
|
| 1006 |
+
|
| 1007 |
+
if left_q_rot_grip is not None:
|
| 1008 |
+
left_rot_grip_action = left_rot_and_grip_indicies
|
| 1009 |
+
if left_q_ignore_collisions is not None:
|
| 1010 |
+
left_ignore_collisions_action = left_ignore_collisions.int()
|
| 1011 |
+
|
| 1012 |
+
right_coords = right_coords.int()
|
| 1013 |
+
left_coords = left_coords.int()
|
| 1014 |
+
|
| 1015 |
+
right_attention_coordinate = bounds[:, :3] + res * right_coords + res / 2
|
| 1016 |
+
left_attention_coordinate = bounds[:, :3] + res * left_coords + res / 2
|
| 1017 |
+
|
| 1018 |
+
# stack prev_layer_voxel_grid(s) into a list
|
| 1019 |
+
# NOTE: PerAct doesn't used multi-layer voxel grids like C2FARM
|
| 1020 |
+
if prev_layer_voxel_grid is None:
|
| 1021 |
+
prev_layer_voxel_grid = [vox_grid]
|
| 1022 |
+
else:
|
| 1023 |
+
prev_layer_voxel_grid = prev_layer_voxel_grid + [vox_grid]
|
| 1024 |
+
|
| 1025 |
+
if prev_layer_bounds is None:
|
| 1026 |
+
prev_layer_bounds = [bounds]
|
| 1027 |
+
else:
|
| 1028 |
+
prev_layer_bounds = prev_layer_bounds + [bounds]
|
| 1029 |
+
|
| 1030 |
+
observation_elements = {
|
| 1031 |
+
"right_attention_coordinate": right_attention_coordinate,
|
| 1032 |
+
"left_attention_coordinate": left_attention_coordinate,
|
| 1033 |
+
"prev_layer_voxel_grid": prev_layer_voxel_grid,
|
| 1034 |
+
"prev_layer_bounds": prev_layer_bounds,
|
| 1035 |
+
}
|
| 1036 |
+
info = {
|
| 1037 |
+
"voxel_grid_depth%d" % self._layer: vox_grid,
|
| 1038 |
+
"right_q_depth%d" % self._layer: right_q_trans,
|
| 1039 |
+
"right_voxel_idx_depth%d" % self._layer: right_coords,
|
| 1040 |
+
"left_q_depth%d" % self._layer: left_q_trans,
|
| 1041 |
+
"left_voxel_idx_depth%d" % self._layer: left_coords,
|
| 1042 |
+
}
|
| 1043 |
+
self._act_voxel_grid = vox_grid[0]
|
| 1044 |
+
self._right_act_max_coordinate = right_coords[0]
|
| 1045 |
+
self._right_act_qvalues = right_q_trans[0].detach()
|
| 1046 |
+
self._left_act_max_coordinate = left_coords[0]
|
| 1047 |
+
self._left_act_qvalues = left_q_trans[0].detach()
|
| 1048 |
+
|
| 1049 |
+
action = (
|
| 1050 |
+
right_coords,
|
| 1051 |
+
right_rot_grip_action,
|
| 1052 |
+
right_ignore_collisions,
|
| 1053 |
+
left_coords,
|
| 1054 |
+
left_rot_grip_action,
|
| 1055 |
+
left_ignore_collisions,
|
| 1056 |
+
)
|
| 1057 |
+
|
| 1058 |
+
return ActResult(action, observation_elements=observation_elements, info=info)
|
| 1059 |
+
|
| 1060 |
+
def update_summaries(self) -> List[Summary]:
|
| 1061 |
+
# voxel_grid = self._vis_voxel_grid.detach().cpu().numpy()
|
| 1062 |
+
summaries = []
|
| 1063 |
+
# summaries.append(
|
| 1064 |
+
# ImageSummary(
|
| 1065 |
+
# "%s/right_update_qattention" % self._name,
|
| 1066 |
+
# transforms.ToTensor()(
|
| 1067 |
+
# visualise_voxel(
|
| 1068 |
+
# voxel_grid,
|
| 1069 |
+
# self._right_vis_translation_qvalue.detach().cpu().numpy(),
|
| 1070 |
+
# self._right_vis_max_coordinate.detach().cpu().numpy(),
|
| 1071 |
+
# self._right_vis_gt_coordinate.detach().cpu().numpy(),
|
| 1072 |
+
# )
|
| 1073 |
+
# ),
|
| 1074 |
+
# )
|
| 1075 |
+
# )
|
| 1076 |
+
# summaries.append(
|
| 1077 |
+
# ImageSummary(
|
| 1078 |
+
# "%s/left_update_qattention" % self._name,
|
| 1079 |
+
# transforms.ToTensor()(
|
| 1080 |
+
# visualise_voxel(
|
| 1081 |
+
# voxel_grid,
|
| 1082 |
+
# self._left_vis_translation_qvalue.detach().cpu().numpy(),
|
| 1083 |
+
# self._left_vis_max_coordinate.detach().cpu().numpy(),
|
| 1084 |
+
# self._left_vis_gt_coordinate.detach().cpu().numpy(),
|
| 1085 |
+
# )
|
| 1086 |
+
# ),
|
| 1087 |
+
# )
|
| 1088 |
+
# )
|
| 1089 |
+
for n, v in self._summaries.items():
|
| 1090 |
+
summaries.append(ScalarSummary("%s/%s" % (self._name, n), v))
|
| 1091 |
+
|
| 1092 |
+
for name, crop in self._crop_summary:
|
| 1093 |
+
crops = (torch.cat(torch.split(crop, 3, dim=1), dim=3) + 1.0) / 2.0
|
| 1094 |
+
summaries.extend([ImageSummary("%s/crops/%s" % (self._name, name), crops)])
|
| 1095 |
+
|
| 1096 |
+
for tag, param in self._q.named_parameters():
|
| 1097 |
+
# assert not torch.isnan(param.grad.abs() <= 1.0).all()
|
| 1098 |
+
summaries.append(
|
| 1099 |
+
HistogramSummary("%s/gradient/%s" % (self._name, tag), param.grad)
|
| 1100 |
+
)
|
| 1101 |
+
summaries.append(
|
| 1102 |
+
HistogramSummary("%s/weight/%s" % (self._name, tag), param.data)
|
| 1103 |
+
)
|
| 1104 |
+
|
| 1105 |
+
return summaries
|
| 1106 |
+
|
| 1107 |
+
def update_wandb_summaries(self):
|
| 1108 |
+
summaries = dict()
|
| 1109 |
+
|
| 1110 |
+
for k, v in self._wandb_summaries.items():
|
| 1111 |
+
summaries[k] = v
|
| 1112 |
+
return summaries
|
| 1113 |
+
|
| 1114 |
+
def act_summaries(self) -> List[Summary]:
|
| 1115 |
+
# voxel_grid = self._act_voxel_grid.cpu().numpy()
|
| 1116 |
+
# right_q_attention = self._right_act_qvalues.cpu().numpy()
|
| 1117 |
+
# right_highlight_coordinate = self._right_act_max_coordinate.cpu().numpy()
|
| 1118 |
+
# right_visualization = visualise_voxel(
|
| 1119 |
+
# voxel_grid, right_q_attention, right_highlight_coordinate
|
| 1120 |
+
# )
|
| 1121 |
+
|
| 1122 |
+
# left_q_attention = self._left_act_qvalues.cpu().numpy()
|
| 1123 |
+
# left_highlight_coordinate = self._left_act_max_coordinate.cpu().numpy()
|
| 1124 |
+
# left_visualization = visualise_voxel(
|
| 1125 |
+
# voxel_grid, left_q_attention, left_highlight_coordinate
|
| 1126 |
+
# )
|
| 1127 |
+
|
| 1128 |
+
# return [
|
| 1129 |
+
# ImageSummary(
|
| 1130 |
+
# f"{self._name}/right_act_Qattention",
|
| 1131 |
+
# transforms.ToTensor()(right_visualization),
|
| 1132 |
+
# ),
|
| 1133 |
+
# ImageSummary(
|
| 1134 |
+
# f"{self._name}/left_act_Qattention",
|
| 1135 |
+
# transforms.ToTensor()(left_visualization),
|
| 1136 |
+
# ),
|
| 1137 |
+
# ]
|
| 1138 |
+
return []
|
| 1139 |
+
|
| 1140 |
+
def concat_weights(self, param, target_size, dims=-1):
|
| 1141 |
+
if param.size(-1) < target_size:
|
| 1142 |
+
param = torch.cat([param, param], dims)
|
| 1143 |
+
return param
|
| 1144 |
+
def load_weights(self, savedir: str):
|
| 1145 |
+
device = (
|
| 1146 |
+
self._device
|
| 1147 |
+
if not self._training
|
| 1148 |
+
else torch.device("cuda:%d" % self._device)
|
| 1149 |
+
)
|
| 1150 |
+
|
| 1151 |
+
weight_file = os.path.join(savedir, "%s.pt" % self._name)
|
| 1152 |
+
state_dict = torch.load(weight_file, map_location=device)
|
| 1153 |
+
merged_state_dict = self._q.state_dict()
|
| 1154 |
+
|
| 1155 |
+
if not self._training:
|
| 1156 |
+
for k, v in state_dict.items():
|
| 1157 |
+
if not self._training:
|
| 1158 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 1159 |
+
if k in merged_state_dict:
|
| 1160 |
+
merged_state_dict[k] = v
|
| 1161 |
+
else:
|
| 1162 |
+
if "_voxelizer" not in k:
|
| 1163 |
+
logging.warning("key %s not found in checkpoint" % k)
|
| 1164 |
+
else:
|
| 1165 |
+
for k, v in state_dict.items():
|
| 1166 |
+
if not self._training:
|
| 1167 |
+
k = k.replace("_qnet.module", "_qnet")
|
| 1168 |
+
# cross_attn
|
| 1169 |
+
if k.startswith("_qnet.module.decoder_cross_attn"):
|
| 1170 |
+
right_key = k.replace("_qnet.module.decoder_cross_attn", "_qnet.module.decoder_cross_attn_right")
|
| 1171 |
+
merged_state_dict[right_key] = v
|
| 1172 |
+
left_key = k.replace("_qnet.module.decoder_cross_attn", "_qnet.module.decoder_cross_attn_left")
|
| 1173 |
+
merged_state_dict[left_key] = v
|
| 1174 |
+
if self.anybimanual:
|
| 1175 |
+
if v.size(0) == 128:
|
| 1176 |
+
merged_state_dict[right_key] = self.concat_weights(v, 256, 0)
|
| 1177 |
+
merged_state_dict[left_key] = self.concat_weights(v, 256, 0)
|
| 1178 |
+
if v.size(-1) == 128:
|
| 1179 |
+
merged_state_dict[right_key] = self.concat_weights(v, 256)
|
| 1180 |
+
merged_state_dict[left_key] = self.concat_weights(v, 256)
|
| 1181 |
+
elif k == "_qnet.module.up0.conv_up.0.conv3d.weight":
|
| 1182 |
+
if self.anybimanual:
|
| 1183 |
+
if v.size(1) == 128:
|
| 1184 |
+
merged_state_dict[k] = self.concat_weights(v, 256, 1)
|
| 1185 |
+
else:
|
| 1186 |
+
merged_state_dict[k] = v
|
| 1187 |
+
# trans_decoder
|
| 1188 |
+
elif k.startswith("_qnet.module.trans_decoder"):
|
| 1189 |
+
right_key = k.replace("_qnet.module.trans_decoder", "_qnet.module.right_trans_decoder")
|
| 1190 |
+
merged_state_dict[right_key] = v
|
| 1191 |
+
|
| 1192 |
+
left_key = k.replace("_qnet.module.trans_decoder", "_qnet.module.left_trans_decoder")
|
| 1193 |
+
merged_state_dict[left_key] = v
|
| 1194 |
+
# dense0
|
| 1195 |
+
elif k.startswith("_qnet.module.dense0"):
|
| 1196 |
+
right_key = k.replace("_qnet.module.dense0", "_qnet.module.right_dense0")
|
| 1197 |
+
merged_state_dict[right_key] = v
|
| 1198 |
+
|
| 1199 |
+
left_key = k.replace("_qnet.module.dense0", "_qnet.module.left_dense0")
|
| 1200 |
+
merged_state_dict[left_key] = v
|
| 1201 |
+
if self.anybimanual:
|
| 1202 |
+
if v.size(-1) == 1024:
|
| 1203 |
+
merged_state_dict[right_key] = torch.cat([v, v[:, :512]], dim=-1)
|
| 1204 |
+
merged_state_dict[left_key] = torch.cat([v, v[:, :512]], dim=-1)
|
| 1205 |
+
# dense1
|
| 1206 |
+
elif k.startswith("_qnet.module.dense1"):
|
| 1207 |
+
right_key = k.replace("_qnet.module.dense1", "_qnet.module.right_dense1")
|
| 1208 |
+
merged_state_dict[right_key] = v
|
| 1209 |
+
|
| 1210 |
+
left_key = k.replace("_qnet.module.dense1", "_qnet.module.left_dense1")
|
| 1211 |
+
merged_state_dict[left_key] = v
|
| 1212 |
+
# collision
|
| 1213 |
+
elif k.startswith("_qnet.module.rot_grip_collision_ff"):
|
| 1214 |
+
right_key = k.replace("_qnet.module.rot_grip_collision_ff", "_qnet.module.right_rot_grip_collision_ff")
|
| 1215 |
+
merged_state_dict[right_key] = v
|
| 1216 |
+
|
| 1217 |
+
left_key = k.replace("_qnet.module.rot_grip_collision_ff", "_qnet.module.left_rot_grip_collision_ff")
|
| 1218 |
+
merged_state_dict[left_key] = v
|
| 1219 |
+
elif k.startswith("_qnet.module.cross_attend_blocks"):
|
| 1220 |
+
if self.anybimanual:
|
| 1221 |
+
if k.startswith("_qnet.module.cross_attend_blocks.0"):
|
| 1222 |
+
merged_state_dict[k] = v
|
| 1223 |
+
k_1 = k.replace("_qnet.module.cross_attend_blocks.0","_qnet.module.cross_attend_blocks.1")
|
| 1224 |
+
merged_state_dict[k_1] = v
|
| 1225 |
+
if self.anybimanual:
|
| 1226 |
+
if v.size(-1) == 128:
|
| 1227 |
+
merged_state_dict[k_1] = self.concat_weights(v, 256)
|
| 1228 |
+
else:
|
| 1229 |
+
k_2 = k.replace("_qnet.module.cross_attend_blocks.1","_qnet.module.cross_attend_blocks.2")
|
| 1230 |
+
k_3 = k.replace("_qnet.module.cross_attend_blocks.1","_qnet.module.cross_attend_blocks.3")
|
| 1231 |
+
merged_state_dict[k_2] = v
|
| 1232 |
+
merged_state_dict[k_3] = v
|
| 1233 |
+
if self.anybimanual:
|
| 1234 |
+
if v.size(-1) == 128:
|
| 1235 |
+
merged_state_dict[k_2] = self.concat_weights(v, 256)
|
| 1236 |
+
merged_state_dict[k_3] = self.concat_weights(v, 256)
|
| 1237 |
+
else:
|
| 1238 |
+
if k.startswith("_qnet.module.cross_attend_blocks.0"):
|
| 1239 |
+
merged_state_dict[k] = v
|
| 1240 |
+
else:
|
| 1241 |
+
merged_state_dict[k] = v
|
| 1242 |
+
k_2 = k.replace("_qnet.module.cross_attend_blocks.1","_qnet.module.cross_attend_blocks.2")
|
| 1243 |
+
merged_state_dict[k_2] = v
|
| 1244 |
+
if self.anybimanual:
|
| 1245 |
+
if v.size(-1) == 128:
|
| 1246 |
+
merged_state_dict[k_2] = self.concat_weights(v, 256)
|
| 1247 |
+
if self.anybimanual:
|
| 1248 |
+
if v.size(-1) == 128:
|
| 1249 |
+
merged_state_dict[k] = self.concat_weights(v, 256)
|
| 1250 |
+
# proprio
|
| 1251 |
+
elif k == '_qnet.module.proprio_preprocess.linear.weight':
|
| 1252 |
+
if v.shape[1] != 8:
|
| 1253 |
+
new_v = torch.cat([v,v], dim=1)
|
| 1254 |
+
merged_state_dict['_qnet.module.proprio_preprocess.linear.weight'] = new_v
|
| 1255 |
+
else:
|
| 1256 |
+
merged_state_dict[k] = v
|
| 1257 |
+
elif k == '_qnet.module.proprio_preprocess.linear.bias':
|
| 1258 |
+
merged_state_dict['_qnet.module.proprio_preprocess.linear.bias'] = v
|
| 1259 |
+
# pos_with_lang
|
| 1260 |
+
elif k == "_qnet.module.pos_encoding":
|
| 1261 |
+
if (v.shape[1] != 8077 or v.shape[1] != 8154) and v.shape[1] < 154:
|
| 1262 |
+
if self.anybimanual:
|
| 1263 |
+
lang_max_seq_len = 154
|
| 1264 |
+
else:
|
| 1265 |
+
lang_max_seq_len = 77
|
| 1266 |
+
spatial_size = v.shape[1]
|
| 1267 |
+
input_dim_before_seq = v.shape[-1]
|
| 1268 |
+
flattened_v = v.view(1, -1, input_dim_before_seq) # (1, spatial_size**3, self.input_dim_before_seq)
|
| 1269 |
+
new_pos_encoding = torch.randn(1, lang_max_seq_len, input_dim_before_seq, device=device)
|
| 1270 |
+
merged_pos_encoding = torch.cat([flattened_v, new_pos_encoding], dim=1) # (1, lang_max_seq_len + spatial_size**3, self.input_dim_before_seq)
|
| 1271 |
+
merged_state_dict["_qnet.module.pos_encoding"] = merged_pos_encoding
|
| 1272 |
+
else:
|
| 1273 |
+
merged_state_dict["_qnet.module.pos_encoding"] = v
|
| 1274 |
+
elif k in merged_state_dict:
|
| 1275 |
+
merged_state_dict[k] = v
|
| 1276 |
+
|
| 1277 |
+
# else:
|
| 1278 |
+
# if "_voxelizer" not in k:
|
| 1279 |
+
# logging.warning("key %s not found in checkpoint" % k)
|
| 1280 |
+
|
| 1281 |
+
|
| 1282 |
+
if not self._training:
|
| 1283 |
+
b = merged_state_dict["_voxelizer._ones_max_coords"].shape[0]
|
| 1284 |
+
merged_state_dict["_voxelizer._ones_max_coords"] = merged_state_dict[
|
| 1285 |
+
"_voxelizer._ones_max_coords"
|
| 1286 |
+
][0:1]
|
| 1287 |
+
flat_shape = merged_state_dict["_voxelizer._flat_output"].shape[0]
|
| 1288 |
+
merged_state_dict["_voxelizer._flat_output"] = merged_state_dict[
|
| 1289 |
+
"_voxelizer._flat_output"
|
| 1290 |
+
][0 : flat_shape // b]
|
| 1291 |
+
merged_state_dict["_voxelizer._tiled_batch_indices"] = merged_state_dict[
|
| 1292 |
+
"_voxelizer._tiled_batch_indices"
|
| 1293 |
+
][0:1]
|
| 1294 |
+
merged_state_dict["_voxelizer._index_grid"] = merged_state_dict[
|
| 1295 |
+
"_voxelizer._index_grid"
|
| 1296 |
+
][0:1]
|
| 1297 |
+
self._q.load_state_dict(merged_state_dict)
|
| 1298 |
+
|
| 1299 |
+
if self.frozen:
|
| 1300 |
+
print("Freezing parameters from PerAct")
|
| 1301 |
+
for name, param in self._q.named_parameters():
|
| 1302 |
+
if name in state_dict:
|
| 1303 |
+
param.requires_grad = False
|
| 1304 |
+
|
| 1305 |
+
logging.info(
|
| 1306 |
+
"# Q Params: %d"
|
| 1307 |
+
% sum(
|
| 1308 |
+
p.numel()
|
| 1309 |
+
for name, p in self._q.named_parameters()
|
| 1310 |
+
if p.requires_grad and "clip" not in name
|
| 1311 |
+
)
|
| 1312 |
+
)
|
| 1313 |
+
print("loaded weights from %s" % weight_file)
|
| 1314 |
+
|
| 1315 |
+
|
| 1316 |
+
def save_weights(self, savedir: str):
|
| 1317 |
+
torch.save(self._q.state_dict(), os.path.join(savedir, "%s.pt" % self._name))
|
third_party/AnyBimanual/agents/peract_bimanual/qattention_stack_agent.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from yarr.agents.agent import Agent, ActResult, Summary
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from helpers import utils
|
| 9 |
+
from agents.peract_bimanual.qattention_peract_bc_agent import QAttentionPerActBCAgent
|
| 10 |
+
|
| 11 |
+
NAME = "QAttentionStackAgent"
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class QAttentionStackAgent(Agent):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
qattention_agents: List[QAttentionPerActBCAgent],
|
| 18 |
+
rotation_resolution: float,
|
| 19 |
+
camera_names: List[str],
|
| 20 |
+
rotation_prediction_depth: int = 0,
|
| 21 |
+
):
|
| 22 |
+
super(QAttentionStackAgent, self).__init__()
|
| 23 |
+
self._qattention_agents = qattention_agents
|
| 24 |
+
self._rotation_resolution = rotation_resolution
|
| 25 |
+
self._camera_names = camera_names
|
| 26 |
+
self._rotation_prediction_depth = rotation_prediction_depth
|
| 27 |
+
|
| 28 |
+
def build(self, training: bool, device=None) -> None:
|
| 29 |
+
self._device = device
|
| 30 |
+
if self._device is None:
|
| 31 |
+
self._device = torch.device("cpu")
|
| 32 |
+
for qa in self._qattention_agents:
|
| 33 |
+
qa.build(training, device)
|
| 34 |
+
|
| 35 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 36 |
+
priorities = 0
|
| 37 |
+
total_losses = 0.0
|
| 38 |
+
for qa in self._qattention_agents:
|
| 39 |
+
update_dict = qa.update(step, replay_sample)
|
| 40 |
+
replay_sample.update(update_dict)
|
| 41 |
+
total_losses += update_dict["total_loss"]
|
| 42 |
+
return {
|
| 43 |
+
"total_losses": total_losses,
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
def act(self, step: int, observation: dict, deterministic=False) -> ActResult:
|
| 47 |
+
observation_elements = {}
|
| 48 |
+
(
|
| 49 |
+
right_translation_results,
|
| 50 |
+
right_rot_grip_results,
|
| 51 |
+
right_ignore_collisions_results,
|
| 52 |
+
) = ([], [], [])
|
| 53 |
+
(
|
| 54 |
+
left_translation_results,
|
| 55 |
+
left_rot_grip_results,
|
| 56 |
+
left_ignore_collisions_results,
|
| 57 |
+
) = ([], [], [])
|
| 58 |
+
|
| 59 |
+
infos = {}
|
| 60 |
+
for depth, qagent in enumerate(self._qattention_agents):
|
| 61 |
+
act_results = qagent.act(step, observation, deterministic)
|
| 62 |
+
right_attention_coordinate = (
|
| 63 |
+
act_results.observation_elements["right_attention_coordinate"]
|
| 64 |
+
.cpu()
|
| 65 |
+
.numpy()
|
| 66 |
+
)
|
| 67 |
+
left_attention_coordinate = (
|
| 68 |
+
act_results.observation_elements["left_attention_coordinate"]
|
| 69 |
+
.cpu()
|
| 70 |
+
.numpy()
|
| 71 |
+
)
|
| 72 |
+
observation_elements[
|
| 73 |
+
"right_attention_coordinate_layer_%d" % depth
|
| 74 |
+
] = right_attention_coordinate[0]
|
| 75 |
+
observation_elements[
|
| 76 |
+
"left_attention_coordinate_layer_%d" % depth
|
| 77 |
+
] = left_attention_coordinate[0]
|
| 78 |
+
|
| 79 |
+
(
|
| 80 |
+
right_translation_idxs,
|
| 81 |
+
right_rot_grip_idxs,
|
| 82 |
+
right_ignore_collisions_idxs,
|
| 83 |
+
left_translation_idxs,
|
| 84 |
+
left_rot_grip_idxs,
|
| 85 |
+
left_ignore_collisions_idxs,
|
| 86 |
+
) = act_results.action
|
| 87 |
+
|
| 88 |
+
right_translation_results.append(right_translation_idxs)
|
| 89 |
+
if right_rot_grip_idxs is not None:
|
| 90 |
+
right_rot_grip_results.append(right_rot_grip_idxs)
|
| 91 |
+
if right_ignore_collisions_idxs is not None:
|
| 92 |
+
right_ignore_collisions_results.append(right_ignore_collisions_idxs)
|
| 93 |
+
|
| 94 |
+
left_translation_results.append(left_translation_idxs)
|
| 95 |
+
if left_rot_grip_idxs is not None:
|
| 96 |
+
left_rot_grip_results.append(left_rot_grip_idxs)
|
| 97 |
+
if left_ignore_collisions_idxs is not None:
|
| 98 |
+
left_ignore_collisions_results.append(left_ignore_collisions_idxs)
|
| 99 |
+
|
| 100 |
+
observation[
|
| 101 |
+
"right_attention_coordinate"
|
| 102 |
+
] = act_results.observation_elements["right_attention_coordinate"]
|
| 103 |
+
observation["left_attention_coordinate"] = act_results.observation_elements[
|
| 104 |
+
"left_attention_coordinate"
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
+
observation["prev_layer_voxel_grid"] = act_results.observation_elements[
|
| 108 |
+
"prev_layer_voxel_grid"
|
| 109 |
+
]
|
| 110 |
+
observation["prev_layer_bounds"] = act_results.observation_elements[
|
| 111 |
+
"prev_layer_bounds"
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
for n in self._camera_names:
|
| 115 |
+
extrinsics = observation["%s_camera_extrinsics" % n][0, 0].cpu().numpy()
|
| 116 |
+
intrinsics = observation["%s_camera_intrinsics" % n][0, 0].cpu().numpy()
|
| 117 |
+
px, py = utils.point_to_pixel_index(
|
| 118 |
+
right_attention_coordinate[0], extrinsics, intrinsics
|
| 119 |
+
)
|
| 120 |
+
pc_t = torch.tensor(
|
| 121 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 122 |
+
)
|
| 123 |
+
observation[f"right_{n}_pixel_coord"] = pc_t
|
| 124 |
+
observation_elements[f"right_{n}_pixel_coord"] = [py, px]
|
| 125 |
+
|
| 126 |
+
px, py = utils.point_to_pixel_index(
|
| 127 |
+
left_attention_coordinate[0], extrinsics, intrinsics
|
| 128 |
+
)
|
| 129 |
+
pc_t = torch.tensor(
|
| 130 |
+
[[[py, px]]], dtype=torch.float32, device=self._device
|
| 131 |
+
)
|
| 132 |
+
observation[f"left_{n}_pixel_coord"] = pc_t
|
| 133 |
+
observation_elements[f"left_{n}_pixel_coord"] = [py, px]
|
| 134 |
+
infos.update(act_results.info)
|
| 135 |
+
|
| 136 |
+
right_rgai = torch.cat(right_rot_grip_results, 1)[0].cpu().numpy()
|
| 137 |
+
# ..todo:: utils.correct_rotation_instability does nothing so we can ignore it
|
| 138 |
+
# right_rgai = utils.correct_rotation_instability(right_rgai, self._rotation_resolution)
|
| 139 |
+
right_ignore_collisions = (
|
| 140 |
+
torch.cat(right_ignore_collisions_results, 1)[0].cpu().numpy()
|
| 141 |
+
)
|
| 142 |
+
right_trans_action_indicies = (
|
| 143 |
+
torch.cat(right_translation_results, 1)[0].cpu().numpy()
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
observation_elements[
|
| 147 |
+
"right_trans_action_indicies"
|
| 148 |
+
] = right_trans_action_indicies[:3]
|
| 149 |
+
observation_elements["right_rot_grip_action_indicies"] = right_rgai[:4]
|
| 150 |
+
|
| 151 |
+
left_rgai = torch.cat(left_rot_grip_results, 1)[0].cpu().numpy()
|
| 152 |
+
left_ignore_collisions = (
|
| 153 |
+
torch.cat(left_ignore_collisions_results, 1)[0].cpu().numpy()
|
| 154 |
+
)
|
| 155 |
+
left_trans_action_indicies = (
|
| 156 |
+
torch.cat(left_translation_results, 1)[0].cpu().numpy()
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
observation_elements["left_trans_action_indicies"] = left_trans_action_indicies[
|
| 160 |
+
3:
|
| 161 |
+
]
|
| 162 |
+
observation_elements["left_rot_grip_action_indicies"] = left_rgai[4:]
|
| 163 |
+
|
| 164 |
+
continuous_action = np.concatenate(
|
| 165 |
+
[
|
| 166 |
+
right_attention_coordinate[0],
|
| 167 |
+
utils.discrete_euler_to_quaternion(
|
| 168 |
+
right_rgai[-4:-1], self._rotation_resolution
|
| 169 |
+
),
|
| 170 |
+
right_rgai[-1:],
|
| 171 |
+
right_ignore_collisions,
|
| 172 |
+
left_attention_coordinate[0],
|
| 173 |
+
utils.discrete_euler_to_quaternion(
|
| 174 |
+
left_rgai[-4:-1], self._rotation_resolution
|
| 175 |
+
),
|
| 176 |
+
left_rgai[-1:],
|
| 177 |
+
left_ignore_collisions,
|
| 178 |
+
]
|
| 179 |
+
)
|
| 180 |
+
return ActResult(
|
| 181 |
+
continuous_action, observation_elements=observation_elements, info=infos
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
def update_summaries(self) -> List[Summary]:
|
| 185 |
+
summaries = []
|
| 186 |
+
for qa in self._qattention_agents:
|
| 187 |
+
summaries.extend(qa.update_summaries())
|
| 188 |
+
return summaries
|
| 189 |
+
|
| 190 |
+
def update_wandb_summaries(self):
|
| 191 |
+
summaries = {}
|
| 192 |
+
for qa in self._qattention_agents:
|
| 193 |
+
summaries.update(qa.update_wandb_summaries())
|
| 194 |
+
return summaries
|
| 195 |
+
def act_summaries(self) -> List[Summary]:
|
| 196 |
+
s = []
|
| 197 |
+
for qa in self._qattention_agents:
|
| 198 |
+
s.extend(qa.act_summaries())
|
| 199 |
+
return s
|
| 200 |
+
|
| 201 |
+
def load_weights(self, savedir: str):
|
| 202 |
+
for qa in self._qattention_agents:
|
| 203 |
+
print(dir(qa))
|
| 204 |
+
qa.load_weights(savedir)
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def save_weights(self, savedir: str):
|
| 208 |
+
for qa in self._qattention_agents:
|
| 209 |
+
qa.save_weights(savedir)
|
third_party/AnyBimanual/agents/peract_bimanual/skill_manager.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import transformers
|
| 4 |
+
from agents.peract_bimanual.trajectory_gpt2 import GPT2Model
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
class SkillManager(nn.Module):
|
| 7 |
+
def __init__(
|
| 8 |
+
self,
|
| 9 |
+
num_classes,
|
| 10 |
+
embedding_matrix=None,
|
| 11 |
+
voxel_dim=128,
|
| 12 |
+
lang_dim=128,
|
| 13 |
+
hidden_size=128,
|
| 14 |
+
output_dim=18,
|
| 15 |
+
max_voxels=8000,
|
| 16 |
+
max_lang_tokens=77,
|
| 17 |
+
**kwargs):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.hidden_size = hidden_size
|
| 21 |
+
self.output_dim = output_dim
|
| 22 |
+
|
| 23 |
+
# GPT-2 configuration
|
| 24 |
+
config = transformers.GPT2Config(
|
| 25 |
+
vocab_size=1, # not used
|
| 26 |
+
n_embd=hidden_size,
|
| 27 |
+
n_head=4,
|
| 28 |
+
n_ctx=1077,
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
self.max_voxels = max_voxels
|
| 32 |
+
self.max_lang_tokens = max_lang_tokens
|
| 33 |
+
self.embed_voxel = nn.Linear(voxel_dim, hidden_size)
|
| 34 |
+
self.embed_lang = nn.Linear(lang_dim, hidden_size)
|
| 35 |
+
self.transformer = GPT2Model(config)
|
| 36 |
+
self.embed_ln = nn.LayerNorm(hidden_size)
|
| 37 |
+
self.predict_logits = nn.Linear(hidden_size, output_dim)
|
| 38 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 39 |
+
self.num_class = num_classes
|
| 40 |
+
if embedding_matrix is not None:
|
| 41 |
+
self.embeddings_matrix = embedding_matrix.to(self.device)
|
| 42 |
+
|
| 43 |
+
def forward(self, voxel_embedding, language_embedding):
|
| 44 |
+
batch_size = voxel_embedding.shape[0]
|
| 45 |
+
voxel_embeddings = self.embed_voxel(voxel_embedding) # [b, 8000, hidden_size]
|
| 46 |
+
language_embeddings = self.embed_lang(language_embedding) # [b, 77, hidden_size]
|
| 47 |
+
voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, hidden_size, 8000]
|
| 48 |
+
voxel_embeddings = F.avg_pool1d(voxel_embeddings, kernel_size=16, stride=16) # [b, hidden_size, 1000]
|
| 49 |
+
voxel_embeddings = voxel_embeddings.permute(0, 2, 1) # [b, 1000, hidden_size]
|
| 50 |
+
inputs = torch.cat([language_embeddings, voxel_embeddings], dim=1) # [b, 8077, hidden_size]
|
| 51 |
+
stacked_inputs = self.embed_ln(inputs)
|
| 52 |
+
attention_mask = torch.ones(
|
| 53 |
+
(batch_size, self.max_lang_tokens + self.max_voxels),
|
| 54 |
+
device=voxel_embedding.device,
|
| 55 |
+
dtype=torch.long # Ensure correct dtype
|
| 56 |
+
)
|
| 57 |
+
assert torch.isfinite(attention_mask).all(), "attention_mask contains NaN or Inf"
|
| 58 |
+
assert torch.all((attention_mask == 1)), "attention_mask contains values not equal to 1"
|
| 59 |
+
transformer_outputs = self.transformer(
|
| 60 |
+
inputs_embeds=stacked_inputs,
|
| 61 |
+
attention_mask=None,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
hidden_state = transformer_outputs.last_hidden_state # [b, 8077, hidden_size]
|
| 65 |
+
aggregated_hidden = hidden_state.mean(dim=1) # [b, hidden_size]
|
| 66 |
+
logits = self.predict_logits(aggregated_hidden) # [b, output_dim]
|
| 67 |
+
probs = F.softmax(logits, dim=1)
|
| 68 |
+
skill = torch.matmul(probs, self.embeddings_matrix.to(probs.device))
|
| 69 |
+
skill = skill.view(-1,77,512)
|
| 70 |
+
return skill
|
third_party/AnyBimanual/agents/peract_bimanual/trajectory_gpt2.py
ADDED
|
@@ -0,0 +1,775 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The OpenAI Team Authors and HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""PyTorch OpenAI GPT-2 model."""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
from dataclasses import dataclass
|
| 20 |
+
from typing import List, Optional, Tuple
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
from torch.nn import CrossEntropyLoss, MSELoss
|
| 25 |
+
|
| 26 |
+
from transformers.activations import ACT2FN
|
| 27 |
+
from transformers.file_utils import (
|
| 28 |
+
ModelOutput,
|
| 29 |
+
add_code_sample_docstrings,
|
| 30 |
+
add_start_docstrings,
|
| 31 |
+
add_start_docstrings_to_model_forward,
|
| 32 |
+
replace_return_docstrings,
|
| 33 |
+
)
|
| 34 |
+
from transformers.modeling_outputs import (
|
| 35 |
+
BaseModelOutputWithPastAndCrossAttentions,
|
| 36 |
+
)
|
| 37 |
+
from transformers.modeling_utils import (
|
| 38 |
+
Conv1D,
|
| 39 |
+
PreTrainedModel,
|
| 40 |
+
SequenceSummary,
|
| 41 |
+
find_pruneable_heads_and_indices,
|
| 42 |
+
prune_conv1d_layer,
|
| 43 |
+
)
|
| 44 |
+
from transformers.utils import logging
|
| 45 |
+
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
|
| 46 |
+
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
| 47 |
+
|
| 48 |
+
logger = logging.get_logger(__name__)
|
| 49 |
+
|
| 50 |
+
_CONFIG_FOR_DOC = "GPT2Config"
|
| 51 |
+
_TOKENIZER_FOR_DOC = "GPT2Tokenizer"
|
| 52 |
+
|
| 53 |
+
GPT2_PRETRAINED_MODEL_ARCHIVE_LIST = [
|
| 54 |
+
"gpt2",
|
| 55 |
+
"gpt2-medium",
|
| 56 |
+
"gpt2-large",
|
| 57 |
+
"gpt2-xl",
|
| 58 |
+
"distilgpt2",
|
| 59 |
+
# See all GPT-2 models at https://huggingface.co/models?filter=gpt2
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
|
| 64 |
+
"""Load tf checkpoints in a pytorch model"""
|
| 65 |
+
try:
|
| 66 |
+
import re
|
| 67 |
+
|
| 68 |
+
import tensorflow as tf
|
| 69 |
+
except ImportError:
|
| 70 |
+
logger.error(
|
| 71 |
+
"Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
|
| 72 |
+
"https://www.tensorflow.org/install/ for installation instructions."
|
| 73 |
+
)
|
| 74 |
+
raise
|
| 75 |
+
tf_path = os.path.abspath(gpt2_checkpoint_path)
|
| 76 |
+
logger.info("Converting TensorFlow checkpoint from {}".format(tf_path))
|
| 77 |
+
# Load weights from TF model
|
| 78 |
+
init_vars = tf.train.list_variables(tf_path)
|
| 79 |
+
names = []
|
| 80 |
+
arrays = []
|
| 81 |
+
for name, shape in init_vars:
|
| 82 |
+
logger.info("Loading TF weight {} with shape {}".format(name, shape))
|
| 83 |
+
array = tf.train.load_variable(tf_path, name)
|
| 84 |
+
names.append(name)
|
| 85 |
+
arrays.append(array.squeeze())
|
| 86 |
+
|
| 87 |
+
for name, array in zip(names, arrays):
|
| 88 |
+
name = name[6:] # skip "model/"
|
| 89 |
+
name = name.split("/")
|
| 90 |
+
pointer = model
|
| 91 |
+
for m_name in name:
|
| 92 |
+
if re.fullmatch(r"[A-Za-z]+\d+", m_name):
|
| 93 |
+
scope_names = re.split(r"(\d+)", m_name)
|
| 94 |
+
else:
|
| 95 |
+
scope_names = [m_name]
|
| 96 |
+
if scope_names[0] == "w" or scope_names[0] == "g":
|
| 97 |
+
pointer = getattr(pointer, "weight")
|
| 98 |
+
elif scope_names[0] == "b":
|
| 99 |
+
pointer = getattr(pointer, "bias")
|
| 100 |
+
elif scope_names[0] == "wpe" or scope_names[0] == "wte":
|
| 101 |
+
pointer = getattr(pointer, scope_names[0])
|
| 102 |
+
pointer = getattr(pointer, "weight")
|
| 103 |
+
else:
|
| 104 |
+
pointer = getattr(pointer, scope_names[0])
|
| 105 |
+
if len(scope_names) >= 2:
|
| 106 |
+
num = int(scope_names[1])
|
| 107 |
+
pointer = pointer[num]
|
| 108 |
+
try:
|
| 109 |
+
assert (
|
| 110 |
+
pointer.shape == array.shape
|
| 111 |
+
), f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
|
| 112 |
+
except AssertionError as e:
|
| 113 |
+
e.args += (pointer.shape, array.shape)
|
| 114 |
+
raise
|
| 115 |
+
logger.info("Initialize PyTorch weight {}".format(name))
|
| 116 |
+
pointer.data = torch.from_numpy(array)
|
| 117 |
+
return model
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
class Attention(nn.Module):
|
| 121 |
+
def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
|
| 122 |
+
super().__init__()
|
| 123 |
+
|
| 124 |
+
n_state = nx # in Attention: n_state=768 (nx=n_embd)
|
| 125 |
+
# [switch nx => n_state from Block to Attention to keep identical to TF implem]
|
| 126 |
+
assert n_state % config.n_head == 0
|
| 127 |
+
self.register_buffer(
|
| 128 |
+
"bias", torch.tril(torch.ones((n_ctx, n_ctx), dtype=torch.uint8)).view(1, 1, n_ctx, n_ctx)
|
| 129 |
+
)
|
| 130 |
+
self.register_buffer("masked_bias", torch.tensor(-1e4))
|
| 131 |
+
self.n_head = config.n_head
|
| 132 |
+
self.split_size = n_state
|
| 133 |
+
self.scale = scale
|
| 134 |
+
self.is_cross_attention = is_cross_attention
|
| 135 |
+
if self.is_cross_attention:
|
| 136 |
+
self.c_attn = Conv1D(2 * n_state, nx)
|
| 137 |
+
self.q_attn = Conv1D(n_state, nx)
|
| 138 |
+
else:
|
| 139 |
+
self.c_attn = Conv1D(3 * n_state, nx)
|
| 140 |
+
self.c_proj = Conv1D(n_state, nx)
|
| 141 |
+
self.attn_dropout = nn.Dropout(config.attn_pdrop)
|
| 142 |
+
self.resid_dropout = nn.Dropout(config.resid_pdrop)
|
| 143 |
+
self.pruned_heads = set()
|
| 144 |
+
|
| 145 |
+
def prune_heads(self, heads):
|
| 146 |
+
if len(heads) == 0:
|
| 147 |
+
return
|
| 148 |
+
heads, index = find_pruneable_heads_and_indices(
|
| 149 |
+
heads, self.n_head, self.split_size // self.n_head, self.pruned_heads
|
| 150 |
+
)
|
| 151 |
+
index_attn = torch.cat([index, index + self.split_size, index + (2 * self.split_size)])
|
| 152 |
+
|
| 153 |
+
# Prune conv1d layers
|
| 154 |
+
self.c_attn = prune_conv1d_layer(self.c_attn, index_attn, dim=1)
|
| 155 |
+
self.c_proj = prune_conv1d_layer(self.c_proj, index, dim=0)
|
| 156 |
+
|
| 157 |
+
# Update hyper params
|
| 158 |
+
self.split_size = (self.split_size // self.n_head) * (self.n_head - len(heads))
|
| 159 |
+
self.n_head = self.n_head - len(heads)
|
| 160 |
+
self.pruned_heads = self.pruned_heads.union(heads)
|
| 161 |
+
|
| 162 |
+
def _attn(self, q, k, v, attention_mask=None, head_mask=None, output_attentions=False):
|
| 163 |
+
w = torch.matmul(q, k)
|
| 164 |
+
if self.scale:
|
| 165 |
+
w = w / (float(v.size(-1)) ** 0.5)
|
| 166 |
+
nd, ns = w.size(-2), w.size(-1)
|
| 167 |
+
|
| 168 |
+
if not self.is_cross_attention:
|
| 169 |
+
# if only "normal" attention layer implements causal mask
|
| 170 |
+
mask = self.bias[:, :, ns - nd: ns, :ns]
|
| 171 |
+
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
|
| 172 |
+
|
| 173 |
+
if attention_mask is not None:
|
| 174 |
+
# Apply the attention mask
|
| 175 |
+
w = w + attention_mask
|
| 176 |
+
|
| 177 |
+
w = nn.Softmax(dim=-1)(w)
|
| 178 |
+
w = self.attn_dropout(w)
|
| 179 |
+
|
| 180 |
+
# Mask heads if we want to
|
| 181 |
+
if head_mask is not None:
|
| 182 |
+
w = w * head_mask
|
| 183 |
+
|
| 184 |
+
outputs = [torch.matmul(w, v)]
|
| 185 |
+
if output_attentions:
|
| 186 |
+
outputs.append(w)
|
| 187 |
+
return outputs
|
| 188 |
+
|
| 189 |
+
def merge_heads(self, x):
|
| 190 |
+
x = x.permute(0, 2, 1, 3).contiguous()
|
| 191 |
+
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),)
|
| 192 |
+
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states
|
| 193 |
+
|
| 194 |
+
def split_heads(self, x, k=False):
|
| 195 |
+
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head)
|
| 196 |
+
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states
|
| 197 |
+
if k:
|
| 198 |
+
return x.permute(0, 2, 3, 1) # (batch, head, head_features, seq_length)
|
| 199 |
+
else:
|
| 200 |
+
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
|
| 201 |
+
|
| 202 |
+
def forward(
|
| 203 |
+
self,
|
| 204 |
+
hidden_states,
|
| 205 |
+
layer_past=None,
|
| 206 |
+
attention_mask=None,
|
| 207 |
+
head_mask=None,
|
| 208 |
+
encoder_hidden_states=None,
|
| 209 |
+
encoder_attention_mask=None,
|
| 210 |
+
use_cache=False,
|
| 211 |
+
output_attentions=False,
|
| 212 |
+
):
|
| 213 |
+
if encoder_hidden_states is not None:
|
| 214 |
+
assert hasattr(
|
| 215 |
+
self, "q_attn"
|
| 216 |
+
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
|
| 217 |
+
query = self.q_attn(hidden_states)
|
| 218 |
+
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
|
| 219 |
+
attention_mask = encoder_attention_mask
|
| 220 |
+
else:
|
| 221 |
+
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
|
| 222 |
+
|
| 223 |
+
query = self.split_heads(query)
|
| 224 |
+
key = self.split_heads(key, k=True)
|
| 225 |
+
value = self.split_heads(value)
|
| 226 |
+
if layer_past is not None:
|
| 227 |
+
past_key, past_value = layer_past[0].transpose(-2, -1), layer_past[1] # transpose back cf below
|
| 228 |
+
key = torch.cat((past_key, key), dim=-1)
|
| 229 |
+
value = torch.cat((past_value, value), dim=-2)
|
| 230 |
+
|
| 231 |
+
if use_cache is True:
|
| 232 |
+
present = torch.stack((key.transpose(-2, -1), value)) # transpose to have same shapes for stacking
|
| 233 |
+
else:
|
| 234 |
+
present = (None,)
|
| 235 |
+
|
| 236 |
+
attn_outputs = self._attn(query, key, value, attention_mask, head_mask, output_attentions)
|
| 237 |
+
a = attn_outputs[0]
|
| 238 |
+
|
| 239 |
+
a = self.merge_heads(a)
|
| 240 |
+
a = self.c_proj(a)
|
| 241 |
+
a = self.resid_dropout(a)
|
| 242 |
+
|
| 243 |
+
outputs = [a, present] + attn_outputs[1:]
|
| 244 |
+
return outputs # a, present, (attentions)
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class MLP(nn.Module):
|
| 248 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
| 249 |
+
super().__init__()
|
| 250 |
+
nx = config.n_embd
|
| 251 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 252 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 253 |
+
self.act = ACT2FN[config.activation_function]
|
| 254 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 255 |
+
|
| 256 |
+
def forward(self, x):
|
| 257 |
+
h = self.act(self.c_fc(x))
|
| 258 |
+
h2 = self.c_proj(h)
|
| 259 |
+
return self.dropout(h2)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
class AdapterMLP(nn.Module):
|
| 263 |
+
def __init__(self, n_state, config): # in MLP: n_state=3072 (4 * n_embd)
|
| 264 |
+
super().__init__()
|
| 265 |
+
nx = config.n_embd
|
| 266 |
+
self.c_fc = Conv1D(n_state, nx)
|
| 267 |
+
self.c_proj = Conv1D(nx, n_state)
|
| 268 |
+
self.act = ACT2FN[config.activation_function]
|
| 269 |
+
self.dropout = nn.Dropout(config.resid_pdrop)
|
| 270 |
+
|
| 271 |
+
def forward(self, x):
|
| 272 |
+
h = self.act(self.c_fc(x))
|
| 273 |
+
h2 = self.c_proj(h)
|
| 274 |
+
return self.dropout(h2)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class Block(nn.Module):
|
| 278 |
+
def __init__(self, n_ctx, config, scale=False):
|
| 279 |
+
super().__init__()
|
| 280 |
+
hidden_size = config.n_embd
|
| 281 |
+
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
| 282 |
+
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 283 |
+
self.attn = Attention(hidden_size, n_ctx, config, scale)
|
| 284 |
+
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 285 |
+
# self.adapter_ln = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 286 |
+
if config.add_cross_attention:
|
| 287 |
+
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
|
| 288 |
+
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
| 289 |
+
self.mlp = MLP(inner_dim, config)
|
| 290 |
+
# self.adapter_mlp = AdapterMLP(512, config) # ADAPTER
|
| 291 |
+
|
| 292 |
+
def forward(
|
| 293 |
+
self,
|
| 294 |
+
hidden_states,
|
| 295 |
+
layer_past=None,
|
| 296 |
+
attention_mask=None,
|
| 297 |
+
head_mask=None,
|
| 298 |
+
encoder_hidden_states=None,
|
| 299 |
+
encoder_attention_mask=None,
|
| 300 |
+
use_cache=False,
|
| 301 |
+
output_attentions=False,
|
| 302 |
+
):
|
| 303 |
+
attn_outputs = self.attn(
|
| 304 |
+
self.ln_1(hidden_states),
|
| 305 |
+
layer_past=layer_past,
|
| 306 |
+
attention_mask=attention_mask,
|
| 307 |
+
head_mask=head_mask,
|
| 308 |
+
use_cache=use_cache,
|
| 309 |
+
output_attentions=output_attentions,
|
| 310 |
+
)
|
| 311 |
+
attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
|
| 312 |
+
outputs = attn_outputs[1:]
|
| 313 |
+
# residual connection
|
| 314 |
+
hidden_states = attn_output + hidden_states
|
| 315 |
+
|
| 316 |
+
if encoder_hidden_states is not None:
|
| 317 |
+
# add one self-attention block for cross-attention
|
| 318 |
+
assert hasattr(
|
| 319 |
+
self, "crossattention"
|
| 320 |
+
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
|
| 321 |
+
cross_attn_outputs = self.crossattention(
|
| 322 |
+
self.ln_cross_attn(hidden_states),
|
| 323 |
+
attention_mask=attention_mask,
|
| 324 |
+
head_mask=head_mask,
|
| 325 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 326 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 327 |
+
output_attentions=output_attentions,
|
| 328 |
+
)
|
| 329 |
+
attn_output = cross_attn_outputs[0]
|
| 330 |
+
# residual connection
|
| 331 |
+
hidden_states = hidden_states + attn_output
|
| 332 |
+
outputs = outputs + cross_attn_outputs[2:] # add cross attentions if we output attention weights
|
| 333 |
+
|
| 334 |
+
feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
|
| 335 |
+
# residual connection
|
| 336 |
+
hidden_states = hidden_states + feed_forward_hidden_states
|
| 337 |
+
# hidden_states = hidden_states + self.adapter_ln(self.adapter_mlp(hidden_states))
|
| 338 |
+
|
| 339 |
+
outputs = [hidden_states] + outputs
|
| 340 |
+
return outputs # hidden_states, present, (attentions, cross_attentions)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
class GPT2PreTrainedModel(PreTrainedModel):
|
| 344 |
+
"""
|
| 345 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
| 346 |
+
models.
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
config_class = GPT2Config
|
| 350 |
+
load_tf_weights = load_tf_weights_in_gpt2
|
| 351 |
+
base_model_prefix = "transformer"
|
| 352 |
+
|
| 353 |
+
def __init__(self, *inputs, **kwargs):
|
| 354 |
+
super().__init__(*inputs, **kwargs)
|
| 355 |
+
|
| 356 |
+
def _init_weights(self, module):
|
| 357 |
+
"""Initialize the weights."""
|
| 358 |
+
if isinstance(module, (nn.Linear, nn.Embedding, Conv1D)):
|
| 359 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
| 360 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
| 361 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
| 362 |
+
if isinstance(module, (nn.Linear, Conv1D)) and module.bias is not None:
|
| 363 |
+
module.bias.data.zero_()
|
| 364 |
+
elif isinstance(module, nn.LayerNorm):
|
| 365 |
+
module.bias.data.zero_()
|
| 366 |
+
module.weight.data.fill_(1.0)
|
| 367 |
+
# module.weight.data.fill_(.01) # KL: Adapter change
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
@dataclass
|
| 371 |
+
class GPT2DoubleHeadsModelOutput(ModelOutput):
|
| 372 |
+
"""
|
| 373 |
+
Base class for outputs of models predicting if two sentences are consecutive or not.
|
| 374 |
+
Args:
|
| 375 |
+
loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when ``labels`` is provided):
|
| 376 |
+
Language modeling loss.
|
| 377 |
+
mc_loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`mc_labels` is provided):
|
| 378 |
+
Multiple choice classification loss.
|
| 379 |
+
logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices, sequence_length, config.vocab_size)`):
|
| 380 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
| 381 |
+
mc_logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, num_choices)`):
|
| 382 |
+
Prediction scores of the multiple choice classification head (scores for each choice before SoftMax).
|
| 383 |
+
past_key_values (:obj:`List[torch.FloatTensor]`, `optional`, returned when ``use_cache=True`` is passed or when ``config.use_cache=True``):
|
| 384 |
+
List of :obj:`torch.FloatTensor` of length :obj:`config.n_layers`, with each tensor of shape :obj:`(2,
|
| 385 |
+
batch_size, num_heads, sequence_length, embed_size_per_head)`).
|
| 386 |
+
Contains pre-computed hidden-states (key and values in the attention blocks) that can be used (see
|
| 387 |
+
:obj:`past_key_values` input) to speed up sequential decoding.
|
| 388 |
+
hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_hidden_states=True`` is passed or when ``config.output_hidden_states=True``):
|
| 389 |
+
Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
|
| 390 |
+
of shape :obj:`(batch_size, sequence_length, hidden_size)`.
|
| 391 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
| 392 |
+
attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``output_attentions=True`` is passed or when ``config.output_attentions=True``):
|
| 393 |
+
Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape :obj:`(batch_size, num_heads,
|
| 394 |
+
sequence_length, sequence_length)`.
|
| 395 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
| 396 |
+
heads.
|
| 397 |
+
"""
|
| 398 |
+
|
| 399 |
+
loss: Optional[torch.FloatTensor] = None
|
| 400 |
+
mc_loss: Optional[torch.FloatTensor] = None
|
| 401 |
+
logits: torch.FloatTensor = None
|
| 402 |
+
mc_logits: torch.FloatTensor = None
|
| 403 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None
|
| 404 |
+
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
| 405 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
| 406 |
+
|
| 407 |
+
|
| 408 |
+
GPT2_START_DOCSTRING = r"""
|
| 409 |
+
This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
|
| 410 |
+
methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
|
| 411 |
+
pruning heads etc.)
|
| 412 |
+
This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
|
| 413 |
+
subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
|
| 414 |
+
general usage and behavior.
|
| 415 |
+
Parameters:
|
| 416 |
+
config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.
|
| 417 |
+
Initializing with a config file does not load the weights associated with the model, only the
|
| 418 |
+
configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
|
| 419 |
+
weights.
|
| 420 |
+
"""
|
| 421 |
+
|
| 422 |
+
GPT2_INPUTS_DOCSTRING = r"""
|
| 423 |
+
Args:
|
| 424 |
+
input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`):
|
| 425 |
+
:obj:`input_ids_length` = ``sequence_length`` if :obj:`past_key_values` is ``None`` else
|
| 426 |
+
``past_key_values[0].shape[-2]`` (``sequence_length`` of input past key value states). Indices of input
|
| 427 |
+
sequence tokens in the vocabulary.
|
| 428 |
+
If :obj:`past_key_values` is used, only ``input_ids`` that do not have their past calculated should be
|
| 429 |
+
passed as ``input_ids``.
|
| 430 |
+
Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See
|
| 431 |
+
:meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for
|
| 432 |
+
details.
|
| 433 |
+
`What are input IDs? <../glossary.html#input-ids>`__
|
| 434 |
+
past_key_values (:obj:`List[torch.FloatTensor]` of length :obj:`config.n_layers`):
|
| 435 |
+
Contains precomputed hidden-states (key and values in the attention blocks) as computed by the model (see
|
| 436 |
+
:obj:`past_key_values` output below). Can be used to speed up sequential decoding. The ``input_ids`` which
|
| 437 |
+
have their past given to this model should not be passed as ``input_ids`` as they have already been
|
| 438 |
+
computed.
|
| 439 |
+
attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 440 |
+
Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
|
| 441 |
+
- 1 for tokens that are **not masked**,
|
| 442 |
+
- 0 for tokens that are **masked**.
|
| 443 |
+
`What are attention masks? <../glossary.html#attention-mask>`__
|
| 444 |
+
token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, input_ids_length)`, `optional`):
|
| 445 |
+
Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
|
| 446 |
+
1]``:
|
| 447 |
+
- 0 corresponds to a `sentence A` token,
|
| 448 |
+
- 1 corresponds to a `sentence B` token.
|
| 449 |
+
`What are token type IDs? <../glossary.html#token-type-ids>`_
|
| 450 |
+
position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
|
| 451 |
+
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,
|
| 452 |
+
config.max_position_embeddings - 1]``.
|
| 453 |
+
`What are position IDs? <../glossary.html#position-ids>`_
|
| 454 |
+
head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`, `optional`):
|
| 455 |
+
Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
|
| 456 |
+
- 1 indicates the head is **not masked**,
|
| 457 |
+
- 0 indicates the head is **masked**.
|
| 458 |
+
inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`, `optional`):
|
| 459 |
+
Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded representation.
|
| 460 |
+
This is useful if you want more control over how to convert :obj:`input_ids` indices into associated
|
| 461 |
+
vectors than the model's internal embedding lookup matrix.
|
| 462 |
+
If :obj:`past_key_values` is used, optionally only the last :obj:`inputs_embeds` have to be input (see
|
| 463 |
+
:obj:`past_key_values`).
|
| 464 |
+
use_cache (:obj:`bool`, `optional`):
|
| 465 |
+
If set to :obj:`True`, :obj:`past_key_values` key value states are returned and can be used to speed up
|
| 466 |
+
decoding (see :obj:`past_key_values`).
|
| 467 |
+
output_attentions (:obj:`bool`, `optional`):
|
| 468 |
+
Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned
|
| 469 |
+
tensors for more detail.
|
| 470 |
+
output_hidden_states (:obj:`bool`, `optional`):
|
| 471 |
+
Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for
|
| 472 |
+
more detail.
|
| 473 |
+
return_dict (:obj:`bool`, `optional`):
|
| 474 |
+
Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.
|
| 475 |
+
"""
|
| 476 |
+
PARALLELIZE_DOCSTRING = r"""
|
| 477 |
+
Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
|
| 478 |
+
it will evenly distribute blocks across all devices.
|
| 479 |
+
Args:
|
| 480 |
+
device_map (:obj:`Dict[int, list]`, optional, defaults to None):
|
| 481 |
+
A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
|
| 482 |
+
automatically mapped to the first device (for esoteric reasons). That means that the first device should
|
| 483 |
+
have fewer attention modules mapped to it than other devices. For reference, the gpt2 models have the
|
| 484 |
+
following number of attention modules:
|
| 485 |
+
- gpt2: 12
|
| 486 |
+
- gpt2-medium: 24
|
| 487 |
+
- gpt2-large: 36
|
| 488 |
+
- gpt2-xl: 48
|
| 489 |
+
Example::
|
| 490 |
+
# Here is an example of a device map on a machine with 4 GPUs using gpt2-xl, which has a total of 48 attention modules:
|
| 491 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2-xl')
|
| 492 |
+
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7, 8],
|
| 493 |
+
1: [9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21],
|
| 494 |
+
2: [22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34],
|
| 495 |
+
3: [35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]}
|
| 496 |
+
model.parallelize(device_map)
|
| 497 |
+
"""
|
| 498 |
+
DEPARALLELIZE_DOCSTRING = r"""
|
| 499 |
+
Moves the model to cpu from a model parallel state.
|
| 500 |
+
Example::
|
| 501 |
+
# On a 4 GPU machine with gpt2-large:
|
| 502 |
+
model = GPT2LMHeadModel.from_pretrained('gpt2-large')
|
| 503 |
+
device_map = {0: [0, 1, 2, 3, 4, 5, 6, 7],
|
| 504 |
+
1: [8, 9, 10, 11, 12, 13, 14, 15],
|
| 505 |
+
2: [16, 17, 18, 19, 20, 21, 22, 23],
|
| 506 |
+
3: [24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35]}
|
| 507 |
+
model.parallelize(device_map) # Splits the model across several devices
|
| 508 |
+
model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
|
| 509 |
+
"""
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
@add_start_docstrings(
|
| 513 |
+
"The bare GPT2 Model transformer outputting raw hidden-states without any specific head on top.",
|
| 514 |
+
GPT2_START_DOCSTRING,
|
| 515 |
+
)
|
| 516 |
+
class GPT2Model(GPT2PreTrainedModel):
|
| 517 |
+
def __init__(self, config):
|
| 518 |
+
super().__init__(config)
|
| 519 |
+
|
| 520 |
+
self.wte = nn.Embedding(config.vocab_size, config.n_embd)
|
| 521 |
+
# self.wpe = nn.Embedding(config.n_positions, config.n_embd)
|
| 522 |
+
self.drop = nn.Dropout(config.embd_pdrop)
|
| 523 |
+
self.h = nn.ModuleList([Block(config.n_ctx, config, scale=True) for _ in range(config.n_layer)])
|
| 524 |
+
self.ln_f = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
|
| 525 |
+
|
| 526 |
+
self.init_weights()
|
| 527 |
+
# Model parallel
|
| 528 |
+
self.model_parallel = False
|
| 529 |
+
self.device_map = None
|
| 530 |
+
|
| 531 |
+
self.use_layers = None
|
| 532 |
+
|
| 533 |
+
def set_layers(self, num_layers):
|
| 534 |
+
assert 1 <= num_layers <= len(self.h)
|
| 535 |
+
if num_layers is not None:
|
| 536 |
+
num_layers -= 1
|
| 537 |
+
self.use_layers = num_layers
|
| 538 |
+
|
| 539 |
+
@add_start_docstrings(PARALLELIZE_DOCSTRING)
|
| 540 |
+
def parallelize(self, device_map=None):
|
| 541 |
+
# Check validity of device_map
|
| 542 |
+
self.device_map = (
|
| 543 |
+
get_device_map(len(self.h), range(torch.cuda.device_count())) if device_map is None else device_map
|
| 544 |
+
)
|
| 545 |
+
assert_device_map(self.device_map, len(self.h))
|
| 546 |
+
self.model_parallel = True
|
| 547 |
+
self.first_device = "cpu" if "cpu" in self.device_map.keys() else "cuda:" + str(min(self.device_map.keys()))
|
| 548 |
+
self.last_device = "cuda:" + str(max(self.device_map.keys()))
|
| 549 |
+
self.wte = self.wte.to(self.first_device)
|
| 550 |
+
self.wpe = self.wpe.to(self.first_device)
|
| 551 |
+
# Load onto devices
|
| 552 |
+
for k, v in self.device_map.items():
|
| 553 |
+
for block in v:
|
| 554 |
+
cuda_device = "cuda:" + str(k)
|
| 555 |
+
self.h[block] = self.h[block].to(cuda_device)
|
| 556 |
+
# ln_f to last
|
| 557 |
+
self.ln_f = self.ln_f.to(self.last_device)
|
| 558 |
+
|
| 559 |
+
@add_start_docstrings(DEPARALLELIZE_DOCSTRING)
|
| 560 |
+
def deparallelize(self):
|
| 561 |
+
self.model_parallel = False
|
| 562 |
+
self.device_map = None
|
| 563 |
+
self.first_device = "cpu"
|
| 564 |
+
self.last_device = "cpu"
|
| 565 |
+
self.wte = self.wte.to("cpu")
|
| 566 |
+
self.wpe = self.wpe.to("cpu")
|
| 567 |
+
for index in range(len(self.h)):
|
| 568 |
+
self.h[index] = self.h[index].to("cpu")
|
| 569 |
+
self.ln_f = self.ln_f.to("cpu")
|
| 570 |
+
torch.cuda.empty_cache()
|
| 571 |
+
|
| 572 |
+
def get_input_embeddings(self):
|
| 573 |
+
return self.wte
|
| 574 |
+
|
| 575 |
+
def set_input_embeddings(self, new_embeddings):
|
| 576 |
+
self.wte = new_embeddings
|
| 577 |
+
|
| 578 |
+
def _prune_heads(self, heads_to_prune):
|
| 579 |
+
"""
|
| 580 |
+
Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer}
|
| 581 |
+
"""
|
| 582 |
+
for layer, heads in heads_to_prune.items():
|
| 583 |
+
self.h[layer].attn.prune_heads(heads)
|
| 584 |
+
|
| 585 |
+
@add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)
|
| 586 |
+
@add_code_sample_docstrings(
|
| 587 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 588 |
+
checkpoint="gpt2",
|
| 589 |
+
output_type=BaseModelOutputWithPastAndCrossAttentions,
|
| 590 |
+
config_class=_CONFIG_FOR_DOC,
|
| 591 |
+
)
|
| 592 |
+
def forward(
|
| 593 |
+
self,
|
| 594 |
+
input_ids=None,
|
| 595 |
+
past_key_values=None,
|
| 596 |
+
attention_mask=None,
|
| 597 |
+
token_type_ids=None,
|
| 598 |
+
position_ids=None,
|
| 599 |
+
head_mask=None,
|
| 600 |
+
inputs_embeds=None,
|
| 601 |
+
encoder_hidden_states=None,
|
| 602 |
+
encoder_attention_mask=None,
|
| 603 |
+
use_cache=None,
|
| 604 |
+
output_attentions=None,
|
| 605 |
+
output_hidden_states=None,
|
| 606 |
+
return_dict=None,
|
| 607 |
+
):
|
| 608 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
| 609 |
+
output_hidden_states = (
|
| 610 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
| 611 |
+
)
|
| 612 |
+
use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| 613 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 614 |
+
|
| 615 |
+
if input_ids is not None and inputs_embeds is not None:
|
| 616 |
+
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
| 617 |
+
elif input_ids is not None:
|
| 618 |
+
input_shape = input_ids.size()
|
| 619 |
+
input_ids = input_ids.view(-1, input_shape[-1])
|
| 620 |
+
batch_size = input_ids.shape[0]
|
| 621 |
+
elif inputs_embeds is not None:
|
| 622 |
+
input_shape = inputs_embeds.size()[:-1]
|
| 623 |
+
batch_size = inputs_embeds.shape[0]
|
| 624 |
+
else:
|
| 625 |
+
raise ValueError("You have to specify either input_ids or inputs_embeds")
|
| 626 |
+
|
| 627 |
+
if token_type_ids is not None:
|
| 628 |
+
token_type_ids = token_type_ids.view(-1, input_shape[-1])
|
| 629 |
+
if position_ids is not None:
|
| 630 |
+
position_ids = position_ids.view(-1, input_shape[-1])
|
| 631 |
+
|
| 632 |
+
if past_key_values is None:
|
| 633 |
+
past_length = 0
|
| 634 |
+
past_key_values = [None] * len(self.h)
|
| 635 |
+
else:
|
| 636 |
+
past_length = past_key_values[0][0].size(-2)
|
| 637 |
+
if position_ids is None:
|
| 638 |
+
device = input_ids.device if input_ids is not None else inputs_embeds.device
|
| 639 |
+
position_ids = torch.arange(past_length, input_shape[-1] + past_length, dtype=torch.long, device=device)
|
| 640 |
+
position_ids = position_ids.unsqueeze(0).view(-1, input_shape[-1])
|
| 641 |
+
|
| 642 |
+
# Attention mask.
|
| 643 |
+
if attention_mask is not None:
|
| 644 |
+
assert batch_size > 0, "batch_size has to be defined and > 0"
|
| 645 |
+
attention_mask = attention_mask.view(batch_size, -1)
|
| 646 |
+
# We create a 3D attention mask from a 2D tensor mask.
|
| 647 |
+
# Sizes are [batch_size, 1, 1, to_seq_length]
|
| 648 |
+
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
|
| 649 |
+
# this attention mask is more simple than the triangular masking of causal attention
|
| 650 |
+
# used in OpenAI GPT, we just need to prepare the broadcast dimension here.
|
| 651 |
+
attention_mask = attention_mask[:, None, None, :]
|
| 652 |
+
|
| 653 |
+
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for
|
| 654 |
+
# masked positions, this operation will create a tensor which is 0.0 for
|
| 655 |
+
# positions we want to attend and -10000.0 for masked positions.
|
| 656 |
+
# Since we are adding it to the raw scores before the softmax, this is
|
| 657 |
+
# effectively the same as removing these entirely.
|
| 658 |
+
attention_mask = attention_mask.to(dtype=self.dtype) # fp16 compatibility
|
| 659 |
+
attention_mask = (1.0 - attention_mask) * -10000.0
|
| 660 |
+
|
| 661 |
+
# If a 2D ou 3D attention mask is provided for the cross-attention
|
| 662 |
+
# we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
|
| 663 |
+
if self.config.add_cross_attention and encoder_hidden_states is not None:
|
| 664 |
+
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
|
| 665 |
+
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
|
| 666 |
+
if encoder_attention_mask is None:
|
| 667 |
+
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
|
| 668 |
+
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
|
| 669 |
+
else:
|
| 670 |
+
encoder_attention_mask = None
|
| 671 |
+
|
| 672 |
+
# Prepare head mask if needed
|
| 673 |
+
# 1.0 in head_mask indicate we keep the head
|
| 674 |
+
# attention_probs has shape bsz x n_heads x N x N
|
| 675 |
+
# head_mask has shape n_layer x batch x n_heads x N x N
|
| 676 |
+
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
| 677 |
+
|
| 678 |
+
if inputs_embeds is None:
|
| 679 |
+
inputs_embeds = self.wte(input_ids)
|
| 680 |
+
# position_embeds = self.wpe(position_ids)
|
| 681 |
+
hidden_states = inputs_embeds # + position_embeds
|
| 682 |
+
|
| 683 |
+
if token_type_ids is not None:
|
| 684 |
+
token_type_embeds = self.wte(token_type_ids)
|
| 685 |
+
hidden_states = hidden_states + token_type_embeds
|
| 686 |
+
|
| 687 |
+
hidden_states = self.drop(hidden_states)
|
| 688 |
+
|
| 689 |
+
output_shape = input_shape + (hidden_states.size(-1),)
|
| 690 |
+
|
| 691 |
+
presents = () if use_cache else None
|
| 692 |
+
all_self_attentions = () if output_attentions else None
|
| 693 |
+
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
|
| 694 |
+
all_hidden_states = () if output_hidden_states else None
|
| 695 |
+
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):
|
| 696 |
+
|
| 697 |
+
if self.use_layers is not None and i >= self.use_layers:
|
| 698 |
+
break
|
| 699 |
+
|
| 700 |
+
# Model parallel
|
| 701 |
+
if self.model_parallel:
|
| 702 |
+
torch.cuda.set_device(hidden_states.device)
|
| 703 |
+
# Ensure layer_past is on same device as hidden_states (might not be correct)
|
| 704 |
+
if layer_past is not None:
|
| 705 |
+
layer_past = layer_past.to(hidden_states.device)
|
| 706 |
+
# Ensure that attention_mask is always on the same device as hidden_states
|
| 707 |
+
if attention_mask is not None:
|
| 708 |
+
attention_mask = attention_mask.to(hidden_states.device)
|
| 709 |
+
if isinstance(head_mask, torch.Tensor):
|
| 710 |
+
head_mask = head_mask.to(hidden_states.device)
|
| 711 |
+
if output_hidden_states:
|
| 712 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 713 |
+
|
| 714 |
+
if getattr(self.config, "gradient_checkpointing", False):
|
| 715 |
+
|
| 716 |
+
def create_custom_forward(module):
|
| 717 |
+
def custom_forward(*inputs):
|
| 718 |
+
# checkpointing only works with tuple returns, not with lists
|
| 719 |
+
return tuple(output for output in module(*inputs, use_cache, output_attentions))
|
| 720 |
+
|
| 721 |
+
return custom_forward
|
| 722 |
+
|
| 723 |
+
outputs = torch.utils.checkpoint.checkpoint(
|
| 724 |
+
create_custom_forward(block),
|
| 725 |
+
hidden_states,
|
| 726 |
+
layer_past,
|
| 727 |
+
attention_mask,
|
| 728 |
+
head_mask[i],
|
| 729 |
+
encoder_hidden_states,
|
| 730 |
+
encoder_attention_mask,
|
| 731 |
+
)
|
| 732 |
+
else:
|
| 733 |
+
outputs = block(
|
| 734 |
+
hidden_states,
|
| 735 |
+
layer_past=layer_past,
|
| 736 |
+
attention_mask=attention_mask,
|
| 737 |
+
head_mask=head_mask[i],
|
| 738 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 739 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 740 |
+
use_cache=use_cache,
|
| 741 |
+
output_attentions=output_attentions,
|
| 742 |
+
)
|
| 743 |
+
|
| 744 |
+
hidden_states, present = outputs[:2]
|
| 745 |
+
if use_cache is True:
|
| 746 |
+
presents = presents + (present,)
|
| 747 |
+
|
| 748 |
+
if output_attentions:
|
| 749 |
+
all_self_attentions = all_self_attentions + (outputs[2],)
|
| 750 |
+
if self.config.add_cross_attention:
|
| 751 |
+
all_cross_attentions = all_cross_attentions + (outputs[3],)
|
| 752 |
+
|
| 753 |
+
# Model Parallel: If it's the last layer for that device, put things on the next device
|
| 754 |
+
if self.model_parallel:
|
| 755 |
+
for k, v in self.device_map.items():
|
| 756 |
+
if i == v[-1] and "cuda:" + str(k) != self.last_device:
|
| 757 |
+
hidden_states = hidden_states.to("cuda:" + str(k + 1))
|
| 758 |
+
|
| 759 |
+
hidden_states = self.ln_f(hidden_states)
|
| 760 |
+
|
| 761 |
+
hidden_states = hidden_states.view(*output_shape)
|
| 762 |
+
# Add last hidden state
|
| 763 |
+
if output_hidden_states:
|
| 764 |
+
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 765 |
+
|
| 766 |
+
if not return_dict:
|
| 767 |
+
return tuple(v for v in [hidden_states, presents, all_hidden_states, all_self_attentions] if v is not None)
|
| 768 |
+
|
| 769 |
+
return BaseModelOutputWithPastAndCrossAttentions(
|
| 770 |
+
last_hidden_state=hidden_states,
|
| 771 |
+
past_key_values=presents,
|
| 772 |
+
hidden_states=all_hidden_states,
|
| 773 |
+
attentions=all_self_attentions,
|
| 774 |
+
cross_attentions=all_cross_attentions,
|
| 775 |
+
)
|
third_party/AnyBimanual/agents/peract_bimanual/visual_aligner.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
class VisualAligner(nn.Module):
|
| 6 |
+
def __init__(self, input_dim=128, hidden_dim=256, mask_dim=128):
|
| 7 |
+
super(VisualAligner, self).__init__()
|
| 8 |
+
|
| 9 |
+
self.conv1 = nn.Conv1d(in_channels=input_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 10 |
+
|
| 11 |
+
self.conv_res1 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 12 |
+
self.conv_res2 = nn.Conv1d(in_channels=hidden_dim, out_channels=hidden_dim, kernel_size=3, padding=1)
|
| 13 |
+
|
| 14 |
+
self.conv2_right = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
|
| 15 |
+
self.conv2_left = nn.Conv1d(in_channels=hidden_dim, out_channels=mask_dim, kernel_size=3, padding=1)
|
| 16 |
+
|
| 17 |
+
self.activation = nn.ReLU()
|
| 18 |
+
|
| 19 |
+
def forward(self, ins):
|
| 20 |
+
ins = ins.transpose(1, 2)
|
| 21 |
+
|
| 22 |
+
features = self.activation(self.conv1(ins))
|
| 23 |
+
|
| 24 |
+
residual = features
|
| 25 |
+
features = self.activation(self.conv_res1(features))
|
| 26 |
+
features = self.conv_res2(features)
|
| 27 |
+
features = features + residual
|
| 28 |
+
|
| 29 |
+
mask_right = self.activation(self.conv2_right(features))
|
| 30 |
+
mask_left = self.activation(self.conv2_left(features))
|
| 31 |
+
|
| 32 |
+
mask_right = mask_right.transpose(1, 2)
|
| 33 |
+
mask_left = mask_left.transpose(1, 2)
|
| 34 |
+
ins = ins.transpose(1, 2)
|
| 35 |
+
|
| 36 |
+
masked_ins1 = ins * mask_right
|
| 37 |
+
masked_ins2 = ins * mask_left
|
| 38 |
+
|
| 39 |
+
return masked_ins1, masked_ins2
|
third_party/AnyBimanual/agents/replay_utils.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import logging
|
| 3 |
+
from typing import List
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
from rlbench.backend.observation import Observation
|
| 7 |
+
from rlbench.observation_config import ObservationConfig
|
| 8 |
+
import rlbench.utils as rlbench_utils
|
| 9 |
+
from rlbench.demo import Demo
|
| 10 |
+
from yarr.replay_buffer.replay_buffer import ReplayBuffer
|
| 11 |
+
|
| 12 |
+
from helpers import demo_loading_utils, utils
|
| 13 |
+
from helpers import observation_utils
|
| 14 |
+
from helpers.clip.core.clip import tokenize
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from yarr.replay_buffer.prioritized_replay_buffer import ObservationElement
|
| 18 |
+
from yarr.replay_buffer.replay_buffer import ReplayElement
|
| 19 |
+
from yarr.replay_buffer.task_uniform_replay_buffer import TaskUniformReplayBuffer
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
import torch
|
| 23 |
+
from torch.multiprocessing import Process, Value, Manager
|
| 24 |
+
from helpers.clip.core.clip import build_model, load_clip
|
| 25 |
+
from omegaconf import DictConfig
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
REWARD_SCALE = 100.0
|
| 29 |
+
LOW_DIM_SIZE = 4
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_replay(cfg, replay_path):
|
| 33 |
+
|
| 34 |
+
if cfg.method.robot_name == "bimanual":
|
| 35 |
+
return create_bimanual_replay(
|
| 36 |
+
cfg.replay.batch_size,
|
| 37 |
+
cfg.replay.timesteps,
|
| 38 |
+
cfg.replay.prioritisation,
|
| 39 |
+
cfg.replay.task_uniform,
|
| 40 |
+
replay_path if cfg.replay.use_disk else None,
|
| 41 |
+
cfg.rlbench.cameras,
|
| 42 |
+
cfg.method.voxel_sizes,
|
| 43 |
+
cfg.rlbench.camera_resolution,
|
| 44 |
+
)
|
| 45 |
+
else:
|
| 46 |
+
return create_unimanual_replay(
|
| 47 |
+
cfg.replay.batch_size,
|
| 48 |
+
cfg.replay.timesteps,
|
| 49 |
+
cfg.replay.prioritisation,
|
| 50 |
+
cfg.replay.task_uniform,
|
| 51 |
+
replay_path if cfg.replay.use_disk else None,
|
| 52 |
+
cfg.rlbench.cameras,
|
| 53 |
+
cfg.method.voxel_sizes,
|
| 54 |
+
cfg.rlbench.camera_resolution,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def create_bimanual_replay(
|
| 60 |
+
batch_size: int,
|
| 61 |
+
timesteps: int,
|
| 62 |
+
prioritisation: bool,
|
| 63 |
+
task_uniform: bool,
|
| 64 |
+
save_dir: str,
|
| 65 |
+
cameras: list,
|
| 66 |
+
voxel_sizes,
|
| 67 |
+
image_size=[128, 128],
|
| 68 |
+
replay_size=3e5,
|
| 69 |
+
):
|
| 70 |
+
trans_indicies_size = 3 * len(voxel_sizes)
|
| 71 |
+
rot_and_grip_indicies_size = 3 + 1
|
| 72 |
+
gripper_pose_size = 7
|
| 73 |
+
ignore_collisions_size = 1
|
| 74 |
+
max_token_seq_len = 77
|
| 75 |
+
lang_feat_dim = 1024
|
| 76 |
+
lang_emb_dim = 512
|
| 77 |
+
|
| 78 |
+
# low_dim_state
|
| 79 |
+
observation_elements = []
|
| 80 |
+
observation_elements.append(
|
| 81 |
+
ObservationElement("right_low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 82 |
+
)
|
| 83 |
+
observation_elements.append(
|
| 84 |
+
ObservationElement("left_low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 88 |
+
for cname in cameras:
|
| 89 |
+
observation_elements.append(
|
| 90 |
+
# color, height, width
|
| 91 |
+
ObservationElement(
|
| 92 |
+
"%s_rgb" % cname,
|
| 93 |
+
(
|
| 94 |
+
3,
|
| 95 |
+
image_size[1],
|
| 96 |
+
image_size[0],
|
| 97 |
+
),
|
| 98 |
+
np.float32,
|
| 99 |
+
)
|
| 100 |
+
)
|
| 101 |
+
observation_elements.append(
|
| 102 |
+
ObservationElement("%s_point_cloud" % cname, (3, image_size[1], image_size[0]), np.float16)
|
| 103 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 104 |
+
observation_elements.append(
|
| 105 |
+
ObservationElement(
|
| 106 |
+
"%s_camera_extrinsics" % cname,
|
| 107 |
+
(
|
| 108 |
+
4,
|
| 109 |
+
4,
|
| 110 |
+
),
|
| 111 |
+
np.float32,
|
| 112 |
+
)
|
| 113 |
+
)
|
| 114 |
+
observation_elements.append(
|
| 115 |
+
ObservationElement(
|
| 116 |
+
"%s_camera_intrinsics" % cname,
|
| 117 |
+
(
|
| 118 |
+
3,
|
| 119 |
+
3,
|
| 120 |
+
),
|
| 121 |
+
np.float32,
|
| 122 |
+
)
|
| 123 |
+
)
|
| 124 |
+
|
| 125 |
+
# discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
|
| 126 |
+
for robot_name in ["right", "left"]:
|
| 127 |
+
observation_elements.extend(
|
| 128 |
+
[
|
| 129 |
+
ReplayElement(
|
| 130 |
+
f"{robot_name}_trans_action_indicies",
|
| 131 |
+
(trans_indicies_size,),
|
| 132 |
+
np.int32,
|
| 133 |
+
),
|
| 134 |
+
ReplayElement(
|
| 135 |
+
f"{robot_name}_rot_grip_action_indicies",
|
| 136 |
+
(rot_and_grip_indicies_size,),
|
| 137 |
+
np.int32,
|
| 138 |
+
),
|
| 139 |
+
ReplayElement(
|
| 140 |
+
f"{robot_name}_ignore_collisions",
|
| 141 |
+
(ignore_collisions_size,),
|
| 142 |
+
np.int32,
|
| 143 |
+
),
|
| 144 |
+
ReplayElement(
|
| 145 |
+
f"{robot_name}_gripper_pose", (gripper_pose_size,), np.float32
|
| 146 |
+
),
|
| 147 |
+
]
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
observation_elements.extend(
|
| 151 |
+
[
|
| 152 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 153 |
+
ReplayElement(
|
| 154 |
+
"lang_token_embs",
|
| 155 |
+
(
|
| 156 |
+
max_token_seq_len,
|
| 157 |
+
lang_emb_dim,
|
| 158 |
+
),
|
| 159 |
+
np.float32,
|
| 160 |
+
), # extracted from CLIP's language encoder
|
| 161 |
+
ReplayElement("task", (), str),
|
| 162 |
+
ReplayElement(
|
| 163 |
+
"lang_goal", (1,), object
|
| 164 |
+
), # language goal string for debugging and visualization
|
| 165 |
+
]
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
extra_replay_elements = [
|
| 169 |
+
ReplayElement("demo", (), bool),
|
| 170 |
+
]
|
| 171 |
+
|
| 172 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 173 |
+
save_dir=save_dir,
|
| 174 |
+
batch_size=batch_size,
|
| 175 |
+
timesteps=timesteps,
|
| 176 |
+
replay_capacity=int(replay_size),
|
| 177 |
+
action_shape=(8 * 2,),
|
| 178 |
+
action_dtype=np.float32,
|
| 179 |
+
reward_shape=(),
|
| 180 |
+
reward_dtype=np.float32,
|
| 181 |
+
update_horizon=1,
|
| 182 |
+
observation_elements=observation_elements,
|
| 183 |
+
extra_replay_elements=extra_replay_elements,
|
| 184 |
+
)
|
| 185 |
+
return replay_buffer
|
| 186 |
+
|
| 187 |
+
def create_unimanual_replay(
|
| 188 |
+
batch_size: int,
|
| 189 |
+
timesteps: int,
|
| 190 |
+
prioritisation: bool,
|
| 191 |
+
task_uniform: bool,
|
| 192 |
+
save_dir: str,
|
| 193 |
+
cameras: list,
|
| 194 |
+
voxel_sizes,
|
| 195 |
+
image_size=[128, 128],
|
| 196 |
+
replay_size=3e5,
|
| 197 |
+
):
|
| 198 |
+
trans_indicies_size = 3 * len(voxel_sizes)
|
| 199 |
+
rot_and_grip_indicies_size = 3 + 1
|
| 200 |
+
gripper_pose_size = 7
|
| 201 |
+
ignore_collisions_size = 1
|
| 202 |
+
max_token_seq_len = 77
|
| 203 |
+
lang_feat_dim = 1024
|
| 204 |
+
lang_emb_dim = 512
|
| 205 |
+
|
| 206 |
+
# low_dim_state
|
| 207 |
+
observation_elements = []
|
| 208 |
+
observation_elements.append(
|
| 209 |
+
ObservationElement("low_dim_state", (LOW_DIM_SIZE,), np.float32)
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
# rgb, depth, point cloud, intrinsics, extrinsics
|
| 213 |
+
for cname in cameras:
|
| 214 |
+
observation_elements.append(
|
| 215 |
+
ObservationElement(
|
| 216 |
+
"%s_rgb" % cname,
|
| 217 |
+
(
|
| 218 |
+
3,
|
| 219 |
+
*image_size,
|
| 220 |
+
),
|
| 221 |
+
np.float32,
|
| 222 |
+
)
|
| 223 |
+
)
|
| 224 |
+
observation_elements.append(
|
| 225 |
+
ObservationElement("%s_point_cloud" % cname, (3, *image_size), np.float32)
|
| 226 |
+
) # see pyrep/objects/vision_sensor.py on how pointclouds are extracted from depth frames
|
| 227 |
+
observation_elements.append(
|
| 228 |
+
ObservationElement(
|
| 229 |
+
"%s_camera_extrinsics" % cname,
|
| 230 |
+
(
|
| 231 |
+
4,
|
| 232 |
+
4,
|
| 233 |
+
),
|
| 234 |
+
np.float32,
|
| 235 |
+
)
|
| 236 |
+
)
|
| 237 |
+
observation_elements.append(
|
| 238 |
+
ObservationElement(
|
| 239 |
+
"%s_camera_intrinsics" % cname,
|
| 240 |
+
(
|
| 241 |
+
3,
|
| 242 |
+
3,
|
| 243 |
+
),
|
| 244 |
+
np.float32,
|
| 245 |
+
)
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
# discretized translation, discretized rotation, discrete ignore collision, 6-DoF gripper pose, and pre-trained language embeddings
|
| 249 |
+
observation_elements.extend(
|
| 250 |
+
[
|
| 251 |
+
ReplayElement("trans_action_indicies", (trans_indicies_size,), np.int32),
|
| 252 |
+
ReplayElement(
|
| 253 |
+
"rot_grip_action_indicies", (rot_and_grip_indicies_size,), np.int32
|
| 254 |
+
),
|
| 255 |
+
ReplayElement("ignore_collisions", (ignore_collisions_size,), np.int32),
|
| 256 |
+
ReplayElement("gripper_pose", (gripper_pose_size,), np.float32),
|
| 257 |
+
ReplayElement("lang_goal_emb", (lang_feat_dim,), np.float32),
|
| 258 |
+
ReplayElement(
|
| 259 |
+
"lang_token_embs",
|
| 260 |
+
(
|
| 261 |
+
max_token_seq_len,
|
| 262 |
+
lang_emb_dim,
|
| 263 |
+
),
|
| 264 |
+
np.float32,
|
| 265 |
+
), # extracted from CLIP's language encoder
|
| 266 |
+
ReplayElement("task", (), str),
|
| 267 |
+
ReplayElement(
|
| 268 |
+
"lang_goal", (1,), object
|
| 269 |
+
), # language goal string for debugging and visualization
|
| 270 |
+
]
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
extra_replay_elements = [
|
| 274 |
+
ReplayElement("demo", (), bool),
|
| 275 |
+
]
|
| 276 |
+
|
| 277 |
+
replay_buffer = TaskUniformReplayBuffer(
|
| 278 |
+
save_dir=save_dir,
|
| 279 |
+
batch_size=batch_size,
|
| 280 |
+
timesteps=timesteps,
|
| 281 |
+
replay_capacity=int(replay_size),
|
| 282 |
+
action_shape=(8,),
|
| 283 |
+
action_dtype=np.float32,
|
| 284 |
+
reward_shape=(),
|
| 285 |
+
reward_dtype=np.float32,
|
| 286 |
+
update_horizon=1,
|
| 287 |
+
observation_elements=observation_elements,
|
| 288 |
+
extra_replay_elements=extra_replay_elements,
|
| 289 |
+
)
|
| 290 |
+
return replay_buffer
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def _get_action(
|
| 295 |
+
obs_tp1: Observation,
|
| 296 |
+
obs_tm1: Observation,
|
| 297 |
+
rlbench_scene_bounds: List[float], # metric 3D bounds of the scene
|
| 298 |
+
voxel_sizes: List[int],
|
| 299 |
+
bounds_offset: List[float],
|
| 300 |
+
rotation_resolution: int,
|
| 301 |
+
crop_augmentation: bool,
|
| 302 |
+
):
|
| 303 |
+
quat = utils.normalize_quaternion(obs_tp1.gripper_pose[3:])
|
| 304 |
+
if quat[-1] < 0:
|
| 305 |
+
quat = -quat
|
| 306 |
+
disc_rot = utils.quaternion_to_discrete_euler(quat, rotation_resolution)
|
| 307 |
+
disc_rot = utils.correct_rotation_instability(disc_rot, rotation_resolution)
|
| 308 |
+
|
| 309 |
+
attention_coordinate = obs_tp1.gripper_pose[:3]
|
| 310 |
+
trans_indicies, attention_coordinates = [], []
|
| 311 |
+
bounds = np.array(rlbench_scene_bounds)
|
| 312 |
+
ignore_collisions = int(obs_tm1.ignore_collisions)
|
| 313 |
+
for depth, vox_size in enumerate(
|
| 314 |
+
voxel_sizes
|
| 315 |
+
): # only single voxelization-level is used in PerAct
|
| 316 |
+
if depth > 0:
|
| 317 |
+
if crop_augmentation:
|
| 318 |
+
shift = bounds_offset[depth - 1] * 0.75
|
| 319 |
+
attention_coordinate += np.random.uniform(-shift, shift, size=(3,))
|
| 320 |
+
bounds = np.concatenate(
|
| 321 |
+
[
|
| 322 |
+
attention_coordinate - bounds_offset[depth - 1],
|
| 323 |
+
attention_coordinate + bounds_offset[depth - 1],
|
| 324 |
+
]
|
| 325 |
+
)
|
| 326 |
+
index = utils.point_to_voxel_index(obs_tp1.gripper_pose[:3], vox_size, bounds)
|
| 327 |
+
trans_indicies.extend(index.tolist())
|
| 328 |
+
res = (bounds[3:] - bounds[:3]) / vox_size
|
| 329 |
+
attention_coordinate = bounds[:3] + res * index
|
| 330 |
+
attention_coordinates.append(attention_coordinate)
|
| 331 |
+
|
| 332 |
+
rot_and_grip_indicies = disc_rot.tolist()
|
| 333 |
+
grip = float(obs_tp1.gripper_open)
|
| 334 |
+
rot_and_grip_indicies.extend([int(obs_tp1.gripper_open)])
|
| 335 |
+
return (
|
| 336 |
+
trans_indicies,
|
| 337 |
+
rot_and_grip_indicies,
|
| 338 |
+
ignore_collisions,
|
| 339 |
+
np.concatenate([obs_tp1.gripper_pose, np.array([grip])]),
|
| 340 |
+
attention_coordinates,
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _add_keypoints_to_replay(
|
| 345 |
+
cfg: DictConfig,
|
| 346 |
+
task: str,
|
| 347 |
+
replay: ReplayBuffer,
|
| 348 |
+
inital_obs: Observation,
|
| 349 |
+
demo: Demo,
|
| 350 |
+
episode_keypoints: List[int],
|
| 351 |
+
|
| 352 |
+
description: str = "",
|
| 353 |
+
clip_model=None,
|
| 354 |
+
device="cpu",
|
| 355 |
+
):
|
| 356 |
+
|
| 357 |
+
cameras = cfg.rlbench.cameras
|
| 358 |
+
rlbench_scene_bounds = cfg.rlbench.scene_bounds
|
| 359 |
+
voxel_sizes = cfg.method.voxel_sizes
|
| 360 |
+
bounds_offset = cfg.method.bounds_offset
|
| 361 |
+
rotation_resolution = cfg.method.rotation_resolution
|
| 362 |
+
crop_augmentation = cfg.method.crop_augmentation
|
| 363 |
+
robot_name = cfg.method.robot_name
|
| 364 |
+
|
| 365 |
+
prev_action = None
|
| 366 |
+
obs = inital_obs
|
| 367 |
+
|
| 368 |
+
for k, keypoint in enumerate(episode_keypoints):
|
| 369 |
+
obs_tp1 = demo[keypoint]
|
| 370 |
+
obs_tm1 = demo[max(0, keypoint - 1)]
|
| 371 |
+
|
| 372 |
+
if obs_tp1.is_bimanual and robot_name == "bimanual":
|
| 373 |
+
#assert isinstance(obs_tp1, BimanualObservation)
|
| 374 |
+
(
|
| 375 |
+
right_trans_indicies,
|
| 376 |
+
right_rot_grip_indicies,
|
| 377 |
+
right_ignore_collisions,
|
| 378 |
+
right_action,
|
| 379 |
+
right_attention_coordinates,
|
| 380 |
+
) = _get_action(
|
| 381 |
+
obs_tp1.right,
|
| 382 |
+
obs_tm1.right,
|
| 383 |
+
rlbench_scene_bounds,
|
| 384 |
+
voxel_sizes,
|
| 385 |
+
bounds_offset,
|
| 386 |
+
rotation_resolution,
|
| 387 |
+
crop_augmentation,
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
(
|
| 391 |
+
left_trans_indicies,
|
| 392 |
+
left_rot_grip_indicies,
|
| 393 |
+
left_ignore_collisions,
|
| 394 |
+
left_action,
|
| 395 |
+
left_attention_coordinates,
|
| 396 |
+
) = _get_action(
|
| 397 |
+
obs_tp1.left,
|
| 398 |
+
obs_tm1.left,
|
| 399 |
+
rlbench_scene_bounds,
|
| 400 |
+
voxel_sizes,
|
| 401 |
+
bounds_offset,
|
| 402 |
+
rotation_resolution,
|
| 403 |
+
crop_augmentation,
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
action = np.append(right_action, left_action)
|
| 407 |
+
|
| 408 |
+
right_ignore_collisions = np.array([right_ignore_collisions])
|
| 409 |
+
left_ignore_collisions = np.array([left_ignore_collisions])
|
| 410 |
+
|
| 411 |
+
elif robot_name == "unimanual":
|
| 412 |
+
(
|
| 413 |
+
trans_indicies,
|
| 414 |
+
rot_grip_indicies,
|
| 415 |
+
ignore_collisions,
|
| 416 |
+
action,
|
| 417 |
+
attention_coordinates,
|
| 418 |
+
) = _get_action(
|
| 419 |
+
obs_tp1,
|
| 420 |
+
obs_tm1,
|
| 421 |
+
rlbench_scene_bounds,
|
| 422 |
+
voxel_sizes,
|
| 423 |
+
bounds_offset,
|
| 424 |
+
rotation_resolution,
|
| 425 |
+
crop_augmentation,
|
| 426 |
+
)
|
| 427 |
+
gripper_pose = obs_tp1.gripper_pose
|
| 428 |
+
elif obs_tp1.is_bimanual and robot_name == "right":
|
| 429 |
+
(
|
| 430 |
+
trans_indicies,
|
| 431 |
+
rot_grip_indicies,
|
| 432 |
+
ignore_collisions,
|
| 433 |
+
action,
|
| 434 |
+
attention_coordinates,
|
| 435 |
+
) = _get_action(
|
| 436 |
+
obs_tp1.right,
|
| 437 |
+
obs_tm1.right,
|
| 438 |
+
rlbench_scene_bounds,
|
| 439 |
+
voxel_sizes,
|
| 440 |
+
bounds_offset,
|
| 441 |
+
rotation_resolution,
|
| 442 |
+
crop_augmentation,
|
| 443 |
+
)
|
| 444 |
+
gripper_pose = obs_tp1.right.gripper_pose
|
| 445 |
+
elif obs_tp1.is_bimanual and robot_name == "left":
|
| 446 |
+
(
|
| 447 |
+
trans_indicies,
|
| 448 |
+
rot_grip_indicies,
|
| 449 |
+
ignore_collisions,
|
| 450 |
+
action,
|
| 451 |
+
attention_coordinates,
|
| 452 |
+
) = _get_action(
|
| 453 |
+
obs_tp1.left,
|
| 454 |
+
obs_tm1.left,
|
| 455 |
+
rlbench_scene_bounds,
|
| 456 |
+
voxel_sizes,
|
| 457 |
+
bounds_offset,
|
| 458 |
+
rotation_resolution,
|
| 459 |
+
crop_augmentation,
|
| 460 |
+
)
|
| 461 |
+
gripper_pose = obs_tp1.left.gripper_pose
|
| 462 |
+
else:
|
| 463 |
+
logging.error("Invalid robot name %s", cfg.method.robot_name)
|
| 464 |
+
raise Exception("Invalid robot name.")
|
| 465 |
+
|
| 466 |
+
terminal = k == len(episode_keypoints) - 1
|
| 467 |
+
reward = float(terminal) * REWARD_SCALE if terminal else 0
|
| 468 |
+
|
| 469 |
+
obs_dict = observation_utils.extract_obs(
|
| 470 |
+
obs,
|
| 471 |
+
t=k,
|
| 472 |
+
prev_action=prev_action,
|
| 473 |
+
cameras=cameras,
|
| 474 |
+
episode_length=cfg.rlbench.episode_length,
|
| 475 |
+
robot_name=robot_name
|
| 476 |
+
)
|
| 477 |
+
tokens = tokenize([description]).numpy()
|
| 478 |
+
token_tensor = torch.from_numpy(tokens).to(device)
|
| 479 |
+
sentence_emb, token_embs = clip_model.encode_text_with_embeddings(token_tensor)
|
| 480 |
+
obs_dict["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
|
| 481 |
+
obs_dict["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
|
| 482 |
+
|
| 483 |
+
prev_action = np.copy(action)
|
| 484 |
+
|
| 485 |
+
others = {"demo": True}
|
| 486 |
+
if robot_name == "bimanual":
|
| 487 |
+
final_obs = {
|
| 488 |
+
"right_trans_action_indicies": right_trans_indicies,
|
| 489 |
+
"right_rot_grip_action_indicies": right_rot_grip_indicies,
|
| 490 |
+
"right_gripper_pose": obs_tp1.right.gripper_pose,
|
| 491 |
+
"left_trans_action_indicies": left_trans_indicies,
|
| 492 |
+
"left_rot_grip_action_indicies": left_rot_grip_indicies,
|
| 493 |
+
"left_gripper_pose": obs_tp1.left.gripper_pose,
|
| 494 |
+
"task": task,
|
| 495 |
+
"lang_goal": np.array([description], dtype=object),
|
| 496 |
+
}
|
| 497 |
+
else:
|
| 498 |
+
final_obs = {
|
| 499 |
+
"trans_action_indicies": trans_indicies,
|
| 500 |
+
"rot_grip_action_indicies": rot_grip_indicies,
|
| 501 |
+
"gripper_pose": gripper_pose,
|
| 502 |
+
"task": task,
|
| 503 |
+
"lang_goal": np.array([description], dtype=object),
|
| 504 |
+
}
|
| 505 |
+
|
| 506 |
+
others.update(final_obs)
|
| 507 |
+
others.update(obs_dict)
|
| 508 |
+
|
| 509 |
+
timeout = False
|
| 510 |
+
replay.add(action, reward, terminal, timeout, **others)
|
| 511 |
+
obs = obs_tp1
|
| 512 |
+
|
| 513 |
+
# final step
|
| 514 |
+
obs_dict_tp1 = observation_utils.extract_obs(
|
| 515 |
+
obs_tp1,
|
| 516 |
+
t=k + 1,
|
| 517 |
+
prev_action=prev_action,
|
| 518 |
+
cameras=cameras,
|
| 519 |
+
episode_length=cfg.rlbench.episode_length,
|
| 520 |
+
robot_name=cfg.method.robot_name
|
| 521 |
+
)
|
| 522 |
+
obs_dict_tp1["lang_goal_emb"] = sentence_emb[0].float().detach().cpu().numpy()
|
| 523 |
+
obs_dict_tp1["lang_token_embs"] = token_embs[0].float().detach().cpu().numpy()
|
| 524 |
+
|
| 525 |
+
obs_dict_tp1.pop("wrist_world_to_cam", None)
|
| 526 |
+
obs_dict_tp1.update(final_obs)
|
| 527 |
+
replay.add_final(**obs_dict_tp1)
|
| 528 |
+
|
| 529 |
+
def check_if_replay_exists(task: str, d_idx: int, replay_path: str):
|
| 530 |
+
replay_file = os.path.join(replay_path, f"{task}_replay_{d_idx}.pkl")
|
| 531 |
+
return os.path.exists(replay_file)
|
| 532 |
+
|
| 533 |
+
def fill_replay(
|
| 534 |
+
cfg: DictConfig,
|
| 535 |
+
obs_config: ObservationConfig,
|
| 536 |
+
rank: int,
|
| 537 |
+
replay: ReplayBuffer,
|
| 538 |
+
task: str,
|
| 539 |
+
clip_model=None,
|
| 540 |
+
device="cpu",
|
| 541 |
+
):
|
| 542 |
+
|
| 543 |
+
num_demos=cfg.rlbench.demos
|
| 544 |
+
demo_augmentation=cfg.method.demo_augmentation
|
| 545 |
+
demo_augmentation_every_n=cfg.method.demo_augmentation_every_n
|
| 546 |
+
keypoint_method=cfg.method.keypoint_method
|
| 547 |
+
|
| 548 |
+
|
| 549 |
+
if clip_model is None:
|
| 550 |
+
model, _ = load_clip("RN50", jit=False, device=device)
|
| 551 |
+
clip_model = build_model(model.state_dict())
|
| 552 |
+
clip_model.to(device)
|
| 553 |
+
del model
|
| 554 |
+
|
| 555 |
+
task_folder = cfg.replay.task_folder
|
| 556 |
+
replay_path = os.path.join(
|
| 557 |
+
cfg.replay.path, task_folder
|
| 558 |
+
)
|
| 559 |
+
logging.debug("Filling %s replay ..." % task)
|
| 560 |
+
for d_idx in range(num_demos):
|
| 561 |
+
# load demo from disk
|
| 562 |
+
if check_if_replay_exists(task, d_idx, replay_path):
|
| 563 |
+
logging.info(f"Replay for demo {d_idx} already exists, skipping...")
|
| 564 |
+
continue
|
| 565 |
+
demo = rlbench_utils.get_stored_demos(
|
| 566 |
+
amount=1,
|
| 567 |
+
image_paths=False,
|
| 568 |
+
dataset_root=cfg.rlbench.demo_path,
|
| 569 |
+
variation_number=-1,
|
| 570 |
+
task_name=task,
|
| 571 |
+
obs_config=obs_config,
|
| 572 |
+
random_selection=False,
|
| 573 |
+
from_episode_number=d_idx,
|
| 574 |
+
)[0]
|
| 575 |
+
|
| 576 |
+
descs = demo._observations[0].misc["descriptions"]
|
| 577 |
+
|
| 578 |
+
# extract keypoints (a.k.a keyframes)
|
| 579 |
+
episode_keypoints = demo_loading_utils.keypoint_discovery(
|
| 580 |
+
demo, method=keypoint_method
|
| 581 |
+
)
|
| 582 |
+
|
| 583 |
+
if rank == 0:
|
| 584 |
+
logging.info(
|
| 585 |
+
f"Loading Demo({d_idx}) - found {len(episode_keypoints)} keypoints - {task}"
|
| 586 |
+
)
|
| 587 |
+
|
| 588 |
+
for i in range(len(demo) - 1):
|
| 589 |
+
if not demo_augmentation and i > 0:
|
| 590 |
+
break
|
| 591 |
+
if i % demo_augmentation_every_n != 0:
|
| 592 |
+
continue
|
| 593 |
+
|
| 594 |
+
obs = demo[i]
|
| 595 |
+
desc = descs[0]
|
| 596 |
+
# if our starting point is past one of the keypoints, then remove it
|
| 597 |
+
while len(episode_keypoints) > 0 and i >= episode_keypoints[0]:
|
| 598 |
+
episode_keypoints = episode_keypoints[1:]
|
| 599 |
+
if len(episode_keypoints) == 0:
|
| 600 |
+
break
|
| 601 |
+
_add_keypoints_to_replay(
|
| 602 |
+
cfg,
|
| 603 |
+
task,
|
| 604 |
+
replay,
|
| 605 |
+
obs,
|
| 606 |
+
demo,
|
| 607 |
+
episode_keypoints,
|
| 608 |
+
description=desc,
|
| 609 |
+
clip_model=clip_model,
|
| 610 |
+
device=device,
|
| 611 |
+
)
|
| 612 |
+
logging.debug("Replay %s filled with demos." % task)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def fill_multi_task_replay(
|
| 616 |
+
cfg: DictConfig,
|
| 617 |
+
obs_config: ObservationConfig,
|
| 618 |
+
rank: int,
|
| 619 |
+
replay: ReplayBuffer,
|
| 620 |
+
tasks: List[str],
|
| 621 |
+
clip_model=None,
|
| 622 |
+
):
|
| 623 |
+
|
| 624 |
+
tasks = cfg.rlbench.tasks
|
| 625 |
+
|
| 626 |
+
manager = Manager()
|
| 627 |
+
store = manager.dict()
|
| 628 |
+
|
| 629 |
+
# create a MP dict for storing indicies
|
| 630 |
+
# TODO(mohit): this shouldn't be initialized here
|
| 631 |
+
del replay._task_idxs
|
| 632 |
+
task_idxs = manager.dict()
|
| 633 |
+
replay._task_idxs = task_idxs
|
| 634 |
+
replay._create_storage(store)
|
| 635 |
+
replay.add_count = Value("i", 0)
|
| 636 |
+
|
| 637 |
+
# fill replay buffer in parallel across tasks
|
| 638 |
+
max_parallel_processes = cfg.replay.max_parallel_processes
|
| 639 |
+
processes = []
|
| 640 |
+
n = np.arange(len(tasks))
|
| 641 |
+
split_n = utils.split_list(n, max_parallel_processes)
|
| 642 |
+
for split in split_n:
|
| 643 |
+
for e_idx, task_idx in enumerate(split):
|
| 644 |
+
task = tasks[int(task_idx)]
|
| 645 |
+
model_device = torch.device(
|
| 646 |
+
"cuda:%s" % (e_idx % torch.cuda.device_count())
|
| 647 |
+
if torch.cuda.is_available()
|
| 648 |
+
else "cpu"
|
| 649 |
+
)
|
| 650 |
+
p = Process(
|
| 651 |
+
target=fill_replay,
|
| 652 |
+
args=(
|
| 653 |
+
cfg,
|
| 654 |
+
obs_config,
|
| 655 |
+
rank,
|
| 656 |
+
replay,
|
| 657 |
+
task,
|
| 658 |
+
clip_model,
|
| 659 |
+
model_device
|
| 660 |
+
),
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
p.start()
|
| 664 |
+
processes.append(p)
|
| 665 |
+
|
| 666 |
+
for p in processes:
|
| 667 |
+
p.join()
|
third_party/AnyBimanual/agents/rvt/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""RVT package marker.
|
| 2 |
+
|
| 3 |
+
Keep package import side-effect free so downstream code can import the
|
| 4 |
+
visual stack without pulling in the full training launcher and its
|
| 5 |
+
optional dependencies.
|
| 6 |
+
"""
|
third_party/AnyBimanual/agents/rvt/launch_utils.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import List
|
| 3 |
+
import torch
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
|
| 8 |
+
from yarr.agents.agent import Agent
|
| 9 |
+
from yarr.agents.agent import ActResult
|
| 10 |
+
from yarr.agents.agent import Summary
|
| 11 |
+
from yarr.agents.agent import ScalarSummary
|
| 12 |
+
import wandb
|
| 13 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 14 |
+
import pickle
|
| 15 |
+
from helpers.preprocess_agent import PreprocessAgent
|
| 16 |
+
from agents.rvt.rvt.models.skill_manager import SkillManager
|
| 17 |
+
from agents.rvt.rvt.models.visual_aligner import VisualAligner
|
| 18 |
+
|
| 19 |
+
from agents.rvt.rvt.mvt.mvt import MVT
|
| 20 |
+
from agents.rvt.rvt.models import rvt_agent
|
| 21 |
+
from agents.rvt.rvt.utils.peract_utils import (
|
| 22 |
+
CAMERAS,
|
| 23 |
+
SCENE_BOUNDS,
|
| 24 |
+
IMAGE_SIZE,
|
| 25 |
+
DATA_FOLDER,
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
import agents.rvt.rvt.config as exp_cfg_mod
|
| 30 |
+
import agents.rvt.rvt.models.rvt_agent as rvt_agent
|
| 31 |
+
import agents.rvt.rvt.mvt.config as mvt_cfg_mod
|
| 32 |
+
import os
|
| 33 |
+
|
| 34 |
+
def create_agent(cfg: DictConfig):
|
| 35 |
+
|
| 36 |
+
exp_cfg = exp_cfg_mod.get_cfg_defaults()
|
| 37 |
+
exp_cfg.bs = cfg.replay.batch_size
|
| 38 |
+
exp_cfg.tasks = ','.join(cfg.rlbench.tasks)
|
| 39 |
+
|
| 40 |
+
exp_cfg.freeze()
|
| 41 |
+
|
| 42 |
+
mvt_cfg = mvt_cfg_mod.get_cfg_defaults()
|
| 43 |
+
mvt_cfg.proprio_dim = cfg.method.low_dim_size
|
| 44 |
+
mvt_cfg.freeze()
|
| 45 |
+
|
| 46 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 47 |
+
pkl_path = os.path.join(current_dir, "../../lang_token.pkl")
|
| 48 |
+
pkl_path = os.path.abspath(pkl_path)
|
| 49 |
+
with open(pkl_path, "rb") as f:
|
| 50 |
+
embeddings_dict = pickle.load(f)
|
| 51 |
+
flattened_embeddings = []
|
| 52 |
+
for key in embeddings_dict.keys():
|
| 53 |
+
embedding = torch.tensor(embeddings_dict[key])
|
| 54 |
+
flattened_embedding = embedding.view(-1)
|
| 55 |
+
flattened_embeddings.append(flattened_embedding)
|
| 56 |
+
embeddings_matrix = torch.stack(flattened_embeddings)
|
| 57 |
+
skill_manager = SkillManager(num_classes=18,embedding_matrix=embeddings_matrix)
|
| 58 |
+
visual_aligner = VisualAligner()
|
| 59 |
+
agent = RVTAgentWrapper(cfg.framework.checkpoint_name_prefix, cfg.rlbench, mvt_cfg, exp_cfg, skill_manager, visual_aligner)
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
preprocess_agent = PreprocessAgent(pose_agent=agent)
|
| 63 |
+
return preprocess_agent
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class RVTAgentWrapper(Agent):
|
| 68 |
+
|
| 69 |
+
def __init__(self, checkpoint_name_prefix, rlbench_cfg, mvt_cfg, exp_cfg, skill_manager, visual_aligner):
|
| 70 |
+
self._checkpoint_filename = f"{checkpoint_name_prefix}.pt"
|
| 71 |
+
self.rvt_agent = None
|
| 72 |
+
self.rlbench_cfg = rlbench_cfg
|
| 73 |
+
self.mvt_cfg = mvt_cfg
|
| 74 |
+
self.exp_cfg = exp_cfg
|
| 75 |
+
self._summaries = {}
|
| 76 |
+
self.skill_manager = skill_manager
|
| 77 |
+
self.visual_aligner = visual_aligner
|
| 78 |
+
|
| 79 |
+
def build(self, training: bool, device=None) -> None:
|
| 80 |
+
|
| 81 |
+
import torch
|
| 82 |
+
torch.cuda.set_device(device)
|
| 83 |
+
torch.cuda.empty_cache()
|
| 84 |
+
self._device = device
|
| 85 |
+
if isinstance(device, int):
|
| 86 |
+
device = f"cuda:{device}"
|
| 87 |
+
|
| 88 |
+
rvt = MVT(
|
| 89 |
+
renderer_device=device,
|
| 90 |
+
**self.mvt_cfg,
|
| 91 |
+
)
|
| 92 |
+
rvt = rvt.to(device)
|
| 93 |
+
|
| 94 |
+
if training:
|
| 95 |
+
rvt = DDP(rvt, device_ids=[device])
|
| 96 |
+
|
| 97 |
+
self.rvt_agent = rvt_agent.RVTAgent(
|
| 98 |
+
network=rvt,
|
| 99 |
+
#image_resolution=self.rlbench_cfg.camera_resolution,
|
| 100 |
+
skill_manager=self.skill_manager,
|
| 101 |
+
visual_aligner=self.visual_aligner,
|
| 102 |
+
stage_two=False,
|
| 103 |
+
add_lang=self.mvt_cfg.add_lang,
|
| 104 |
+
scene_bounds=self.rlbench_cfg.scene_bounds,
|
| 105 |
+
cameras=self.rlbench_cfg.cameras,
|
| 106 |
+
log_dir="/tmp/eval_run",
|
| 107 |
+
**self.exp_cfg.peract,
|
| 108 |
+
**self.exp_cfg.rvt,
|
| 109 |
+
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.rvt_agent.build(training, device)
|
| 113 |
+
|
| 114 |
+
def update(self, step: int, replay_sample: dict) -> dict:
|
| 115 |
+
for k, v in replay_sample.items():
|
| 116 |
+
replay_sample[k] = v.unsqueeze(1)
|
| 117 |
+
# RVT is based on the PerAct's Colab version.
|
| 118 |
+
replay_sample["lang_goal_embs"] = replay_sample["lang_token_embs"]
|
| 119 |
+
replay_sample["tasks"] = self.exp_cfg.tasks.split(',')
|
| 120 |
+
|
| 121 |
+
update_dict = self.rvt_agent.update(step, replay_sample)
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
for key, val in self.rvt_agent.loss_log.items():
|
| 125 |
+
self._summaries[key] = np.mean(np.array(val))
|
| 126 |
+
device = self._device
|
| 127 |
+
rank = device
|
| 128 |
+
if step % 10 == 0 and rank == 0:
|
| 129 |
+
wandb.log({
|
| 130 |
+
'train/grip_loss': update_dict["grip_loss"],
|
| 131 |
+
'train/trans_loss': update_dict["trans_loss"],
|
| 132 |
+
'train/rot_loss': (update_dict["rot_loss_x"]+update_dict["rot_loss_y"]+update_dict["rot_loss_z"]),
|
| 133 |
+
'train/collision_loss': update_dict["collision_loss"],
|
| 134 |
+
'train/total_loss': update_dict["total_loss"],
|
| 135 |
+
}, step=step)
|
| 136 |
+
self._wandb_summaries = {
|
| 137 |
+
'losses/grip_loss': update_dict["grip_loss"],
|
| 138 |
+
'losses/trans_loss': update_dict["trans_loss"],
|
| 139 |
+
'losses/rot_loss': (update_dict["rot_loss_x"]+update_dict["rot_loss_y"]+update_dict["rot_loss_z"]),
|
| 140 |
+
'losses/collision_loss': update_dict["collision_loss"],
|
| 141 |
+
'losses/total_loss': update_dict["total_loss"],
|
| 142 |
+
}
|
| 143 |
+
return {
|
| 144 |
+
"total_losses": update_dict["total_loss"],
|
| 145 |
+
}
|
| 146 |
+
|
| 147 |
+
return result
|
| 148 |
+
|
| 149 |
+
def act(self, step: int, observation: dict, deterministic: bool) -> ActResult:
|
| 150 |
+
return self.rvt_agent.act(step, observation, deterministic)
|
| 151 |
+
|
| 152 |
+
def reset(self) -> None:
|
| 153 |
+
self.rvt_agent.reset()
|
| 154 |
+
|
| 155 |
+
def update_summaries(self) -> List[Summary]:
|
| 156 |
+
summaries = []
|
| 157 |
+
for k, v in self._summaries.items():
|
| 158 |
+
summaries.append(ScalarSummary(f"RVT/{k}", v))
|
| 159 |
+
return summaries
|
| 160 |
+
|
| 161 |
+
def update_wandb_summaries(self):
|
| 162 |
+
summaries = dict()
|
| 163 |
+
|
| 164 |
+
for k, v in self._wandb_summaries.items():
|
| 165 |
+
summaries[k] = v
|
| 166 |
+
return summaries
|
| 167 |
+
|
| 168 |
+
def act_summaries(self) -> List[Summary]:
|
| 169 |
+
return []
|
| 170 |
+
|
| 171 |
+
def load_weights(self, savedir: str) -> None:
|
| 172 |
+
"""
|
| 173 |
+
copied from RVT
|
| 174 |
+
"""
|
| 175 |
+
device = torch.device("cuda:0")
|
| 176 |
+
weight_file = os.path.join(savedir, self._checkpoint_filename)
|
| 177 |
+
state_dict = torch.load(weight_file, map_location=device)
|
| 178 |
+
|
| 179 |
+
skill = self.rvt_agent.skill_manager
|
| 180 |
+
visual_aligner = self.rvt_agent.visual_aligner
|
| 181 |
+
model = self.rvt_agent._network
|
| 182 |
+
optimizer = self.rvt_agent._optimizer
|
| 183 |
+
lr_sched = self.rvt_agent._lr_sched
|
| 184 |
+
|
| 185 |
+
if isinstance(model, DDP):
|
| 186 |
+
model = model.module
|
| 187 |
+
model.load_state_dict(state_dict["model_state"])
|
| 188 |
+
optimizer.load_state_dict(state_dict["optimizer_state"])
|
| 189 |
+
lr_sched.load_state_dict(state_dict["lr_sched_state"])
|
| 190 |
+
|
| 191 |
+
return self.rvt_agent.load_clip()
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def save_weights(self, savedir: str) -> None:
|
| 195 |
+
|
| 196 |
+
os.makedirs(savedir, exist_ok=True)
|
| 197 |
+
weight_file = os.path.join(savedir, self._checkpoint_filename)
|
| 198 |
+
skill = self.rvt_agent.skill_manager
|
| 199 |
+
visual_aligner = self.rvt_agent.visual_aligner
|
| 200 |
+
model = self.rvt_agent._network
|
| 201 |
+
optimizer = self.rvt_agent._optimizer
|
| 202 |
+
lr_sched = self.rvt_agent._lr_sched
|
| 203 |
+
|
| 204 |
+
if isinstance(model, DDP):
|
| 205 |
+
model = model.module
|
| 206 |
+
|
| 207 |
+
skill_state = skill.state_dict()
|
| 208 |
+
visual_aligner_state = visual_aligner.state_dict()
|
| 209 |
+
model_state = model.state_dict()
|
| 210 |
+
|
| 211 |
+
torch.save(
|
| 212 |
+
{
|
| 213 |
+
"skill_state": skill_state,
|
| 214 |
+
"visual_aligner_state": visual_aligner_state,
|
| 215 |
+
"model_state": model_state,
|
| 216 |
+
"optimizer_state": optimizer.state_dict(),
|
| 217 |
+
"lr_sched_state": lr_sched.state_dict(),
|
| 218 |
+
},
|
| 219 |
+
weight_file,
|
| 220 |
+
)
|
| 221 |
+
|
third_party/AnyBimanual/agents/rvt/rvt/config.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the NVIDIA Source Code License [see LICENSE for details].
|
| 4 |
+
|
| 5 |
+
from yacs.config import CfgNode as CN
|
| 6 |
+
|
| 7 |
+
_C = CN()
|
| 8 |
+
|
| 9 |
+
_C.agent = "our"
|
| 10 |
+
_C.tasks = "insert_onto_square_peg,open_drawer,place_wine_at_rack_location,light_bulb_in"
|
| 11 |
+
_C.exp_id = "def"
|
| 12 |
+
_C.resume = ""
|
| 13 |
+
# bs per device, effective bs is scaled by num device
|
| 14 |
+
_C.bs = 4
|
| 15 |
+
_C.epochs = 20
|
| 16 |
+
# number of dataloader workers, >= 0
|
| 17 |
+
_C.num_workers = 0
|
| 18 |
+
# 'transition_uniform' or 'task_uniform'
|
| 19 |
+
_C.sample_distribution_mode = 'transition_uniform'
|
| 20 |
+
_C.train_iter = 16 * 10000
|
| 21 |
+
|
| 22 |
+
# arguments present in both peract and rvt
|
| 23 |
+
# some of them donot support every possible combination in peract
|
| 24 |
+
_C.peract = CN()
|
| 25 |
+
_C.peract.lambda_weight_l2 = 1e-6
|
| 26 |
+
# lr should be thought on per sample basis
|
| 27 |
+
# effective lr is multiplied by bs * num_devices
|
| 28 |
+
_C.peract.lr = 2.5e-5
|
| 29 |
+
_C.peract.optimizer_type = "lamb"
|
| 30 |
+
_C.peract.warmup_steps = 0
|
| 31 |
+
_C.peract.lr_cos_dec = False
|
| 32 |
+
_C.peract.add_rgc_loss = True
|
| 33 |
+
_C.peract.num_rotation_classes = 72
|
| 34 |
+
_C.peract.amp = False
|
| 35 |
+
_C.peract.bnb = False
|
| 36 |
+
_C.peract.transform_augmentation = True
|
| 37 |
+
_C.peract.transform_augmentation_xyz = [0.1, 0.1, 0.1]
|
| 38 |
+
_C.peract.transform_augmentation_rpy = [0.0, 0.0, 20.0]
|
| 39 |
+
|
| 40 |
+
# arguments present in only rvt and not peract
|
| 41 |
+
_C.rvt = CN()
|
| 42 |
+
_C.rvt.gt_hm_sigma = 1.5
|
| 43 |
+
_C.rvt.img_aug = 0.1
|
| 44 |
+
_C.rvt.place_with_mean = True
|
| 45 |
+
_C.rvt.move_pc_in_bound = True
|
| 46 |
+
|
| 47 |
+
# arguments present in peract official
|
| 48 |
+
_C.peract_official = CN()
|
| 49 |
+
_C.peract_official.cfg_path = "configs/peract_official_config.yaml"
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def get_cfg_defaults():
|
| 53 |
+
"""Get a yacs CfgNode object with default values for my_project."""
|
| 54 |
+
return _C.clone()
|
third_party/AnyBimanual/agents/rvt/rvt/configs/peract_official_config.yaml
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# copied from: https://github.com/peract/peract/releases/download/v1.0.0/peract_600k.zip
|
| 2 |
+
method:
|
| 3 |
+
name: PERACT_BC
|
| 4 |
+
lr: 0.0005
|
| 5 |
+
lr_scheduler: false
|
| 6 |
+
num_warmup_steps: 3000
|
| 7 |
+
optimizer: lamb
|
| 8 |
+
activation: lrelu
|
| 9 |
+
norm: None
|
| 10 |
+
lambda_weight_l2: 1.0e-06
|
| 11 |
+
trans_loss_weight: 1.0
|
| 12 |
+
rot_loss_weight: 1.0
|
| 13 |
+
grip_loss_weight: 1.0
|
| 14 |
+
collision_loss_weight: 1.0
|
| 15 |
+
rotation_resolution: 5
|
| 16 |
+
image_crop_size: 64
|
| 17 |
+
bounds_offset:
|
| 18 |
+
- 0.15
|
| 19 |
+
voxel_sizes:
|
| 20 |
+
- 100
|
| 21 |
+
num_latents: 2048
|
| 22 |
+
latent_dim: 512
|
| 23 |
+
transformer_depth: 6
|
| 24 |
+
transformer_iterations: 1
|
| 25 |
+
cross_heads: 1
|
| 26 |
+
cross_dim_head: 64
|
| 27 |
+
latent_heads: 8
|
| 28 |
+
latent_dim_head: 64
|
| 29 |
+
pos_encoding_with_lang: false
|
| 30 |
+
lang_fusion_type: seq
|
| 31 |
+
voxel_patch_size: 5
|
| 32 |
+
voxel_patch_stride: 5
|
| 33 |
+
input_dropout: 0.1
|
| 34 |
+
attn_dropout: 0.1
|
| 35 |
+
decoder_dropout: 0.0
|
| 36 |
+
crop_augmentation: true
|
| 37 |
+
final_dim: 64
|
| 38 |
+
transform_augmentation:
|
| 39 |
+
apply_se3: true
|
| 40 |
+
aug_xyz:
|
| 41 |
+
- 0.125
|
| 42 |
+
- 0.125
|
| 43 |
+
- 0.125
|
| 44 |
+
aug_rpy:
|
| 45 |
+
- 0.0
|
| 46 |
+
- 0.0
|
| 47 |
+
- 0.0
|
| 48 |
+
aug_rot_resolution: 5
|
| 49 |
+
demo_augmentation: true
|
| 50 |
+
demo_augmentation_every_n: 10
|
| 51 |
+
no_skip_connection: false
|
| 52 |
+
no_perceiver: false
|
| 53 |
+
no_language: false
|
| 54 |
+
keypoint_method: heuristic
|
| 55 |
+
ddp:
|
| 56 |
+
master_addr: "localhost"
|
| 57 |
+
master_port: "29500"
|
| 58 |
+
num_devices: 1
|
| 59 |
+
rlbench:
|
| 60 |
+
task_name: multi
|
| 61 |
+
tasks:
|
| 62 |
+
- change_channel
|
| 63 |
+
- close_jar
|
| 64 |
+
- insert_onto_square_peg
|
| 65 |
+
- light_bulb_in
|
| 66 |
+
- meat_off_grill
|
| 67 |
+
- open_drawer
|
| 68 |
+
- place_cups
|
| 69 |
+
- place_shape_in_shape_sorter
|
| 70 |
+
- push_buttons
|
| 71 |
+
- put_groceries_in_cupboard
|
| 72 |
+
- put_item_in_drawer
|
| 73 |
+
- put_money_in_safe
|
| 74 |
+
- reach_and_drag
|
| 75 |
+
- stack_blocks
|
| 76 |
+
- stack_cups
|
| 77 |
+
- turn_tap
|
| 78 |
+
- set_clock_to_time
|
| 79 |
+
- place_wine_at_rack_location
|
| 80 |
+
- put_rubbish_in_color_bin
|
| 81 |
+
- slide_block_to_color_target
|
| 82 |
+
- sweep_to_dustpan_of_size
|
| 83 |
+
demos: 100
|
| 84 |
+
demo_path: /raid/dataset/
|
| 85 |
+
episode_length: 25
|
| 86 |
+
cameras:
|
| 87 |
+
- front
|
| 88 |
+
- left_shoulder
|
| 89 |
+
- right_shoulder
|
| 90 |
+
- wrist
|
| 91 |
+
camera_resolution:
|
| 92 |
+
- 128
|
| 93 |
+
- 128
|
| 94 |
+
scene_bounds:
|
| 95 |
+
- -0.3
|
| 96 |
+
- -0.5
|
| 97 |
+
- 0.6
|
| 98 |
+
- 0.7
|
| 99 |
+
- 0.5
|
| 100 |
+
- 1.6
|
| 101 |
+
include_lang_goal_in_obs: True
|
| 102 |
+
replay:
|
| 103 |
+
batch_size: 16
|
| 104 |
+
timesteps: 1
|
| 105 |
+
prioritisation: false
|
| 106 |
+
task_uniform: true
|
| 107 |
+
use_disk: true
|
| 108 |
+
path: /raid/arm/replay
|
| 109 |
+
max_parallel_processes: 32
|
| 110 |
+
framework:
|
| 111 |
+
log_freq: 100
|
| 112 |
+
save_freq: 10000
|
| 113 |
+
train_envs: 1
|
| 114 |
+
replay_ratio: 16
|
| 115 |
+
transitions_before_train: 200
|
| 116 |
+
tensorboard_logging: true
|
| 117 |
+
csv_logging: true
|
| 118 |
+
training_iterations: 600001
|
| 119 |
+
gpu: 0
|
| 120 |
+
env_gpu: 0
|
| 121 |
+
logdir: /home/user/workspace/logs_may16_n100
|
| 122 |
+
seeds: 1
|
| 123 |
+
start_seed: 0
|
| 124 |
+
load_existing_weights: true
|
| 125 |
+
num_weights_to_keep: 60
|
| 126 |
+
record_every_n: 5
|
| 127 |
+
|
third_party/AnyBimanual/agents/rvt/rvt/configs/rvt.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_id: rvt
|
| 2 |
+
tasks: all
|
| 3 |
+
bs: 3
|
| 4 |
+
num_workers: 3
|
| 5 |
+
epochs: 15
|
| 6 |
+
sample_distribution_mode: task_uniform
|
| 7 |
+
peract:
|
| 8 |
+
lr: 1e-4
|
| 9 |
+
warmup_steps: 2000
|
| 10 |
+
optimizer_type: lamb
|
| 11 |
+
lr_cos_dec: True
|
| 12 |
+
transform_augmentation_xyz: [0.125, 0.125, 0.125]
|
| 13 |
+
transform_augmentation_rpy: [0.0, 0.0, 45.0]
|
| 14 |
+
rvt:
|
| 15 |
+
place_with_mean: False
|
third_party/AnyBimanual/agents/rvt/rvt/configs/rvt2.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
exp_id: rvt2
|
| 2 |
+
tasks: all
|
| 3 |
+
bs: 24
|
| 4 |
+
num_workers: 3
|
| 5 |
+
epochs: 15
|
| 6 |
+
sample_distribution_mode: task_uniform
|
| 7 |
+
peract:
|
| 8 |
+
lr: 1.25e-5
|
| 9 |
+
warmup_steps: 2000
|
| 10 |
+
optimizer_type: lamb
|
| 11 |
+
lr_cos_dec: True
|
| 12 |
+
transform_augmentation_xyz: [0.125, 0.125, 0.125]
|
| 13 |
+
transform_augmentation_rpy: [0.0, 0.0, 45.0]
|
| 14 |
+
amp: True
|
| 15 |
+
bnb: True
|
| 16 |
+
lambda_weight_l2: 1e-4
|
| 17 |
+
rvt:
|
| 18 |
+
place_with_mean: False
|
| 19 |
+
img_aug: 0.0
|
third_party/AnyBimanual/agents/rvt/rvt/eval.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the NVIDIA Source Code License [see LICENSE for details].
|
| 4 |
+
|
| 5 |
+
import os
|
| 6 |
+
import yaml
|
| 7 |
+
import csv
|
| 8 |
+
import torch
|
| 9 |
+
import cv2
|
| 10 |
+
import shutil
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
from omegaconf import OmegaConf
|
| 15 |
+
from multiprocessing import Value
|
| 16 |
+
from tensorflow.python.summary.summary_iterator import summary_iterator
|
| 17 |
+
from copy import deepcopy
|
| 18 |
+
|
| 19 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
| 20 |
+
os.environ["BITSANDBYTES_NOWELCOME"] = "1"
|
| 21 |
+
|
| 22 |
+
from rlbench.backend import task as rlbench_task
|
| 23 |
+
from rlbench.backend.utils import task_file_to_task_class
|
| 24 |
+
from rlbench.action_modes.gripper_action_modes import Discrete
|
| 25 |
+
from rlbench.action_modes.action_mode import MoveArmThenGripper
|
| 26 |
+
from yarr.utils.rollout_generator import RolloutGenerator
|
| 27 |
+
from yarr.utils.stat_accumulator import SimpleAccumulator
|
| 28 |
+
from yarr.utils.log_writer import LogWriter
|
| 29 |
+
from yarr.agents.agent import VideoSummary
|
| 30 |
+
|
| 31 |
+
import rvt.mvt.config as default_mvt_cfg
|
| 32 |
+
import rvt.models.rvt_agent as rvt_agent
|
| 33 |
+
import rvt.config as default_exp_cfg
|
| 34 |
+
|
| 35 |
+
from rvt.mvt.mvt import MVT
|
| 36 |
+
from rvt.libs.peract.helpers import utils
|
| 37 |
+
from rvt.utils.custom_rlbench_env import (
|
| 38 |
+
CustomMultiTaskRLBenchEnv2 as CustomMultiTaskRLBenchEnv,
|
| 39 |
+
)
|
| 40 |
+
from rvt.utils.peract_utils import (
|
| 41 |
+
CAMERAS,
|
| 42 |
+
SCENE_BOUNDS,
|
| 43 |
+
IMAGE_SIZE,
|
| 44 |
+
get_official_peract,
|
| 45 |
+
)
|
| 46 |
+
from rvt.utils.rlbench_planning import (
|
| 47 |
+
EndEffectorPoseViaPlanning2 as EndEffectorPoseViaPlanning,
|
| 48 |
+
)
|
| 49 |
+
from rvt.utils.rvt_utils import (
|
| 50 |
+
TensorboardManager,
|
| 51 |
+
get_eval_parser,
|
| 52 |
+
RLBENCH_TASKS,
|
| 53 |
+
)
|
| 54 |
+
from rvt.utils.rvt_utils import load_agent as load_agent_state
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def load_agent(
|
| 58 |
+
model_path=None,
|
| 59 |
+
peract_official=False,
|
| 60 |
+
peract_model_dir=None,
|
| 61 |
+
exp_cfg_path=None,
|
| 62 |
+
mvt_cfg_path=None,
|
| 63 |
+
eval_log_dir="",
|
| 64 |
+
device=0,
|
| 65 |
+
use_input_place_with_mean=False,
|
| 66 |
+
):
|
| 67 |
+
device = f"cuda:{device}"
|
| 68 |
+
|
| 69 |
+
if not (peract_official):
|
| 70 |
+
assert model_path is not None
|
| 71 |
+
|
| 72 |
+
# load exp_cfg
|
| 73 |
+
model_folder = os.path.join(os.path.dirname(model_path))
|
| 74 |
+
|
| 75 |
+
exp_cfg = default_exp_cfg.get_cfg_defaults()
|
| 76 |
+
if exp_cfg_path != None:
|
| 77 |
+
exp_cfg.merge_from_file(exp_cfg_path)
|
| 78 |
+
else:
|
| 79 |
+
exp_cfg.merge_from_file(os.path.join(model_folder, "exp_cfg.yaml"))
|
| 80 |
+
|
| 81 |
+
# NOTE: to not use place_with_mean in evaluation
|
| 82 |
+
# needed for rvt-1 but not rvt-2
|
| 83 |
+
if not use_input_place_with_mean:
|
| 84 |
+
# for backward compatibility
|
| 85 |
+
old_place_with_mean = exp_cfg.rvt.place_with_mean
|
| 86 |
+
exp_cfg.rvt.place_with_mean = True
|
| 87 |
+
|
| 88 |
+
exp_cfg.freeze()
|
| 89 |
+
|
| 90 |
+
# create agent
|
| 91 |
+
if exp_cfg.agent == "original":
|
| 92 |
+
# initialize PerceiverIO Transformer
|
| 93 |
+
VOXEL_SIZES = [100] # 100x100x100 voxels
|
| 94 |
+
NUM_LATENTS = 512 # PerceiverIO latents
|
| 95 |
+
BATCH_SIZE_TRAIN = 1
|
| 96 |
+
perceiver_encoder = PerceiverIO(
|
| 97 |
+
depth=6,
|
| 98 |
+
iterations=1,
|
| 99 |
+
voxel_size=VOXEL_SIZES[0],
|
| 100 |
+
initial_dim=3 + 3 + 1 + 3,
|
| 101 |
+
low_dim_size=4,
|
| 102 |
+
layer=0,
|
| 103 |
+
num_rotation_classes=72,
|
| 104 |
+
num_grip_classes=2,
|
| 105 |
+
num_collision_classes=2,
|
| 106 |
+
num_latents=NUM_LATENTS,
|
| 107 |
+
latent_dim=512,
|
| 108 |
+
cross_heads=1,
|
| 109 |
+
latent_heads=8,
|
| 110 |
+
cross_dim_head=64,
|
| 111 |
+
latent_dim_head=64,
|
| 112 |
+
weight_tie_layers=False,
|
| 113 |
+
activation="lrelu",
|
| 114 |
+
input_dropout=0.1,
|
| 115 |
+
attn_dropout=0.1,
|
| 116 |
+
decoder_dropout=0.0,
|
| 117 |
+
voxel_patch_size=5,
|
| 118 |
+
voxel_patch_stride=5,
|
| 119 |
+
final_dim=64,
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
# initialize PerceiverActor
|
| 123 |
+
agent = PerceiverActorAgent(
|
| 124 |
+
coordinate_bounds=SCENE_BOUNDS,
|
| 125 |
+
perceiver_encoder=perceiver_encoder,
|
| 126 |
+
camera_names=CAMERAS,
|
| 127 |
+
batch_size=BATCH_SIZE_TRAIN,
|
| 128 |
+
voxel_size=VOXEL_SIZES[0],
|
| 129 |
+
voxel_feature_size=3,
|
| 130 |
+
num_rotation_classes=72,
|
| 131 |
+
rotation_resolution=5,
|
| 132 |
+
image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
|
| 133 |
+
transform_augmentation=False,
|
| 134 |
+
**exp_cfg.peract,
|
| 135 |
+
)
|
| 136 |
+
elif exp_cfg.agent == "our":
|
| 137 |
+
mvt_cfg = default_mvt_cfg.get_cfg_defaults()
|
| 138 |
+
if mvt_cfg_path != None:
|
| 139 |
+
mvt_cfg.merge_from_file(mvt_cfg_path)
|
| 140 |
+
else:
|
| 141 |
+
mvt_cfg.merge_from_file(os.path.join(model_folder, "mvt_cfg.yaml"))
|
| 142 |
+
|
| 143 |
+
mvt_cfg.freeze()
|
| 144 |
+
|
| 145 |
+
# for rvt-2 we do not change place_with_mean regardless of the arg
|
| 146 |
+
# done this way to ensure backward compatibility and allow the
|
| 147 |
+
# flexibility for rvt-1
|
| 148 |
+
if mvt_cfg.stage_two:
|
| 149 |
+
exp_cfg.defrost()
|
| 150 |
+
exp_cfg.rvt.place_with_mean = old_place_with_mean
|
| 151 |
+
exp_cfg.freeze()
|
| 152 |
+
|
| 153 |
+
rvt = MVT(
|
| 154 |
+
renderer_device=device,
|
| 155 |
+
**mvt_cfg,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
agent = rvt_agent.RVTAgent(
|
| 159 |
+
network=rvt.to(device),
|
| 160 |
+
image_resolution=[IMAGE_SIZE, IMAGE_SIZE],
|
| 161 |
+
add_lang=mvt_cfg.add_lang,
|
| 162 |
+
stage_two=mvt_cfg.stage_two,
|
| 163 |
+
rot_ver=mvt_cfg.rot_ver,
|
| 164 |
+
scene_bounds=SCENE_BOUNDS,
|
| 165 |
+
cameras=CAMERAS,
|
| 166 |
+
log_dir=f"{eval_log_dir}/eval_run",
|
| 167 |
+
**exp_cfg.peract,
|
| 168 |
+
**exp_cfg.rvt,
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
raise NotImplementedError
|
| 172 |
+
|
| 173 |
+
agent.build(training=False, device=device)
|
| 174 |
+
load_agent_state(model_path, agent)
|
| 175 |
+
agent.eval()
|
| 176 |
+
|
| 177 |
+
elif peract_official: # load official peract model, using the provided code
|
| 178 |
+
try:
|
| 179 |
+
model_folder = os.path.join(os.path.abspath(peract_model_dir), "..", "..")
|
| 180 |
+
train_cfg_path = os.path.join(model_folder, "config.yaml")
|
| 181 |
+
agent = get_official_peract(train_cfg_path, False, device, bs=1)
|
| 182 |
+
except FileNotFoundError:
|
| 183 |
+
print("Config file not found, trying to load again in our format")
|
| 184 |
+
train_cfg_path = "configs/peract_official_config.yaml"
|
| 185 |
+
agent = get_official_peract(train_cfg_path, False, device, bs=1)
|
| 186 |
+
agent.load_weights(peract_model_dir)
|
| 187 |
+
agent.eval()
|
| 188 |
+
|
| 189 |
+
print("Agent Information")
|
| 190 |
+
print(agent)
|
| 191 |
+
return agent
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
@torch.no_grad()
|
| 195 |
+
def eval(
|
| 196 |
+
agent,
|
| 197 |
+
tasks,
|
| 198 |
+
eval_datafolder,
|
| 199 |
+
start_episode=0,
|
| 200 |
+
eval_episodes=25,
|
| 201 |
+
episode_length=25,
|
| 202 |
+
replay_ground_truth=False,
|
| 203 |
+
device=0,
|
| 204 |
+
headless=True,
|
| 205 |
+
logging=False,
|
| 206 |
+
log_dir=None,
|
| 207 |
+
verbose=True,
|
| 208 |
+
save_video=False,
|
| 209 |
+
):
|
| 210 |
+
agent.eval()
|
| 211 |
+
if isinstance(agent, rvt_agent.RVTAgent):
|
| 212 |
+
agent.load_clip()
|
| 213 |
+
|
| 214 |
+
camera_resolution = [IMAGE_SIZE, IMAGE_SIZE]
|
| 215 |
+
obs_config = utils.create_obs_config(CAMERAS, camera_resolution, method_name="")
|
| 216 |
+
|
| 217 |
+
gripper_mode = Discrete()
|
| 218 |
+
arm_action_mode = EndEffectorPoseViaPlanning()
|
| 219 |
+
action_mode = MoveArmThenGripper(arm_action_mode, gripper_mode)
|
| 220 |
+
|
| 221 |
+
task_files = [
|
| 222 |
+
t.replace(".py", "")
|
| 223 |
+
for t in os.listdir(rlbench_task.TASKS_PATH)
|
| 224 |
+
if t != "__init__.py" and t.endswith(".py")
|
| 225 |
+
]
|
| 226 |
+
|
| 227 |
+
task_classes = []
|
| 228 |
+
if tasks[0] == "all":
|
| 229 |
+
tasks = RLBENCH_TASKS
|
| 230 |
+
if verbose:
|
| 231 |
+
print(f"evaluate on {len(tasks)} tasks: ", tasks)
|
| 232 |
+
|
| 233 |
+
for task in tasks:
|
| 234 |
+
if task not in task_files:
|
| 235 |
+
raise ValueError("Task %s not recognised!." % task)
|
| 236 |
+
task_classes.append(task_file_to_task_class(task))
|
| 237 |
+
|
| 238 |
+
eval_env = CustomMultiTaskRLBenchEnv(
|
| 239 |
+
task_classes=task_classes,
|
| 240 |
+
observation_config=obs_config,
|
| 241 |
+
action_mode=action_mode,
|
| 242 |
+
dataset_root=eval_datafolder,
|
| 243 |
+
episode_length=episode_length,
|
| 244 |
+
headless=headless,
|
| 245 |
+
swap_task_every=eval_episodes,
|
| 246 |
+
include_lang_goal_in_obs=True,
|
| 247 |
+
time_in_state=True,
|
| 248 |
+
record_every_n=1 if save_video else -1,
|
| 249 |
+
)
|
| 250 |
+
|
| 251 |
+
eval_env.eval = True
|
| 252 |
+
|
| 253 |
+
device = f"cuda:{device}"
|
| 254 |
+
|
| 255 |
+
if logging:
|
| 256 |
+
assert log_dir is not None
|
| 257 |
+
|
| 258 |
+
# create metric saving writer
|
| 259 |
+
csv_file = "eval_results.csv"
|
| 260 |
+
if not os.path.exists(os.path.join(log_dir, csv_file)):
|
| 261 |
+
with open(os.path.join(log_dir, csv_file), "w") as csv_fp:
|
| 262 |
+
fieldnames = ["task", "success rate", "length", "total_transitions"]
|
| 263 |
+
csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
|
| 264 |
+
csv_writer.writeheader()
|
| 265 |
+
|
| 266 |
+
# evaluate agent
|
| 267 |
+
rollout_generator = RolloutGenerator(device)
|
| 268 |
+
stats_accumulator = SimpleAccumulator(eval_video_fps=30)
|
| 269 |
+
|
| 270 |
+
eval_env.launch()
|
| 271 |
+
|
| 272 |
+
current_task_id = -1
|
| 273 |
+
|
| 274 |
+
num_tasks = len(tasks)
|
| 275 |
+
step_signal = Value("i", -1)
|
| 276 |
+
|
| 277 |
+
scores = []
|
| 278 |
+
for task_id in range(num_tasks):
|
| 279 |
+
task_rewards = []
|
| 280 |
+
for ep in range(start_episode, start_episode + eval_episodes):
|
| 281 |
+
episode_rollout = []
|
| 282 |
+
generator = rollout_generator.generator(
|
| 283 |
+
step_signal=step_signal,
|
| 284 |
+
env=eval_env,
|
| 285 |
+
agent=agent,
|
| 286 |
+
episode_length=episode_length,
|
| 287 |
+
timesteps=1,
|
| 288 |
+
eval=True,
|
| 289 |
+
eval_demo_seed=ep,
|
| 290 |
+
record_enabled=False,
|
| 291 |
+
replay_ground_truth=replay_ground_truth,
|
| 292 |
+
)
|
| 293 |
+
try:
|
| 294 |
+
for replay_transition in generator:
|
| 295 |
+
episode_rollout.append(replay_transition)
|
| 296 |
+
except StopIteration as e:
|
| 297 |
+
continue
|
| 298 |
+
except Exception as e:
|
| 299 |
+
eval_env.shutdown()
|
| 300 |
+
raise e
|
| 301 |
+
|
| 302 |
+
for transition in episode_rollout:
|
| 303 |
+
stats_accumulator.step(transition, True)
|
| 304 |
+
current_task_id = transition.info["active_task_id"]
|
| 305 |
+
assert current_task_id == task_id
|
| 306 |
+
|
| 307 |
+
task_name = tasks[task_id]
|
| 308 |
+
reward = episode_rollout[-1].reward
|
| 309 |
+
task_rewards.append(reward)
|
| 310 |
+
lang_goal = eval_env._lang_goal
|
| 311 |
+
if verbose:
|
| 312 |
+
print(
|
| 313 |
+
f"Evaluating {task_name} | Episode {ep} | Score: {reward} | Episode Length: {len(episode_rollout)} | Lang Goal: {lang_goal}"
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# report summaries
|
| 317 |
+
summaries = []
|
| 318 |
+
summaries.extend(stats_accumulator.pop())
|
| 319 |
+
task_name = tasks[task_id]
|
| 320 |
+
if logging:
|
| 321 |
+
# writer csv first
|
| 322 |
+
with open(os.path.join(log_dir, csv_file), "a") as csv_fp:
|
| 323 |
+
fieldnames = ["task", "success rate", "length", "total_transitions"]
|
| 324 |
+
csv_writer = csv.DictWriter(csv_fp, fieldnames=fieldnames)
|
| 325 |
+
csv_results = {"task": task_name}
|
| 326 |
+
for s in summaries:
|
| 327 |
+
if s.name == "eval_envs/return":
|
| 328 |
+
csv_results["success rate"] = s.value
|
| 329 |
+
elif s.name == "eval_envs/length":
|
| 330 |
+
csv_results["length"] = s.value
|
| 331 |
+
elif s.name == "eval_envs/total_transitions":
|
| 332 |
+
csv_results["total_transitions"] = s.value
|
| 333 |
+
if "eval" in s.name:
|
| 334 |
+
s.name = "%s/%s" % (s.name, task_name)
|
| 335 |
+
csv_writer.writerow(csv_results)
|
| 336 |
+
else:
|
| 337 |
+
for s in summaries:
|
| 338 |
+
if "eval" in s.name:
|
| 339 |
+
s.name = "%s/%s" % (s.name, task_name)
|
| 340 |
+
|
| 341 |
+
if len(summaries) > 0:
|
| 342 |
+
task_score = [
|
| 343 |
+
s.value for s in summaries if f"eval_envs/return/{task_name}" in s.name
|
| 344 |
+
][0]
|
| 345 |
+
else:
|
| 346 |
+
task_score = "unknown"
|
| 347 |
+
|
| 348 |
+
print(f"[Evaluation] Finished {task_name} | Final Score: {task_score}\n")
|
| 349 |
+
|
| 350 |
+
scores.append(task_score)
|
| 351 |
+
|
| 352 |
+
if save_video:
|
| 353 |
+
video_image_folder = "./tmp"
|
| 354 |
+
record_fps = 25
|
| 355 |
+
record_folder = os.path.join(log_dir, "videos")
|
| 356 |
+
os.makedirs(record_folder, exist_ok=True)
|
| 357 |
+
video_success_cnt = 0
|
| 358 |
+
video_fail_cnt = 0
|
| 359 |
+
video_cnt = 0
|
| 360 |
+
for summary in summaries:
|
| 361 |
+
if isinstance(summary, VideoSummary):
|
| 362 |
+
video = deepcopy(summary.value)
|
| 363 |
+
video = np.transpose(video, (0, 2, 3, 1))
|
| 364 |
+
video = video[:, :, :, ::-1]
|
| 365 |
+
if task_rewards[video_cnt] > 99:
|
| 366 |
+
video_path = os.path.join(
|
| 367 |
+
record_folder,
|
| 368 |
+
f"{task_name}_success_{video_success_cnt}.mp4",
|
| 369 |
+
)
|
| 370 |
+
video_success_cnt += 1
|
| 371 |
+
else:
|
| 372 |
+
video_path = os.path.join(
|
| 373 |
+
record_folder, f"{task_name}_fail_{video_fail_cnt}.mp4"
|
| 374 |
+
)
|
| 375 |
+
video_fail_cnt += 1
|
| 376 |
+
video_cnt += 1
|
| 377 |
+
os.makedirs(video_image_folder, exist_ok=True)
|
| 378 |
+
for idx in range(len(video) - 10):
|
| 379 |
+
cv2.imwrite(
|
| 380 |
+
os.path.join(video_image_folder, f"{idx}.png"), video[idx]
|
| 381 |
+
)
|
| 382 |
+
images_path = os.path.join(video_image_folder, r"%d.png")
|
| 383 |
+
os.system(
|
| 384 |
+
"ffmpeg -i {} -vf palettegen palette.png -hide_banner -loglevel error".format(
|
| 385 |
+
images_path
|
| 386 |
+
)
|
| 387 |
+
)
|
| 388 |
+
os.system(
|
| 389 |
+
"ffmpeg -framerate {} -i {} -i palette.png -lavfi paletteuse {} -hide_banner -loglevel error".format(
|
| 390 |
+
record_fps, images_path, video_path
|
| 391 |
+
)
|
| 392 |
+
)
|
| 393 |
+
os.remove("palette.png")
|
| 394 |
+
shutil.rmtree(video_image_folder)
|
| 395 |
+
|
| 396 |
+
eval_env.shutdown()
|
| 397 |
+
|
| 398 |
+
if logging:
|
| 399 |
+
csv_fp.close()
|
| 400 |
+
|
| 401 |
+
# set agent to back train mode
|
| 402 |
+
agent.train()
|
| 403 |
+
|
| 404 |
+
# unloading clip to save memory
|
| 405 |
+
if isinstance(agent, rvt_agent.RVTAgent):
|
| 406 |
+
agent.unload_clip()
|
| 407 |
+
agent._network.free_mem()
|
| 408 |
+
|
| 409 |
+
return scores
|
| 410 |
+
|
| 411 |
+
|
| 412 |
+
def get_model_index(filename):
|
| 413 |
+
"""
|
| 414 |
+
:param filenam: path of file of format /.../model_idx.pth
|
| 415 |
+
:return: idx or None
|
| 416 |
+
"""
|
| 417 |
+
if len(filename) >= 9 and filename[-4:] == ".pth":
|
| 418 |
+
try:
|
| 419 |
+
index = int(filename[:-4].split("_")[-1])
|
| 420 |
+
except:
|
| 421 |
+
index = None
|
| 422 |
+
else:
|
| 423 |
+
index = None
|
| 424 |
+
return index
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
def _eval(args):
|
| 428 |
+
|
| 429 |
+
model_paths = []
|
| 430 |
+
if not (args.peract_official):
|
| 431 |
+
assert args.model_name is not None
|
| 432 |
+
model_paths.append(os.path.join(args.model_folder, args.model_name))
|
| 433 |
+
else:
|
| 434 |
+
model_paths.append(None)
|
| 435 |
+
|
| 436 |
+
# skipping evaluated models
|
| 437 |
+
if args.skip:
|
| 438 |
+
"""
|
| 439 |
+
to_skip: {
|
| 440 |
+
0: {'light_bulb_in': False, .....}
|
| 441 |
+
1: {'light_bulb_in': False, .....}
|
| 442 |
+
.
|
| 443 |
+
.
|
| 444 |
+
}
|
| 445 |
+
"""
|
| 446 |
+
to_skip = {
|
| 447 |
+
get_model_index(x): {y: False for y in args.tasks} for x in model_paths
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
filenames = os.listdir(args.eval_log_dir)
|
| 451 |
+
for filename in filenames:
|
| 452 |
+
if not filename.startswith("events.out.tfevents."):
|
| 453 |
+
continue
|
| 454 |
+
summ = summary_iterator(f"{args.eval_log_dir}/{filename}")
|
| 455 |
+
# skipping the time log of the summary
|
| 456 |
+
try:
|
| 457 |
+
next(summ)
|
| 458 |
+
except:
|
| 459 |
+
# moving to the next file
|
| 460 |
+
continue
|
| 461 |
+
for cur_summ in summ:
|
| 462 |
+
cur_task = cur_summ.summary.value[0].tag[5:]
|
| 463 |
+
cur_step = cur_summ.step
|
| 464 |
+
if cur_step in to_skip:
|
| 465 |
+
to_skip[cur_step][cur_task] = True
|
| 466 |
+
|
| 467 |
+
tb = TensorboardManager(args.eval_log_dir)
|
| 468 |
+
for model_path in model_paths:
|
| 469 |
+
tasks_to_eval = deepcopy(args.tasks)
|
| 470 |
+
|
| 471 |
+
if args.peract_official:
|
| 472 |
+
model_idx = 0
|
| 473 |
+
else:
|
| 474 |
+
model_idx = get_model_index(model_path)
|
| 475 |
+
if model_idx is None:
|
| 476 |
+
model_idx = 0
|
| 477 |
+
|
| 478 |
+
if args.skip:
|
| 479 |
+
for _task in args.tasks:
|
| 480 |
+
if to_skip[model_idx][_task]:
|
| 481 |
+
tasks_to_eval.remove(_task)
|
| 482 |
+
|
| 483 |
+
if len(tasks_to_eval) == 0:
|
| 484 |
+
print(f"Skipping model_idx={model_idx} for args.tasks={args.tasks}")
|
| 485 |
+
continue
|
| 486 |
+
|
| 487 |
+
if not (args.peract_official):
|
| 488 |
+
agent = load_agent(
|
| 489 |
+
model_path=model_path,
|
| 490 |
+
exp_cfg_path=args.exp_cfg_path,
|
| 491 |
+
mvt_cfg_path=args.mvt_cfg_path,
|
| 492 |
+
eval_log_dir=args.eval_log_dir,
|
| 493 |
+
device=args.device,
|
| 494 |
+
use_input_place_with_mean=args.use_input_place_with_mean,
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
agent_eval_log_dir = os.path.join(
|
| 498 |
+
args.eval_log_dir, os.path.basename(model_path).split(".")[0]
|
| 499 |
+
)
|
| 500 |
+
else:
|
| 501 |
+
agent = load_agent(
|
| 502 |
+
peract_official=args.peract_official,
|
| 503 |
+
peract_model_dir=args.peract_model_dir,
|
| 504 |
+
device=args.device,
|
| 505 |
+
use_input_place_with_mean=args.use_input_place_with_mean,
|
| 506 |
+
)
|
| 507 |
+
agent_eval_log_dir = os.path.join(args.eval_log_dir, "final")
|
| 508 |
+
|
| 509 |
+
os.makedirs(agent_eval_log_dir, exist_ok=True)
|
| 510 |
+
scores = eval(
|
| 511 |
+
agent=agent,
|
| 512 |
+
tasks=tasks_to_eval,
|
| 513 |
+
eval_datafolder=args.eval_datafolder,
|
| 514 |
+
start_episode=args.start_episode,
|
| 515 |
+
eval_episodes=args.eval_episodes,
|
| 516 |
+
episode_length=args.episode_length,
|
| 517 |
+
replay_ground_truth=args.ground_truth,
|
| 518 |
+
device=args.device,
|
| 519 |
+
headless=args.headless,
|
| 520 |
+
logging=True,
|
| 521 |
+
log_dir=agent_eval_log_dir,
|
| 522 |
+
verbose=True,
|
| 523 |
+
save_video=args.save_video,
|
| 524 |
+
)
|
| 525 |
+
print(f"model {model_path}, scores {scores}")
|
| 526 |
+
task_scores = {}
|
| 527 |
+
for i in range(len(tasks_to_eval)):
|
| 528 |
+
task_scores[tasks_to_eval[i]] = scores[i]
|
| 529 |
+
|
| 530 |
+
print("save ", task_scores)
|
| 531 |
+
tb.update("eval", model_idx, task_scores)
|
| 532 |
+
tb.writer.flush()
|
| 533 |
+
|
| 534 |
+
tb.close()
|
| 535 |
+
|
| 536 |
+
|
| 537 |
+
if __name__ == "__main__":
|
| 538 |
+
parser = get_eval_parser()
|
| 539 |
+
|
| 540 |
+
args = parser.parse_args()
|
| 541 |
+
|
| 542 |
+
if args.log_name is None:
|
| 543 |
+
args.log_name = "none"
|
| 544 |
+
|
| 545 |
+
if not (args.peract_official):
|
| 546 |
+
args.eval_log_dir = os.path.join(args.model_folder, "eval", args.log_name)
|
| 547 |
+
else:
|
| 548 |
+
args.eval_log_dir = os.path.join(args.peract_model_dir, "eval", args.log_name)
|
| 549 |
+
|
| 550 |
+
os.makedirs(args.eval_log_dir, exist_ok=True)
|
| 551 |
+
|
| 552 |
+
# save the arguments for future reference
|
| 553 |
+
with open(os.path.join(args.eval_log_dir, "eval_config.yaml"), "w") as fp:
|
| 554 |
+
yaml.dump(args.__dict__, fp)
|
| 555 |
+
|
| 556 |
+
_eval(args)
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitattributes
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pcd_data.tar.gz filter=lfs diff=lfs merge=lfs -text
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
build/*
|
| 2 |
+
*.egg-info/*
|
| 3 |
+
*.so
|
| 4 |
+
*/__pycache__/*
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/LICENSE
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2022-2023, NVIDIA Corporation & affiliates. All rights reserved.
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
NVIDIA Source Code License for instant neural graphics primitives
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
=======================================================================
|
| 8 |
+
|
| 9 |
+
1. Definitions
|
| 10 |
+
|
| 11 |
+
"Licensor" means any person or entity that distributes its Work.
|
| 12 |
+
|
| 13 |
+
"Software" means the original work of authorship made available under
|
| 14 |
+
this License.
|
| 15 |
+
|
| 16 |
+
"Work" means the Software and any additions to or derivative works of
|
| 17 |
+
the Software that are made available under this License.
|
| 18 |
+
|
| 19 |
+
The terms "reproduce," "reproduction," "derivative works," and
|
| 20 |
+
"distribution" have the meaning as provided under U.S. copyright law;
|
| 21 |
+
provided, however, that for the purposes of this License, derivative
|
| 22 |
+
works shall not include works that remain separable from, or merely
|
| 23 |
+
link (or bind by name) to the interfaces of, the Work.
|
| 24 |
+
|
| 25 |
+
Works, including the Software, are "made available" under this License
|
| 26 |
+
by including in or with the Work either (a) a copyright notice
|
| 27 |
+
referencing the applicability of this License to the Work, or (b) a
|
| 28 |
+
copy of this License.
|
| 29 |
+
|
| 30 |
+
2. License Grants
|
| 31 |
+
|
| 32 |
+
2.1 Copyright Grant. Subject to the terms and conditions of this
|
| 33 |
+
License, each Licensor grants to you a perpetual, worldwide,
|
| 34 |
+
non-exclusive, royalty-free, copyright license to reproduce,
|
| 35 |
+
prepare derivative works of, publicly display, publicly perform,
|
| 36 |
+
sublicense and distribute its Work and any resulting derivative
|
| 37 |
+
works in any form.
|
| 38 |
+
|
| 39 |
+
3. Limitations
|
| 40 |
+
|
| 41 |
+
3.1 Redistribution. You may reproduce or distribute the Work only
|
| 42 |
+
if (a) you do so under this License, (b) you include a complete
|
| 43 |
+
copy of this License with your distribution, and (c) you retain
|
| 44 |
+
without modification any copyright, patent, trademark, or
|
| 45 |
+
attribution notices that are present in the Work.
|
| 46 |
+
|
| 47 |
+
3.2 Derivative Works. You may specify that additional or different
|
| 48 |
+
terms apply to the use, reproduction, and distribution of your
|
| 49 |
+
derivative works of the Work ("Your Terms") only if (a) Your Terms
|
| 50 |
+
provide that the use limitation in Section 3.3 applies to your
|
| 51 |
+
derivative works, and (b) you identify the specific derivative
|
| 52 |
+
works that are subject to Your Terms. Notwithstanding Your Terms,
|
| 53 |
+
this License (including the redistribution requirements in Section
|
| 54 |
+
3.1) will continue to apply to the Work itself.
|
| 55 |
+
|
| 56 |
+
3.3 Use Limitation. The Work and any derivative works thereof only
|
| 57 |
+
may be used or intended for use non-commercially. Notwithstanding
|
| 58 |
+
the foregoing, NVIDIA and its affiliates may use the Work and any
|
| 59 |
+
derivative works commercially. As used herein, "non-commercially"
|
| 60 |
+
means for research or evaluation purposes only.
|
| 61 |
+
|
| 62 |
+
3.4 Patent Claims. If you bring or threaten to bring a patent claim
|
| 63 |
+
against any Licensor (including any claim, cross-claim or
|
| 64 |
+
counterclaim in a lawsuit) to enforce any patents that you allege
|
| 65 |
+
are infringed by any Work, then your rights under this License from
|
| 66 |
+
such Licensor (including the grant in Section 2.1) will terminate
|
| 67 |
+
immediately.
|
| 68 |
+
|
| 69 |
+
3.5 Trademarks. This License does not grant any rights to use any
|
| 70 |
+
Licensor’s or its affiliates’ names, logos, or trademarks, except
|
| 71 |
+
as necessary to reproduce the notices described in this License.
|
| 72 |
+
|
| 73 |
+
3.6 Termination. If you violate any term of this License, then your
|
| 74 |
+
rights under this License (including the grant in Section 2.1) will
|
| 75 |
+
terminate immediately.
|
| 76 |
+
|
| 77 |
+
4. Disclaimer of Warranty.
|
| 78 |
+
|
| 79 |
+
THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
| 80 |
+
KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
|
| 81 |
+
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
|
| 82 |
+
NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
|
| 83 |
+
THIS LICENSE.
|
| 84 |
+
|
| 85 |
+
5. Limitation of Liability.
|
| 86 |
+
|
| 87 |
+
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
|
| 88 |
+
THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
|
| 89 |
+
SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
|
| 90 |
+
INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
|
| 91 |
+
OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
|
| 92 |
+
(INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
|
| 93 |
+
LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
|
| 94 |
+
COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
|
| 95 |
+
THE POSSIBILITY OF SUCH DAMAGES.
|
| 96 |
+
|
| 97 |
+
=======================================================================
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/README.md
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## Point Renderer
|
| 2 |
+
A minimal, lightweight CUDA-accelerated renderer of pointclouds.
|
| 3 |
+
|
| 4 |
+
<div align="center"><img src="demo.png"/></div>
|
| 5 |
+
|
| 6 |
+
### Install
|
| 7 |
+
|
| 8 |
+
```
|
| 9 |
+
pip install -r requirements.txt
|
| 10 |
+
pip install -e .
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
### Run
|
| 14 |
+
|
| 15 |
+
**Load Data**
|
| 16 |
+
Extract included pcd_data.tar.gz
|
| 17 |
+
|
| 18 |
+
```
|
| 19 |
+
import numpy as np
|
| 20 |
+
|
| 21 |
+
data = np.load("pcd_data/w1280_h720/3.npy", allow_pickle=True)
|
| 22 |
+
data = data[None][0]
|
| 23 |
+
pc = data["pc"]
|
| 24 |
+
rgb = data["img_feat"]
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
**Render the image**
|
| 28 |
+
|
| 29 |
+
```
|
| 30 |
+
# Make the renderer
|
| 31 |
+
from point_renderer.renderer import PointRenderer
|
| 32 |
+
renderer = PointRenderer(device="cuda", perf_timer=False)
|
| 33 |
+
|
| 34 |
+
# Define a batch of cameras
|
| 35 |
+
img_size = (512, 512)
|
| 36 |
+
K = renderer.get_camera_intrinsics(hfov=70, img_size=img_size)
|
| 37 |
+
camera_poses = renderer.get_batch_of_camera_poses(
|
| 38 |
+
cam_positions=[[1.5, 1.5, 1.5],[-1.5, -1.5, -1.5]],
|
| 39 |
+
cam_lookats=[[0.0, 0.0, 0.0],[0.0, 0.0, 0.0]])
|
| 40 |
+
|
| 41 |
+
# Render the pointcloud from the given cameras
|
| 42 |
+
images, depths = renderer.render_batch(pc, rgb, camera_poses, K, img_size,
|
| 43 |
+
default_color=1.0,
|
| 44 |
+
splat_radius=0.005,
|
| 45 |
+
aa_factor=2
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
# Show the results
|
| 49 |
+
plt.imshow(images[0].detach().cpu().numpy()); plt.show()
|
| 50 |
+
plt.imshow(depths[0].detach().cpu().numpy()); plt.show()
|
| 51 |
+
plt.imshow(images[1].detach().cpu().numpy()); plt.show()
|
| 52 |
+
plt.imshow(depths[1].detach().cpu().numpy()); plt.show()
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
.. Or run the jupyter notebook that has this same code above, and also all the benchmarks.
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/demo.png
ADDED
|
Git LFS Details
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/image_0_splat_2xaa.png
ADDED
|
third_party/AnyBimanual/agents/rvt/rvt/libs/point-renderer/point_renderer/cameras.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from point_renderer import ops
|
| 4 |
+
from functools import lru_cache
|
| 5 |
+
|
| 6 |
+
@lru_cache(maxsize=32)
|
| 7 |
+
def linalg_inv(poses):
|
| 8 |
+
return torch.linalg.inv(poses)
|
| 9 |
+
|
| 10 |
+
class Cameras:
|
| 11 |
+
def __init__(self, poses, intrinsics, img_size, inv_poses=None):
|
| 12 |
+
self.poses = poses
|
| 13 |
+
self.img_size = img_size
|
| 14 |
+
if inv_poses is None:
|
| 15 |
+
self.inv_poses = linalg_inv(poses)
|
| 16 |
+
else:
|
| 17 |
+
self.inv_poses = inv_poses
|
| 18 |
+
self.intrinsics = intrinsics
|
| 19 |
+
|
| 20 |
+
def __len__(self):
|
| 21 |
+
return len(self.poses)
|
| 22 |
+
|
| 23 |
+
def scale(self, constant):
|
| 24 |
+
self.intrinsics = self.intrinsics.clone()
|
| 25 |
+
self.intrinsics[:, :2, :3] *= constant
|
| 26 |
+
|
| 27 |
+
def is_orthographic(self):
|
| 28 |
+
raise ValueError("is_orthographic should be called on child classes only")
|
| 29 |
+
|
| 30 |
+
def is_perspective(self):
|
| 31 |
+
raise ValueError("is_perspective should be called on child classes only")
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class PerspectiveCameras(Cameras):
|
| 35 |
+
def __init__(self, poses, intrinsics, img_size, inv_poses=None):
|
| 36 |
+
super().__init__(poses, intrinsics, img_size, inv_poses)
|
| 37 |
+
|
| 38 |
+
@classmethod
|
| 39 |
+
def from_lookat(cls, eyes, ats, ups, hfov, img_size, device="cpu"):
|
| 40 |
+
cam_poses = []
|
| 41 |
+
for eye, at, up in zip(eyes, ats, ups):
|
| 42 |
+
T = ops.lookat_to_cam_pose(eye, at, up, device=device)
|
| 43 |
+
cam_poses.append(T)
|
| 44 |
+
cam_poses = torch.stack(cam_poses, dim=0)
|
| 45 |
+
intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device)
|
| 46 |
+
intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
|
| 47 |
+
return PerspectiveCameras(cam_poses, intrinsics, img_size)
|
| 48 |
+
|
| 49 |
+
@classmethod
|
| 50 |
+
def from_rotation_and_translation(cls, R, T, S, hfov, img_size):
|
| 51 |
+
device = R.device
|
| 52 |
+
assert T.device == device
|
| 53 |
+
cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float)
|
| 54 |
+
cam_poses[:, :3, :3] = R * S[None, :]
|
| 55 |
+
cam_poses[:, :3, 3] = T
|
| 56 |
+
cam_poses[:, 3, 3] = 1.0
|
| 57 |
+
intrinsics = ops.fov_and_size_to_intrinsics(hfov, img_size, device=device)
|
| 58 |
+
intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
|
| 59 |
+
return PerspectiveCameras(cam_poses, intrinsics, img_size)
|
| 60 |
+
|
| 61 |
+
def to(self, device):
|
| 62 |
+
return PerspectiveCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device))
|
| 63 |
+
|
| 64 |
+
def is_orthographic(self):
|
| 65 |
+
return False
|
| 66 |
+
|
| 67 |
+
def is_perspective(self):
|
| 68 |
+
return True
|
| 69 |
+
|
| 70 |
+
class OrthographicCameras(Cameras):
|
| 71 |
+
def __init__(self, poses, intrinsics, img_size, inv_poses=None):
|
| 72 |
+
super().__init__(poses, intrinsics, img_size, inv_poses)
|
| 73 |
+
|
| 74 |
+
@classmethod
|
| 75 |
+
def from_lookat(cls, eyes, ats, ups, img_sizes_w, img_size_px, device="cpu"):
|
| 76 |
+
"""
|
| 77 |
+
Args:
|
| 78 |
+
eyes: Nx3 tensor of camera coordinates
|
| 79 |
+
ats: Nx3 tensor of look-at directions
|
| 80 |
+
ups: Nx3 tensor of up-vectors
|
| 81 |
+
scale: Nx2 tensor defining image sizes in world coordinates
|
| 82 |
+
img_size: 2-dim tuple defining image size in pixels
|
| 83 |
+
Returns:
|
| 84 |
+
OrthographicCamera
|
| 85 |
+
"""
|
| 86 |
+
if isinstance(img_sizes_w, list):
|
| 87 |
+
img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(eyes), 1))
|
| 88 |
+
|
| 89 |
+
cam_poses = []
|
| 90 |
+
for eye, at, up in zip(eyes, ats, ups):
|
| 91 |
+
T = ops.lookat_to_cam_pose(eye, at, up, device=device)
|
| 92 |
+
cam_poses.append(T)
|
| 93 |
+
cam_poses = torch.stack(cam_poses, dim=0)
|
| 94 |
+
intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device)
|
| 95 |
+
return OrthographicCameras(cam_poses, intrinsics, img_size_px)
|
| 96 |
+
|
| 97 |
+
@classmethod
|
| 98 |
+
def from_rotation_and_translation(cls, R, T, img_sizes_w, img_size_px, device="cpu"):
|
| 99 |
+
if isinstance(img_sizes_w, list):
|
| 100 |
+
img_sizes_w = torch.tensor(img_sizes_w, device=device)[None, :].repeat((len(R), 1))
|
| 101 |
+
|
| 102 |
+
device = R.device
|
| 103 |
+
assert T.device == device
|
| 104 |
+
cam_poses = torch.zeros((R.shape[0], 4, 4), device=device, dtype=torch.float)
|
| 105 |
+
cam_poses[:, :3, :3] = R
|
| 106 |
+
cam_poses[:, :3, 3] = T
|
| 107 |
+
cam_poses[:, 3, 3] = 1.0
|
| 108 |
+
intrinsics = ops.orthographic_intrinsics_from_scales(img_sizes_w, img_size_px, device=device)
|
| 109 |
+
intrinsics = intrinsics[None, :, :].repeat((cam_poses.shape[0], 1, 1)).contiguous()
|
| 110 |
+
return OrthographicCameras(cam_poses, intrinsics, img_size_px)
|
| 111 |
+
|
| 112 |
+
def to(self, device):
|
| 113 |
+
return OrthographicCameras(self.poses.to(device), self.intrinsics.to(device), self.inv_poses.to(device))
|
| 114 |
+
|
| 115 |
+
def is_orthographic(self):
|
| 116 |
+
return True
|
| 117 |
+
|
| 118 |
+
def is_perspective(self):
|
| 119 |
+
return False
|