Spaces:
Runtime error
Runtime error
| import fnmatch | |
| import os | |
| from typing import Dict, SupportsFloat | |
| import gymnasium as gym | |
| import numpy as np | |
| import torch | |
| from gymnasium import wrappers | |
| from huggingface_hub import HfApi | |
| from huggingface_hub.utils._errors import EntryNotFoundError | |
| from src.logging import setup_logger | |
| logger = setup_logger(__name__) | |
| API = HfApi(token=os.environ.get("TOKEN")) | |
| ALL_ENV_IDS = [ | |
| "AdventureNoFrameskip-v4", | |
| "AirRaidNoFrameskip-v4", | |
| "AlienNoFrameskip-v4", | |
| "AmidarNoFrameskip-v4", | |
| "AssaultNoFrameskip-v4", | |
| "AsterixNoFrameskip-v4", | |
| "AsteroidsNoFrameskip-v4", | |
| "AtlantisNoFrameskip-v4", | |
| "BankHeistNoFrameskip-v4", | |
| "BattleZoneNoFrameskip-v4", | |
| "BeamRiderNoFrameskip-v4", | |
| "BerzerkNoFrameskip-v4", | |
| "BowlingNoFrameskip-v4", | |
| "BoxingNoFrameskip-v4", | |
| "BreakoutNoFrameskip-v4", | |
| "CarnivalNoFrameskip-v4", | |
| "CentipedeNoFrameskip-v4", | |
| "ChopperCommandNoFrameskip-v4", | |
| "CrazyClimberNoFrameskip-v4", | |
| "DefenderNoFrameskip-v4", | |
| "DemonAttackNoFrameskip-v4", | |
| "DoubleDunkNoFrameskip-v4", | |
| "ElevatorActionNoFrameskip-v4", | |
| "EnduroNoFrameskip-v4", | |
| "FishingDerbyNoFrameskip-v4", | |
| "FreewayNoFrameskip-v4", | |
| "FrostbiteNoFrameskip-v4", | |
| "GopherNoFrameskip-v4", | |
| "GravitarNoFrameskip-v4", | |
| "HeroNoFrameskip-v4", | |
| "IceHockeyNoFrameskip-v4", | |
| "JamesbondNoFrameskip-v4", | |
| "JourneyEscapeNoFrameskip-v4", | |
| "KangarooNoFrameskip-v4", | |
| "KrullNoFrameskip-v4", | |
| "KungFuMasterNoFrameskip-v4", | |
| "MontezumaRevengeNoFrameskip-v4", | |
| "MsPacmanNoFrameskip-v4", | |
| "NameThisGameNoFrameskip-v4", | |
| "PhoenixNoFrameskip-v4", | |
| "PitfallNoFrameskip-v4", | |
| "PongNoFrameskip-v4", | |
| "PooyanNoFrameskip-v4", | |
| "PrivateEyeNoFrameskip-v4", | |
| "QbertNoFrameskip-v4", | |
| "RiverraidNoFrameskip-v4", | |
| "RoadRunnerNoFrameskip-v4", | |
| "RobotankNoFrameskip-v4", | |
| "SeaquestNoFrameskip-v4", | |
| "SkiingNoFrameskip-v4", | |
| "SolarisNoFrameskip-v4", | |
| "SpaceInvadersNoFrameskip-v4", | |
| "StarGunnerNoFrameskip-v4", | |
| "TennisNoFrameskip-v4", | |
| "TimePilotNoFrameskip-v4", | |
| "TutankhamNoFrameskip-v4", | |
| "UpNDownNoFrameskip-v4", | |
| "VentureNoFrameskip-v4", | |
| "VideoPinballNoFrameskip-v4", | |
| "WizardOfWorNoFrameskip-v4", | |
| "YarsRevengeNoFrameskip-v4", | |
| "ZaxxonNoFrameskip-v4", | |
| "CartPole-v1", | |
| "MountainCar-v0", | |
| "Acrobot-v1", | |
| "Ant-v4", | |
| "HalfCheetah-v4", | |
| "Hopper-v4", | |
| "Humanoid-v4", | |
| "HumanoidStandup-v4", | |
| "InvertedDoublePendulum-v4", | |
| "InvertedPendulum-v4", | |
| "Pusher-v4", | |
| "Reacher-v4", | |
| "Swimmer-v4", | |
| "Walker2d-v4", | |
| ] | |
| class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
| """ | |
| Sample initial states by taking random number of no-ops on reset. | |
| No-op is assumed to be action 0. | |
| :param env: Environment to wrap | |
| :param noop_max: Maximum value of no-ops to run | |
| """ | |
| def __init__(self, env: gym.Env, noop_max: int = 30) -> None: | |
| super().__init__(env) | |
| self.noop_max = noop_max | |
| self.override_num_noops = None | |
| self.noop_action = 0 | |
| assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined] | |
| def reset(self, **kwargs): | |
| self.env.reset(**kwargs) | |
| if self.override_num_noops is not None: | |
| noops = self.override_num_noops | |
| else: | |
| noops = self.unwrapped.np_random.integers(1, self.noop_max + 1) | |
| assert noops > 0 | |
| obs = np.zeros(0) | |
| info: Dict = {} | |
| for _ in range(noops): | |
| obs, _, terminated, truncated, info = self.env.step(self.noop_action) | |
| if terminated or truncated: | |
| obs, info = self.env.reset(**kwargs) | |
| return obs, info | |
| class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
| """ | |
| Take action on reset for environments that are fixed until firing. | |
| :param env: Environment to wrap | |
| """ | |
| def __init__(self, env: gym.Env) -> None: | |
| super().__init__(env) | |
| assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined] | |
| assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined] | |
| def reset(self, **kwargs): | |
| self.env.reset(**kwargs) | |
| obs, _, terminated, truncated, _ = self.env.step(1) | |
| if terminated or truncated: | |
| self.env.reset(**kwargs) | |
| obs, _, terminated, truncated, _ = self.env.step(2) | |
| if terminated or truncated: | |
| self.env.reset(**kwargs) | |
| return obs, {} | |
| class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
| """ | |
| Make end-of-life == end-of-episode, but only reset on true game over. | |
| Done by DeepMind for the DQN and co. since it helps value estimation. | |
| :param env: Environment to wrap | |
| """ | |
| def __init__(self, env: gym.Env) -> None: | |
| super().__init__(env) | |
| self.lives = 0 | |
| self.was_real_done = True | |
| def step(self, action: int): | |
| obs, reward, terminated, truncated, info = self.env.step(action) | |
| self.was_real_done = terminated or truncated | |
| # check current lives, make loss of life terminal, | |
| # then update lives to handle bonus lives | |
| lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] | |
| if 0 < lives < self.lives: | |
| # for Qbert sometimes we stay in lives == 0 condition for a few frames | |
| # so its important to keep lives > 0, so that we only reset once | |
| # the environment advertises done. | |
| terminated = True | |
| self.lives = lives | |
| return obs, reward, terminated, truncated, info | |
| def reset(self, **kwargs): | |
| """ | |
| Calls the Gym environment reset, only when lives are exhausted. | |
| This way all states are still reachable even though lives are episodic, | |
| and the learner need not know about any of this behind-the-scenes. | |
| :param kwargs: Extra keywords passed to env.reset() call | |
| :return: the first observation of the environment | |
| """ | |
| if self.was_real_done: | |
| obs, info = self.env.reset(**kwargs) | |
| else: | |
| # no-op step to advance from terminal/lost life state | |
| obs, _, terminated, truncated, info = self.env.step(0) | |
| # The no-op step can lead to a game over, so we need to check it again | |
| # to see if we should reset the environment and avoid the | |
| # monitor.py `RuntimeError: Tried to step environment that needs reset` | |
| if terminated or truncated: | |
| obs, info = self.env.reset(**kwargs) | |
| self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined] | |
| return obs, info | |
| class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]): | |
| """ | |
| Return only every ``skip``-th frame (frameskipping) | |
| and return the max between the two last frames. | |
| :param env: Environment to wrap | |
| :param skip: Number of ``skip``-th frame | |
| The same action will be taken ``skip`` times. | |
| """ | |
| def __init__(self, env: gym.Env, skip: int = 4) -> None: | |
| super().__init__(env) | |
| # most recent raw observations (for max pooling across time steps) | |
| assert env.observation_space.dtype is not None, "No dtype specified for the observation space" | |
| assert env.observation_space.shape is not None, "No shape defined for the observation space" | |
| self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype) | |
| self._skip = skip | |
| def step(self, action: int): | |
| """ | |
| Step the environment with the given action | |
| Repeat action, sum reward, and max over last observations. | |
| :param action: the action | |
| :return: observation, reward, terminated, truncated, information | |
| """ | |
| total_reward = 0.0 | |
| terminated = truncated = False | |
| for i in range(self._skip): | |
| obs, reward, terminated, truncated, info = self.env.step(action) | |
| done = terminated or truncated | |
| if i == self._skip - 2: | |
| self._obs_buffer[0] = obs | |
| if i == self._skip - 1: | |
| self._obs_buffer[1] = obs | |
| total_reward += float(reward) | |
| if done: | |
| break | |
| # Note that the observation on the done=True frame | |
| # doesn't matter | |
| max_frame = self._obs_buffer.max(axis=0) | |
| return max_frame, total_reward, terminated, truncated, info | |
| class ClipRewardEnv(gym.RewardWrapper): | |
| """ | |
| Clip the reward to {+1, 0, -1} by its sign. | |
| :param env: Environment to wrap | |
| """ | |
| def __init__(self, env: gym.Env) -> None: | |
| super().__init__(env) | |
| def reward(self, reward: SupportsFloat) -> float: | |
| """ | |
| Bin reward to {+1, 0, -1} by its sign. | |
| :param reward: | |
| :return: | |
| """ | |
| return np.sign(float(reward)) | |
| def make(env_id): | |
| def thunk(): | |
| env = gym.make(env_id) | |
| env = wrappers.RecordEpisodeStatistics(env) | |
| if "NoFrameskip" in env_id: | |
| env = NoopResetEnv(env, noop_max=30) | |
| env = MaxAndSkipEnv(env, skip=4) | |
| env = EpisodicLifeEnv(env) | |
| if "FIRE" in env.unwrapped.get_action_meanings(): | |
| env = FireResetEnv(env) | |
| env = ClipRewardEnv(env) | |
| env = wrappers.ResizeObservation(env, (84, 84)) | |
| env = wrappers.GrayScaleObservation(env) | |
| env = wrappers.FrameStack(env, 4) | |
| return env | |
| return thunk | |
| def pattern_match(patterns, source_list): | |
| if isinstance(patterns, str): | |
| patterns = [patterns] | |
| env_ids = set() | |
| for pattern in patterns: | |
| for matching in fnmatch.filter(source_list, pattern): | |
| env_ids.add(matching) | |
| return sorted(list(env_ids)) | |
| def evaluate(model_id, revision): | |
| tags = API.model_info(model_id, revision=revision).tags | |
| # Extract the environment IDs from the tags (usually only one) | |
| env_ids = pattern_match(tags, ALL_ENV_IDS) | |
| logger.info(f"Selected environments: {env_ids}") | |
| results = {} | |
| # Check if the agent exists | |
| try: | |
| agent_path = API.hf_hub_download(repo_id=model_id, filename="agent.pt") | |
| except EntryNotFoundError: | |
| logger.error("Agent not found") | |
| return None | |
| # Check safety | |
| security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security | |
| if security is None or "safe" not in security: | |
| logger.error("Agent safety not available") | |
| return None | |
| elif not security["safe"]: | |
| logger.error("Agent not safe") | |
| return None | |
| # Load the agent | |
| try: | |
| agent = torch.jit.load(agent_path) | |
| except Exception as e: | |
| logger.error(f"Error loading agent: {e}") | |
| return None | |
| # Evaluate the agent on the environments | |
| for env_id in env_ids: | |
| envs = gym.vector.SyncVectorEnv([make(env_id) for _ in range(3)]) | |
| observations, _ = envs.reset() | |
| episodic_returns = [] | |
| while len(episodic_returns) < 10: | |
| actions = agent(torch.tensor(observations)).numpy() | |
| observations, _, _, _, infos = envs.step(actions) | |
| if "final_info" in infos: | |
| for info in infos["final_info"]: | |
| if info is None or "episode" not in info: | |
| continue | |
| episodic_returns.append(float(info["episode"]["r"])) | |
| results[env_id] = {"episodic_returns": episodic_returns} | |
| logger.info(f"Environment {env_id}: {np.mean(episodic_returns)} ± {np.std(episodic_returns)}") | |
| return results | |