diff --git a/data/envs/metaworld/train_all.sh b/data/envs/metaworld/train_all.sh index dbf328a..c393191 100755 --- a/data/envs/metaworld/train_all.sh +++ b/data/envs/metaworld/train_all.sh @@ -4,7 +4,7 @@ ENVS=( assembly basketball bin-picking - box-close + #box-close button-press-topdown button-press-topdown-wall button-press diff --git a/gia/eval/callback.py b/gia/eval/callback.py index 5c3a080..4b6198f 100644 --- a/gia/eval/callback.py +++ b/gia/eval/callback.py @@ -2,10 +2,10 @@ import glob import json import subprocess -import wandb from accelerate import Accelerator from transformers import TrainerCallback, TrainerControl, TrainerState, TrainingArguments +import wandb from gia.config import Arguments from gia.eval.utils import is_slurm_available diff --git a/gia/eval/rl/envs/core.py b/gia/eval/rl/envs/core.py index ec5e5b2..3294471 100644 --- a/gia/eval/rl/envs/core.py +++ b/gia/eval/rl/envs/core.py @@ -180,7 +180,7 @@ def make(task_name: str, num_envs: int = 1): import metaworld env_id = TASK_TO_ENV_MAPPING[task_name] - env = gym.vector.SyncVectorEnv([lambda: gym.make(env_id)] * num_envs) + env = gym.make(env_id) else: raise ValueError(f"Unknown task name: {task_name}") diff --git a/gia/eval/rl/gia_agent.py b/gia/eval/rl/gia_agent.py index f0d0b9b..255beda 100644 --- a/gia/eval/rl/gia_agent.py +++ b/gia/eval/rl/gia_agent.py @@ -54,7 +54,7 @@ class GiaAgent: self.action_space = action_space self.deterministic = deterministic self.device = next(model.parameters()).device - self._max_length = self.model.config.max_position_embeddings - 10 + self._max_length = self.model.config.max_position_embeddings - 100 if isinstance(observation_space, spaces.Box): self._observation_key = "continuous_observations" diff --git a/gia/eval/rl/gym_evaluator.py b/gia/eval/rl/gym_evaluator.py index f8531ee..71e0fdc 100644 --- a/gia/eval/rl/gym_evaluator.py +++ b/gia/eval/rl/gym_evaluator.py @@ -1,7 +1,6 @@ import gym from gym.vector.vector_env import VectorEnv -from gia.eval.mappings import TASK_TO_ENV_MAPPING from gia.eval.rl.rl_evaluator import RLEvaluator