| | |
| | |
| | |
| | |
| | @@ -4,7 +4,7 @@ ENVS=( |
| | assembly |
| | basketball |
| | bin-picking |
| | - box-close |
| | + #box-close |
| | button-press-topdown |
| | button-press-topdown-wall |
| | button-press |
| | |
| | |
| | |
| | |
| | @@ -2,10 +2,10 @@ import glob |
| | import json |
| | import subprocess |
| | |
| | -import wandb |
| | from accelerate import Accelerator |
| | from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments |
| | |
| | +import wandb |
| | from gia.config import Arguments |
| | from gia.eval.utils import is_slurm_available |
| | |
| | |
| | |
| | |
| | |
| | @@ -180,7 +180,7 @@ def make(task_name: str, num_envs: int = 1): |
| | import metaworld |
| | |
| | env_id = TASK_TO_ENV_MAPPING[task_name] |
| | - env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) |
| | + env = gym.make(env_id) |
| | |
| | else: |
| | raise ValueError(f"Unknown task name: {task_name}") |
| | |
| | |
| | |
| | |
| | @@ -54,7 +54,7 @@ class GiaAgent: |
| | self.action_space = action_space |
| | self.deterministic = deterministic |
| | self.device = next(model.parameters()).device |
| | - self._max_length = self.model.config.max_position_embeddings - 10 |
| | + self._max_length = self.model.config.max_position_embeddings - 100 |
| | |
| | if isinstance(observation_space, spaces.Box): |
| | self._observation_key = "continuous_observations" |
| | |
| | |
| | |
| | |
| | @@ -1,7 +1,6 @@ |
| | import gym |
| | from gym.vector.vector_env import VectorEnv |
| | |
| | -from gia.eval.mappings import TASK_TO_ENV_MAPPING |
| | from gia.eval.rl.rl_evaluator import RLEvaluator |
| | |
| | |
| |
|