Commit
·
5c3096c
1
Parent(s):
ab0c658
Upload folder using huggingface_hub
Browse files- .summary/0/events.out.tfevents.1688753657.qgallouedec-MS-7C84 +3 -0
- README.md +1 -1
- checkpoint_p0/best_000015216_7790592_reward_381.761.pth +3 -0
- checkpoint_p0/checkpoint_000019528_9998336.pth +3 -0
- checkpoint_p0/checkpoint_000019544_10006528.pth +1 -1
- config.json +2 -2
- git.diff +219 -74
- replay.mp4 +2 -2
- sf_log.txt +0 -0
.summary/0/events.out.tfevents.1688753657.qgallouedec-MS-7C84
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:065fccda6004f2a0e31e4a2a6cccbc8f665985b21fbffd96abfb9b0942b70d85
|
| 3 |
+
size 800648
|
README.md
CHANGED
|
@@ -15,7 +15,7 @@ model-index:
|
|
| 15 |
type: bin-picking-v2
|
| 16 |
metrics:
|
| 17 |
- type: mean_reward
|
| 18 |
-
value:
|
| 19 |
name: mean_reward
|
| 20 |
verified: false
|
| 21 |
---
|
|
|
|
| 15 |
type: bin-picking-v2
|
| 16 |
metrics:
|
| 17 |
- type: mean_reward
|
| 18 |
+
value: 402.65 +/- 2.67
|
| 19 |
name: mean_reward
|
| 20 |
verified: false
|
| 21 |
---
|
checkpoint_p0/best_000015216_7790592_reward_381.761.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8ac417c935acc153a20a232e82408052bbb5b4e40794df2e9aa6a3f15bcf08f7
|
| 3 |
+
size 98239
|
checkpoint_p0/checkpoint_000019528_9998336.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3b275139d40e357758fd61256fcbf041ee60089899f4f2c424c30be014dc814
|
| 3 |
+
size 98567
|
checkpoint_p0/checkpoint_000019544_10006528.pth
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 98567
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a964c934fd32046ed2fefc2ca66fef563c63119f7dc43d4ba7b48c75384d6419
|
| 3 |
size 98567
|
config.json
CHANGED
|
@@ -128,7 +128,7 @@
|
|
| 128 |
"wandb_user": "qgallouedec",
|
| 129 |
"wandb_project": "sample_facotry_metaworld"
|
| 130 |
},
|
| 131 |
-
"git_hash": "
|
| 132 |
"git_repo_name": "https://github.com/huggingface/gia",
|
| 133 |
-
"wandb_unique_id": "bin-picking-
|
| 134 |
}
|
|
|
|
| 128 |
"wandb_user": "qgallouedec",
|
| 129 |
"wandb_project": "sample_facotry_metaworld"
|
| 130 |
},
|
| 131 |
+
"git_hash": "dda7c2cbaa4c60ae8940e37f69d814d32339d2fa",
|
| 132 |
"git_repo_name": "https://github.com/huggingface/gia",
|
| 133 |
+
"wandb_unique_id": "bin-picking-v2_20230707_201415_318153"
|
| 134 |
}
|
git.diff
CHANGED
|
@@ -318,6 +318,96 @@ index 4c3f06b..88b6c45 100644
|
|
| 318 |
]
|
| 319 |
|
| 320 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 321 |
diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
|
| 322 |
index e21b237..c2b1907 100644
|
| 323 |
--- a/data/envs/metaworld/generate_dataset.py
|
|
@@ -333,20 +423,22 @@ index e21b237..c2b1907 100644
|
|
| 333 |
dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0])
|
| 334 |
dataset["continuous_actions"][-1].append(actions[0])
|
| 335 |
diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
|
| 336 |
-
index cfdae2f..
|
| 337 |
--- a/data/envs/metaworld/generate_dataset_all.sh
|
| 338 |
+++ b/data/envs/metaworld/generate_dataset_all.sh
|
| 339 |
-
@@ -
|
|
|
|
| 340 |
|
| 341 |
ENVS=(
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
-
|
| 346 |
-
-
|
| 347 |
-
-
|
| 348 |
-
|
| 349 |
-
|
|
|
|
| 350 |
- coffee-button
|
| 351 |
- coffee-pull
|
| 352 |
- coffee-push
|
|
@@ -389,13 +481,6 @@ index cfdae2f..5db8c4b 100755
|
|
| 389 |
- sweep
|
| 390 |
- window-close
|
| 391 |
- window-open
|
| 392 |
-
+ # basketball
|
| 393 |
-
+ # bin-picking
|
| 394 |
-
+ # box-close
|
| 395 |
-
+ # button-press-topdown
|
| 396 |
-
+ # button-press-topdown-wall
|
| 397 |
-
+ # button-press
|
| 398 |
-
+ # button-press-wall
|
| 399 |
+ # coffee-button
|
| 400 |
+ # coffee-pull
|
| 401 |
+ # coffee-push
|
|
@@ -447,7 +532,7 @@ index cfdae2f..5db8c4b 100755
|
|
| 447 |
+ python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
|
| 448 |
done
|
| 449 |
diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh
|
| 450 |
-
index 9d71467..
|
| 451 |
--- a/data/envs/metaworld/push_all.sh
|
| 452 |
+++ b/data/envs/metaworld/push_all.sh
|
| 453 |
@@ -2,57 +2,57 @@
|
|
@@ -556,13 +641,82 @@ index 9d71467..5b05c6d 100755
|
|
| 556 |
|
| 557 |
for ENV in "${ENVS[@]}"; do
|
| 558 |
- python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
|
| 559 |
-
+ python
|
| 560 |
done
|
| 561 |
diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
|
| 562 |
-
index 46dc581..
|
| 563 |
--- a/data/envs/metaworld/train.py
|
| 564 |
+++ b/data/envs/metaworld/train.py
|
| 565 |
-
@@ -
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 566 |
num_workers=8,
|
| 567 |
num_envs_per_worker=8,
|
| 568 |
worker_num_splits=2,
|
|
@@ -571,34 +725,18 @@ index 46dc581..c72f289 100644
|
|
| 571 |
encoder_mlp_layers=[64, 64],
|
| 572 |
env_frameskip=1,
|
| 573 |
nonlinearity="tanh",
|
| 574 |
-
|
| 575 |
-
index dbf328a..1b3c4c8 100755
|
| 576 |
-
--- a/data/envs/metaworld/train_all.sh
|
| 577 |
-
+++ b/data/envs/metaworld/train_all.sh
|
| 578 |
-
@@ -1,7 +1,7 @@
|
| 579 |
-
#!/bin/bash
|
| 580 |
-
|
| 581 |
-
ENVS=(
|
| 582 |
-
- assembly
|
| 583 |
-
+ # assembly
|
| 584 |
-
basketball
|
| 585 |
-
bin-picking
|
| 586 |
-
box-close
|
| 587 |
-
diff --git a/gia/eval/callback.py b/gia/eval/callback.py
|
| 588 |
-
index 5c3a080..4b6198f 100644
|
| 589 |
-
--- a/gia/eval/callback.py
|
| 590 |
-
+++ b/gia/eval/callback.py
|
| 591 |
-
@@ -2,10 +2,10 @@ import glob
|
| 592 |
-
import json
|
| 593 |
-
import subprocess
|
| 594 |
|
| 595 |
-
-import wandb
|
| 596 |
-
from accelerate import Accelerator
|
| 597 |
-
from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments
|
| 598 |
|
| 599 |
-
|
| 600 |
-
|
| 601 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 602 |
|
| 603 |
diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
|
| 604 |
index 91b645c..3e2cae7 100644
|
|
@@ -625,38 +763,33 @@ index 91b645c..3e2cae7 100644
|
|
| 625 |
def evaluate(self, model: GiaModel) -> float:
|
| 626 |
return self._evaluate(model)
|
| 627 |
|
| 628 |
-
diff --git a/gia/eval/
|
| 629 |
-
|
| 630 |
-
|
| 631 |
-
|
| 632 |
-
|
| 633 |
-
|
| 634 |
-
-TASK_TO_ENV_MAPPING = {
|
| 635 |
-
- "mujoco-ant": "Ant-v4",
|
| 636 |
-
- "mujoco-halfcheetah": "HalfCheetah-v4",
|
| 637 |
-
- "mujoco-hopper": "Hopper-v4",
|
| 638 |
-
- "mujoco-doublependulum": "InvertedDoublePendulum-v4",
|
| 639 |
-
- "mujoco-pendulum": "InvertedPendulum-v4",
|
| 640 |
-
- "mujoco-reacher": "Reacher-v4",
|
| 641 |
-
- "mujoco-swimmer": "Swimmer-v4",
|
| 642 |
-
- "mujoco-walker": "Walker2d-v4",
|
| 643 |
-
- # Atari etc...
|
| 644 |
-
-}
|
| 645 |
-
diff --git a/gia/eval/rl/__init__.py b/gia/eval/rl/__init__.py
|
| 646 |
-
index 36d890b..da5e0c7 100644
|
| 647 |
-
--- a/gia/eval/rl/__init__.py
|
| 648 |
-
+++ b/gia/eval/rl/__init__.py
|
| 649 |
-
@@ -1,4 +1,5 @@
|
| 650 |
-
+from .envs.core import make
|
| 651 |
-
from .gym_evaluator import GymEvaluator
|
| 652 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 653 |
|
| 654 |
-
|
| 655 |
-
|
| 656 |
diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
|
| 657 |
-
index f0d0b9b..
|
| 658 |
--- a/gia/eval/rl/gia_agent.py
|
| 659 |
+++ b/gia/eval/rl/gia_agent.py
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 660 |
@@ -75,6 +75,11 @@ class GiaAgent:
|
| 661 |
) -> Tuple[Tuple[Tensor, Tensor], ...]:
|
| 662 |
return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
|
|
@@ -712,3 +845,15 @@ index 1b8ebee..ff7d030 100644
|
|
| 712 |
},
|
| 713 |
"random": {
|
| 714 |
"mean": 220.65601680730813,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 318 |
]
|
| 319 |
|
| 320 |
|
| 321 |
+
diff --git a/data/envs/metaworld/enjoy.py b/data/envs/metaworld/enjoy.py
|
| 322 |
+
deleted file mode 100644
|
| 323 |
+
index 6ec026b..0000000
|
| 324 |
+
--- a/data/envs/metaworld/enjoy.py
|
| 325 |
+
+++ /dev/null
|
| 326 |
+
@@ -1,84 +0,0 @@
|
| 327 |
+
-import sys
|
| 328 |
+
-from typing import Dict, Optional
|
| 329 |
+
-
|
| 330 |
+
-import gym
|
| 331 |
+
-import metaworld # noqa: F401
|
| 332 |
+
-from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
|
| 333 |
+
-from sample_factory.enjoy import enjoy
|
| 334 |
+
-from sample_factory.envs.env_utils import register_env
|
| 335 |
+
-
|
| 336 |
+
-
|
| 337 |
+
-ENV_NAMES = [
|
| 338 |
+
- "assembly-v2",
|
| 339 |
+
- "basketball-v2",
|
| 340 |
+
- "bin-picking-v2",
|
| 341 |
+
- "box-close-v2",
|
| 342 |
+
- "button-press-topdown-v2",
|
| 343 |
+
- "button-press-topdown-wall-v2",
|
| 344 |
+
- "button-press-v2",
|
| 345 |
+
- "button-press-wall-v2",
|
| 346 |
+
- "coffee-button-v2",
|
| 347 |
+
- "coffee-pull-v2",
|
| 348 |
+
- "coffee-push-v2",
|
| 349 |
+
- "dial-turn-v2",
|
| 350 |
+
- "disassemble-v2",
|
| 351 |
+
- "door-close-v2",
|
| 352 |
+
- "door-lock-v2",
|
| 353 |
+
- "door-open-v2",
|
| 354 |
+
- "door-unlock-v2",
|
| 355 |
+
- "drawer-close-v2",
|
| 356 |
+
- "drawer-open-v2",
|
| 357 |
+
- "faucet-close-v2",
|
| 358 |
+
- "faucet-open-v2",
|
| 359 |
+
- "hammer-v2",
|
| 360 |
+
- "hand-insert-v2",
|
| 361 |
+
- "handle-press-side-v2",
|
| 362 |
+
- "handle-press-v2",
|
| 363 |
+
- "handle-pull-side-v2",
|
| 364 |
+
- "handle-pull-v2",
|
| 365 |
+
- "lever-pull-v2",
|
| 366 |
+
- "peg-insert-side-v2",
|
| 367 |
+
- "peg-unplug-side-v2",
|
| 368 |
+
- "pick-out-of-hole-v2",
|
| 369 |
+
- "pick-place-v2",
|
| 370 |
+
- "pick-place-wall-v2",
|
| 371 |
+
- "plate-slide-back-side-v2",
|
| 372 |
+
- "plate-slide-back-v2",
|
| 373 |
+
- "plate-slide-side-v2",
|
| 374 |
+
- "plate-slide-v2",
|
| 375 |
+
- "push-back-v2",
|
| 376 |
+
- "push-v2",
|
| 377 |
+
- "push-wall-v2",
|
| 378 |
+
- "reach-v2",
|
| 379 |
+
- "reach-wall-v2",
|
| 380 |
+
- "shelf-place-v2",
|
| 381 |
+
- "soccer-v2",
|
| 382 |
+
- "stick-pull-v2",
|
| 383 |
+
- "stick-push-v2",
|
| 384 |
+
- "sweep-into-v2",
|
| 385 |
+
- "sweep-v2",
|
| 386 |
+
- "window-close-v2",
|
| 387 |
+
- "window-open-v2",
|
| 388 |
+
-]
|
| 389 |
+
-
|
| 390 |
+
-
|
| 391 |
+
-def make_custom_env(
|
| 392 |
+
- full_env_name: str,
|
| 393 |
+
- cfg: Optional[Dict] = None,
|
| 394 |
+
- env_config: Optional[Dict] = None,
|
| 395 |
+
- render_mode: Optional[str] = None,
|
| 396 |
+
-) -> gym.Env:
|
| 397 |
+
- return gym.make(full_env_name, render_mode=render_mode)
|
| 398 |
+
-
|
| 399 |
+
-
|
| 400 |
+
-def main() -> int:
|
| 401 |
+
- for env_name in ENV_NAMES:
|
| 402 |
+
- register_env(env_name, make_custom_env)
|
| 403 |
+
- parser, _ = parse_sf_args(argv=None, evaluation=True)
|
| 404 |
+
- cfg = parse_full_cfg(parser)
|
| 405 |
+
- status = enjoy(cfg)
|
| 406 |
+
- return status
|
| 407 |
+
-
|
| 408 |
+
-
|
| 409 |
+
-if __name__ == "__main__":
|
| 410 |
+
- sys.exit(main())
|
| 411 |
diff --git a/data/envs/metaworld/generate_dataset.py b/data/envs/metaworld/generate_dataset.py
|
| 412 |
index e21b237..c2b1907 100644
|
| 413 |
--- a/data/envs/metaworld/generate_dataset.py
|
|
|
|
| 423 |
dataset["continuous_observations"][-1].append(observations["obs"].cpu().numpy()[0])
|
| 424 |
dataset["continuous_actions"][-1].append(actions[0])
|
| 425 |
diff --git a/data/envs/metaworld/generate_dataset_all.sh b/data/envs/metaworld/generate_dataset_all.sh
|
| 426 |
+
index cfdae2f..8720089 100755
|
| 427 |
--- a/data/envs/metaworld/generate_dataset_all.sh
|
| 428 |
+++ b/data/envs/metaworld/generate_dataset_all.sh
|
| 429 |
+
@@ -1,7 +1,7 @@
|
| 430 |
+
#!/bin/bash
|
| 431 |
|
| 432 |
ENVS=(
|
| 433 |
+
- assembly
|
| 434 |
+
+ # assembly
|
| 435 |
+
basketball
|
| 436 |
+
bin-picking
|
| 437 |
+
box-close
|
| 438 |
+
@@ -9,51 +9,51 @@ ENVS=(
|
| 439 |
+
button-press-topdown-wall
|
| 440 |
+
button-press
|
| 441 |
+
button-press-wall
|
| 442 |
- coffee-button
|
| 443 |
- coffee-pull
|
| 444 |
- coffee-push
|
|
|
|
| 481 |
- sweep
|
| 482 |
- window-close
|
| 483 |
- window-open
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 484 |
+ # coffee-button
|
| 485 |
+ # coffee-pull
|
| 486 |
+ # coffee-push
|
|
|
|
| 532 |
+ python generate_dataset.py --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir
|
| 533 |
done
|
| 534 |
diff --git a/data/envs/metaworld/push_all.sh b/data/envs/metaworld/push_all.sh
|
| 535 |
+
index 9d71467..4fc1fc2 100755
|
| 536 |
--- a/data/envs/metaworld/push_all.sh
|
| 537 |
+++ b/data/envs/metaworld/push_all.sh
|
| 538 |
@@ -2,57 +2,57 @@
|
|
|
|
| 641 |
|
| 642 |
for ENV in "${ENVS[@]}"; do
|
| 643 |
- python enjoy.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/sample-factory-$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
|
| 644 |
+
+ python push.py --algo=APPO --env $ENV-v2 --experiment $ENV-v2 --train_dir=./train_dir --max_num_episodes=10 --push_to_hub --hf_repository=qgallouedec/$ENV-v2 --save_video --no_render --enjoy_script=enjoy --train_script=train --load_checkpoint_kind best
|
| 645 |
done
|
| 646 |
diff --git a/data/envs/metaworld/train.py b/data/envs/metaworld/train.py
|
| 647 |
+
index 46dc581..095414e 100644
|
| 648 |
--- a/data/envs/metaworld/train.py
|
| 649 |
+++ b/data/envs/metaworld/train.py
|
| 650 |
+
@@ -2,67 +2,13 @@ import argparse
|
| 651 |
+
import sys
|
| 652 |
+
from typing import Dict, Optional
|
| 653 |
+
|
| 654 |
+
-import gym
|
| 655 |
+
+import gymnasium as gym
|
| 656 |
+
import metaworld # noqa: F401
|
| 657 |
+
from sample_factory.cfg.arguments import parse_full_cfg, parse_sf_args
|
| 658 |
+
from sample_factory.envs.env_utils import register_env
|
| 659 |
+
from sample_factory.train import run_rl
|
| 660 |
+
|
| 661 |
+
|
| 662 |
+
-ENV_NAMES = [
|
| 663 |
+
- "assembly-v2",
|
| 664 |
+
- "basketball-v2",
|
| 665 |
+
- "bin-picking-v2",
|
| 666 |
+
- "box-close-v2",
|
| 667 |
+
- "button-press-topdown-v2",
|
| 668 |
+
- "button-press-topdown-wall-v2",
|
| 669 |
+
- "button-press-v2",
|
| 670 |
+
- "button-press-wall-v2",
|
| 671 |
+
- "coffee-button-v2",
|
| 672 |
+
- "coffee-pull-v2",
|
| 673 |
+
- "coffee-push-v2",
|
| 674 |
+
- "dial-turn-v2",
|
| 675 |
+
- "disassemble-v2",
|
| 676 |
+
- "door-close-v2",
|
| 677 |
+
- "door-lock-v2",
|
| 678 |
+
- "door-open-v2",
|
| 679 |
+
- "door-unlock-v2",
|
| 680 |
+
- "drawer-close-v2",
|
| 681 |
+
- "drawer-open-v2",
|
| 682 |
+
- "faucet-close-v2",
|
| 683 |
+
- "faucet-open-v2",
|
| 684 |
+
- "hammer-v2",
|
| 685 |
+
- "hand-insert-v2",
|
| 686 |
+
- "handle-press-side-v2",
|
| 687 |
+
- "handle-press-v2",
|
| 688 |
+
- "handle-pull-side-v2",
|
| 689 |
+
- "handle-pull-v2",
|
| 690 |
+
- "lever-pull-v2",
|
| 691 |
+
- "peg-insert-side-v2",
|
| 692 |
+
- "peg-unplug-side-v2",
|
| 693 |
+
- "pick-out-of-hole-v2",
|
| 694 |
+
- "pick-place-v2",
|
| 695 |
+
- "pick-place-wall-v2",
|
| 696 |
+
- "plate-slide-back-side-v2",
|
| 697 |
+
- "plate-slide-back-v2",
|
| 698 |
+
- "plate-slide-side-v2",
|
| 699 |
+
- "plate-slide-v2",
|
| 700 |
+
- "push-back-v2",
|
| 701 |
+
- "push-v2",
|
| 702 |
+
- "push-wall-v2",
|
| 703 |
+
- "reach-v2",
|
| 704 |
+
- "reach-wall-v2",
|
| 705 |
+
- "shelf-place-v2",
|
| 706 |
+
- "soccer-v2",
|
| 707 |
+
- "stick-pull-v2",
|
| 708 |
+
- "stick-push-v2",
|
| 709 |
+
- "sweep-into-v2",
|
| 710 |
+
- "sweep-v2",
|
| 711 |
+
- "window-close-v2",
|
| 712 |
+
- "window-open-v2",
|
| 713 |
+
-]
|
| 714 |
+
-
|
| 715 |
+
-
|
| 716 |
+
def make_custom_env(
|
| 717 |
+
full_env_name: str,
|
| 718 |
+
cfg: Optional[Dict] = None,
|
| 719 |
+
@@ -79,7 +25,7 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
|
| 720 |
num_workers=8,
|
| 721 |
num_envs_per_worker=8,
|
| 722 |
worker_num_splits=2,
|
|
|
|
| 725 |
encoder_mlp_layers=[64, 64],
|
| 726 |
env_frameskip=1,
|
| 727 |
nonlinearity="tanh",
|
| 728 |
+
@@ -116,11 +62,10 @@ def override_defaults(parser: argparse.ArgumentParser) -> argparse.ArgumentParse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 729 |
|
|
|
|
|
|
|
|
|
|
| 730 |
|
| 731 |
+
def main() -> int:
|
| 732 |
+
- for env_name in ENV_NAMES:
|
| 733 |
+
- register_env(env_name, make_custom_env)
|
| 734 |
+
parser, _ = parse_sf_args(argv=None, evaluation=False)
|
| 735 |
+
parser = override_defaults(parser)
|
| 736 |
+
cfg = parse_full_cfg(parser)
|
| 737 |
+
+ register_env(cfg.env, make_custom_env)
|
| 738 |
+
status = run_rl(cfg)
|
| 739 |
+
return status
|
| 740 |
|
| 741 |
diff --git a/gia/eval/evaluator.py b/gia/eval/evaluator.py
|
| 742 |
index 91b645c..3e2cae7 100644
|
|
|
|
| 763 |
def evaluate(self, model: GiaModel) -> float:
|
| 764 |
return self._evaluate(model)
|
| 765 |
|
| 766 |
+
diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py
|
| 767 |
+
index f1f83f5..ec5e5b2 100644
|
| 768 |
+
--- a/gia/eval/rl/envs/core.py
|
| 769 |
+
+++ b/gia/eval/rl/envs/core.py
|
| 770 |
+
@@ -176,7 +176,8 @@ def make(task_name: str, num_envs: int = 1):
|
| 771 |
+
env = gym.vector.SyncVectorEnv([env_func] * num_envs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 772 |
|
| 773 |
+
elif task_name.startswith("metaworld"):
|
| 774 |
+
- import gym
|
| 775 |
+
+ import gymnasium as gym
|
| 776 |
+
+ import metaworld
|
| 777 |
|
| 778 |
+
env_id = TASK_TO_ENV_MAPPING[task_name]
|
| 779 |
+
env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs)
|
| 780 |
diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py
|
| 781 |
+
index f0d0b9b..ca37721 100644
|
| 782 |
--- a/gia/eval/rl/gia_agent.py
|
| 783 |
+++ b/gia/eval/rl/gia_agent.py
|
| 784 |
+
@@ -9,7 +9,7 @@ from gia.datasets import GiaDataCollator, Prompter
|
| 785 |
+
from gia.model.gia_model import GiaModel
|
| 786 |
+
from gia.processing import GiaProcessor
|
| 787 |
+
|
| 788 |
+
-
|
| 789 |
+
+import sample_factory.envs.env_utils
|
| 790 |
+
class GiaAgent:
|
| 791 |
+
r"""
|
| 792 |
+
An RL agent that uses Gia to generate actions.
|
| 793 |
@@ -75,6 +75,11 @@ class GiaAgent:
|
| 794 |
) -> Tuple[Tuple[Tensor, Tensor], ...]:
|
| 795 |
return tuple((k[:, :, -self._max_length :], v[:, :, -self._max_length :]) for (k, v) in past_key_values)
|
|
|
|
| 845 |
},
|
| 846 |
"random": {
|
| 847 |
"mean": 220.65601680730813,
|
| 848 |
+
diff --git a/gia/model/gia_model.py b/gia/model/gia_model.py
|
| 849 |
+
index 7683ca5..74e82f3 100644
|
| 850 |
+
--- a/gia/model/gia_model.py
|
| 851 |
+
+++ b/gia/model/gia_model.py
|
| 852 |
+
@@ -116,6 +116,7 @@ class GiaModel(PreTrainedModel):
|
| 853 |
+
labels[~loss_mask] = -100
|
| 854 |
+
else:
|
| 855 |
+
labels = None
|
| 856 |
+
+ labels[labels>0] = 0
|
| 857 |
+
return self.causal_lm_model(
|
| 858 |
+
inputs_embeds=embeds,
|
| 859 |
+
attention_mask=attention_mask,
|
replay.mp4
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:33b28b88b4780dab4a6e7f78f8c40edba8047630cde2d14b68e05dd2acce8489
|
| 3 |
+
size 674687
|
sf_log.txt
CHANGED
|
The diff for this file is too large to render.
See raw diff
|
|
|