qgallouedec HF staff commited on
Commit
a40ca17
1 Parent(s): 37d4b84

initial commit

Browse files
Files changed (9) hide show
  1. Makefile +13 -0
  2. README.md +7 -6
  3. app.py +90 -0
  4. packages.txt +3 -0
  5. pyproject.toml +15 -0
  6. requirements.txt +24 -0
  7. src/backend.py +90 -0
  8. src/evaluation.py +365 -0
  9. src/logging.py +37 -0
Makefile ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .PHONY: style format
2
+
3
+
4
+ style:
5
+ python -m black --line-length 119 src app.py
6
+ python -m isort src app.py
7
+ ruff check --fix src app.py
8
+
9
+
10
+ quality:
11
+ python -m black --check --line-length 119 src app.py
12
+ python -m isort --check-only src app.py
13
+ ruff check src app.py
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
  title: Backend
3
- emoji: 🦀
4
  colorFrom: red
5
- colorTo: purple
6
  sdk: gradio
7
- sdk_version: 4.31.5
8
  app_file: app.py
9
- pinned: false
 
 
 
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Backend
3
+ emoji: 🥇
4
  colorFrom: red
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.20.0
8
  app_file: app.py
9
+ pinned: true
10
+ license: apache-2.0
11
+ tags:
12
+ - leaderboard
13
  ---
 
 
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ from functools import partial
3
+ from io import StringIO
4
+
5
+ import gradio as gr
6
+ from apscheduler.schedulers.background import BackgroundScheduler
7
+ from bs4 import BeautifulSoup
8
+ from rich.console import Console
9
+ from rich.syntax import Syntax
10
+
11
+ from src.backend import backend_routine
12
+ from src.logging import configure_root_logger, log_file, setup_logger
13
+
14
+ logging.getLogger("httpx").setLevel(logging.WARNING)
15
+ logging.getLogger("numexpr").setLevel(logging.WARNING)
16
+ logging.getLogger("absl").setLevel(logging.WARNING)
17
+
18
+ configure_root_logger()
19
+
20
+ logging.basicConfig(level=logging.INFO)
21
+ logger = setup_logger(__name__)
22
+
23
+
24
+ def log_file_to_html_string(reverse=True):
25
+ with open(log_file, "rt") as f:
26
+ lines = f.readlines()
27
+ lines = lines[-300:]
28
+
29
+ if reverse:
30
+ lines = reversed(lines)
31
+
32
+ output = "".join(lines)
33
+ syntax = Syntax(output, "python", theme="monokai", word_wrap=True)
34
+
35
+ console = Console(record=True, width=150, style="#272822", file=StringIO())
36
+ console.print(syntax)
37
+ html_content = console.export_html(inline_styles=True)
38
+
39
+ # Parse the HTML content using BeautifulSoup
40
+ soup = BeautifulSoup(html_content, "lxml")
41
+
42
+ # Modify the <pre> tag and add custom styles
43
+ pre_tag = soup.pre
44
+ pre_tag["class"] = "scrollable"
45
+ del pre_tag["style"]
46
+
47
+ # Add your custom styles and the .scrollable CSS to the <style> tag
48
+ style_tag = soup.style
49
+ style_tag.append(
50
+ """
51
+ pre, code {
52
+ background-color: #272822;
53
+ }
54
+ .scrollable {
55
+ font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace;
56
+ height: 500px;
57
+ overflow: auto;
58
+ }
59
+ """
60
+ )
61
+
62
+ return soup.prettify()
63
+
64
+
65
+ REPO_ID = "open-rl-leaderboard/leaderboard"
66
+ RESULTS_REPO = "open-rl-leaderboard/results_v2"
67
+
68
+
69
+ links_md = f"""
70
+ # Important links
71
+ | Description | Link |
72
+ |-----------------|------|
73
+ | Leaderboard | [{REPO_ID}](https://huggingface.co/spaces/{REPO_ID}) |
74
+ | Results Repo | [{RESULTS_REPO}](https://huggingface.co/datasets/{RESULTS_REPO}) |
75
+ """
76
+
77
+
78
+ with gr.Blocks() as demo:
79
+ gr.Markdown(links_md)
80
+ gr.HTML(partial(log_file_to_html_string), every=1)
81
+ with gr.Row():
82
+ gr.DownloadButton("Download Log File", value=log_file)
83
+
84
+
85
+ scheduler = BackgroundScheduler()
86
+ scheduler.add_job(func=backend_routine, trigger="interval", seconds=5 * 60, max_instances=1)
87
+ scheduler.start()
88
+
89
+ if __name__ == "__main__":
90
+ demo.queue().launch()
packages.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ swig
2
+ libosmesa6-dev
3
+ patchelf
pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.ruff]
2
+ line-length = 119
3
+
4
+ [tool.ruff.lint]
5
+ # Enable pycodestyle (`E`) and Pyflakes (`F`) codes by default.
6
+ select = ["E", "F"]
7
+ ignore = ["E501"] # line too long (black is taking care of this)
8
+ fixable = ["A", "B", "C", "D", "E", "F", "G", "I", "N", "Q", "S", "T", "W", "ANN", "ARG", "BLE", "COM", "DJ", "DTZ", "EM", "ERA", "EXE", "FBT", "ICN", "INP", "ISC", "NPY", "PD", "PGH", "PIE", "PL", "PT", "PTH", "PYI", "RET", "RSE", "RUF", "SIM", "SLF", "TCH", "TID", "TRY", "UP", "YTT"]
9
+
10
+ [tool.isort]
11
+ profile = "black"
12
+ line_length = 119
13
+
14
+ [tool.black]
15
+ line-length = 119
requirements.txt ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ APScheduler==3.10.1
2
+ black==23.11.0
3
+ click==8.1.3
4
+ datasets==2.14.5
5
+ gradio==4.20.0
6
+ gradio_client
7
+ gymnasium[all,accept-rom-license]==0.29.1
8
+ huggingface-hub>=0.18.0
9
+ matplotlib==3.7.1
10
+ free-mujoco-py
11
+ mujoco<=2.3.7
12
+ numpy==1.24.2
13
+ pandas==2.0.0
14
+ python-dateutil==2.8.2
15
+ requests==2.28.2
16
+ rliable==1.0.8
17
+ torch==2.2.2
18
+ tqdm==4.65.0
19
+
20
+
21
+ # Log Visualizer
22
+ BeautifulSoup4==4.12.2
23
+ lxml==4.9.3
24
+ rich==13.3.4
src/backend.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ import re
5
+ import tempfile
6
+
7
+ from huggingface_hub import CommitOperationAdd, HfApi
8
+
9
+ from src.evaluation import evaluate
10
+ from src.logging import setup_logger
11
+
12
+ logger = setup_logger(__name__)
13
+
14
+ API = HfApi(token=os.environ.get("TOKEN"))
15
+ RESULTS_REPO = "open-rl-leaderboard/results"
16
+
17
+
18
+ def _backend_routine():
19
+ # List only the text classification models
20
+ rl_models = list(API.list_models(filter="reinforcement-learning"))
21
+ logger.info(f"Found {len(rl_models)} RL models")
22
+ compatible_models = []
23
+ for model in rl_models:
24
+ filenames = [sib.rfilename for sib in model.siblings]
25
+ if "agent.pt" in filenames:
26
+ compatible_models.append((model.modelId, model.sha))
27
+
28
+ logger.info(f"Found {len(compatible_models)} compatible models")
29
+
30
+ # Get the results
31
+ pattern = re.compile(r"^[^/]*/[^/]*/[^/]*results_[a-f0-9]+\.json$")
32
+ filenames = API.list_repo_files(RESULTS_REPO, repo_type="dataset")
33
+ filenames = [filename for filename in filenames if pattern.match(filename)]
34
+
35
+ evaluated_models = set()
36
+ for filename in filenames:
37
+ path = API.hf_hub_download(repo_id=RESULTS_REPO, filename=filename, repo_type="dataset")
38
+ with open(path) as fp:
39
+ report = json.load(fp)
40
+ evaluated_models.add((report["config"]["model_id"], report["config"]["model_sha"]))
41
+
42
+ # Find the models that are not associated with any results
43
+ pending_models = list(set(compatible_models) - evaluated_models)
44
+ logger.info(f"Found {len(pending_models)} pending models")
45
+
46
+ if len(pending_models) == 0:
47
+ return None
48
+
49
+ # Run an evaluation on the models
50
+ with tempfile.TemporaryDirectory() as tmp_dir:
51
+ commits = []
52
+ model_id, sha = random.choice(pending_models)
53
+ logger.info(f"Running evaluation on {model_id}")
54
+ report = {"config": {"model_id": model_id, "model_sha": sha}}
55
+ try:
56
+ evaluations = evaluate(model_id, revision=sha)
57
+ except Exception as e:
58
+ logger.error(f"Error evaluating {model_id}: {e}")
59
+ evaluations = None
60
+
61
+ if evaluations is not None:
62
+ report["results"] = evaluations
63
+ report["status"] = "DONE"
64
+ else:
65
+ report["status"] = "FAILED"
66
+
67
+ # Update the results
68
+ dumped = json.dumps(report, indent=2)
69
+ path_in_repo = f"{model_id}/results_{sha}.json"
70
+ local_path = os.path.join(tmp_dir, path_in_repo)
71
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
72
+ with open(local_path, "w") as f:
73
+ f.write(dumped)
74
+
75
+ commits.append(CommitOperationAdd(path_in_repo=path_in_repo, path_or_fileobj=local_path))
76
+
77
+ API.create_commit(
78
+ repo_id=RESULTS_REPO, commit_message="Add evaluation results", operations=commits, repo_type="dataset"
79
+ )
80
+
81
+
82
+ def backend_routine():
83
+ try:
84
+ _backend_routine()
85
+ except Exception as e:
86
+ logger.error(f"{e.__class__.__name__}: {str(e)}")
87
+
88
+
89
+ if __name__ == "__main__":
90
+ backend_routine()
src/evaluation.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fnmatch
2
+ import os
3
+ from typing import Dict, SupportsFloat
4
+
5
+ import gymnasium as gym
6
+ import numpy as np
7
+ import torch
8
+ from gymnasium import wrappers
9
+ from huggingface_hub import HfApi
10
+ from huggingface_hub.utils._errors import EntryNotFoundError
11
+
12
+ from src.logging import setup_logger
13
+
14
+ logger = setup_logger(__name__)
15
+
16
+ API = HfApi(token=os.environ.get("TOKEN"))
17
+
18
+
19
+ ALL_ENV_IDS = [
20
+ "AdventureNoFrameskip-v4",
21
+ "AirRaidNoFrameskip-v4",
22
+ "AlienNoFrameskip-v4",
23
+ "AmidarNoFrameskip-v4",
24
+ "AssaultNoFrameskip-v4",
25
+ "AsterixNoFrameskip-v4",
26
+ "AsteroidsNoFrameskip-v4",
27
+ "AtlantisNoFrameskip-v4",
28
+ "BankHeistNoFrameskip-v4",
29
+ "BattleZoneNoFrameskip-v4",
30
+ "BeamRiderNoFrameskip-v4",
31
+ "BerzerkNoFrameskip-v4",
32
+ "BowlingNoFrameskip-v4",
33
+ "BoxingNoFrameskip-v4",
34
+ "BreakoutNoFrameskip-v4",
35
+ "CarnivalNoFrameskip-v4",
36
+ "CentipedeNoFrameskip-v4",
37
+ "ChopperCommandNoFrameskip-v4",
38
+ "CrazyClimberNoFrameskip-v4",
39
+ "DefenderNoFrameskip-v4",
40
+ "DemonAttackNoFrameskip-v4",
41
+ "DoubleDunkNoFrameskip-v4",
42
+ "ElevatorActionNoFrameskip-v4",
43
+ "EnduroNoFrameskip-v4",
44
+ "FishingDerbyNoFrameskip-v4",
45
+ "FreewayNoFrameskip-v4",
46
+ "FrostbiteNoFrameskip-v4",
47
+ "GopherNoFrameskip-v4",
48
+ "GravitarNoFrameskip-v4",
49
+ "HeroNoFrameskip-v4",
50
+ "IceHockeyNoFrameskip-v4",
51
+ "JamesbondNoFrameskip-v4",
52
+ "JourneyEscapeNoFrameskip-v4",
53
+ "KangarooNoFrameskip-v4",
54
+ "KrullNoFrameskip-v4",
55
+ "KungFuMasterNoFrameskip-v4",
56
+ "MontezumaRevengeNoFrameskip-v4",
57
+ "MsPacmanNoFrameskip-v4",
58
+ "NameThisGameNoFrameskip-v4",
59
+ "PhoenixNoFrameskip-v4",
60
+ "PitfallNoFrameskip-v4",
61
+ "PongNoFrameskip-v4",
62
+ "PooyanNoFrameskip-v4",
63
+ "PrivateEyeNoFrameskip-v4",
64
+ "QbertNoFrameskip-v4",
65
+ "RiverraidNoFrameskip-v4",
66
+ "RoadRunnerNoFrameskip-v4",
67
+ "RobotankNoFrameskip-v4",
68
+ "SeaquestNoFrameskip-v4",
69
+ "SkiingNoFrameskip-v4",
70
+ "SolarisNoFrameskip-v4",
71
+ "SpaceInvadersNoFrameskip-v4",
72
+ "StarGunnerNoFrameskip-v4",
73
+ "TennisNoFrameskip-v4",
74
+ "TimePilotNoFrameskip-v4",
75
+ "TutankhamNoFrameskip-v4",
76
+ "UpNDownNoFrameskip-v4",
77
+ "VentureNoFrameskip-v4",
78
+ "VideoPinballNoFrameskip-v4",
79
+ "WizardOfWorNoFrameskip-v4",
80
+ "YarsRevengeNoFrameskip-v4",
81
+ "ZaxxonNoFrameskip-v4",
82
+ # Box2D
83
+ "BipedalWalker-v3",
84
+ "BipedalWalkerHardcore-v3",
85
+ "CarRacing-v2",
86
+ "LunarLander-v2",
87
+ "LunarLanderContinuous-v2",
88
+ # Toy text
89
+ "Blackjack-v1",
90
+ "CliffWalking-v0",
91
+ "FrozenLake-v1",
92
+ "FrozenLake8x8-v1",
93
+ # Classic control
94
+ "Acrobot-v1",
95
+ "CartPole-v1",
96
+ "MountainCar-v0",
97
+ "MountainCarContinuous-v0",
98
+ "Pendulum-v1",
99
+ # MuJoCo
100
+ "Ant-v4",
101
+ "HalfCheetah-v4",
102
+ "Hopper-v4",
103
+ "Humanoid-v4",
104
+ "HumanoidStandup-v4",
105
+ "InvertedDoublePendulum-v4",
106
+ "InvertedPendulum-v4",
107
+ "Pusher-v4",
108
+ "Reacher-v4",
109
+ "Swimmer-v4",
110
+ "Walker2d-v4",
111
+ ]
112
+
113
+ NUM_EPISODES = 50
114
+
115
+
116
+ class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
117
+ """
118
+ Sample initial states by taking random number of no-ops on reset.
119
+ No-op is assumed to be action 0.
120
+
121
+ :param env: Environment to wrap
122
+ :param noop_max: Maximum value of no-ops to run
123
+ """
124
+
125
+ def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
126
+ super().__init__(env)
127
+ self.noop_max = noop_max
128
+ self.override_num_noops = None
129
+ self.noop_action = 0
130
+ assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
131
+
132
+ def reset(self, **kwargs):
133
+ self.env.reset(**kwargs)
134
+ if self.override_num_noops is not None:
135
+ noops = self.override_num_noops
136
+ else:
137
+ noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
138
+ assert noops > 0
139
+ obs = np.zeros(0)
140
+ info: Dict = {}
141
+ for _ in range(noops):
142
+ obs, _, terminated, truncated, info = self.env.step(self.noop_action)
143
+ if terminated or truncated:
144
+ obs, info = self.env.reset(**kwargs)
145
+ return obs, info
146
+
147
+
148
+ class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
149
+ """
150
+ Take action on reset for environments that are fixed until firing.
151
+
152
+ :param env: Environment to wrap
153
+ """
154
+
155
+ def __init__(self, env: gym.Env) -> None:
156
+ super().__init__(env)
157
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
158
+ assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
159
+
160
+ def reset(self, **kwargs):
161
+ self.env.reset(**kwargs)
162
+ obs, _, terminated, truncated, _ = self.env.step(1)
163
+ if terminated or truncated:
164
+ self.env.reset(**kwargs)
165
+ obs, _, terminated, truncated, _ = self.env.step(2)
166
+ if terminated or truncated:
167
+ self.env.reset(**kwargs)
168
+ return obs, {}
169
+
170
+
171
+ class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
172
+ """
173
+ Make end-of-life == end-of-episode, but only reset on true game over.
174
+ Done by DeepMind for the DQN and co. since it helps value estimation.
175
+
176
+ :param env: Environment to wrap
177
+ """
178
+
179
+ def __init__(self, env: gym.Env) -> None:
180
+ super().__init__(env)
181
+ self.lives = 0
182
+ self.was_real_done = True
183
+
184
+ def step(self, action: int):
185
+ obs, reward, terminated, truncated, info = self.env.step(action)
186
+ self.was_real_done = terminated or truncated
187
+ # check current lives, make loss of life terminal,
188
+ # then update lives to handle bonus lives
189
+ lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
190
+ if 0 < lives < self.lives:
191
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
192
+ # so its important to keep lives > 0, so that we only reset once
193
+ # the environment advertises done.
194
+ terminated = True
195
+ self.lives = lives
196
+ return obs, reward, terminated, truncated, info
197
+
198
+ def reset(self, **kwargs):
199
+ """
200
+ Calls the Gym environment reset, only when lives are exhausted.
201
+ This way all states are still reachable even though lives are episodic,
202
+ and the learner need not know about any of this behind-the-scenes.
203
+
204
+ :param kwargs: Extra keywords passed to env.reset() call
205
+ :return: the first observation of the environment
206
+ """
207
+ if self.was_real_done:
208
+ obs, info = self.env.reset(**kwargs)
209
+ else:
210
+ # no-op step to advance from terminal/lost life state
211
+ obs, _, terminated, truncated, info = self.env.step(0)
212
+
213
+ # The no-op step can lead to a game over, so we need to check it again
214
+ # to see if we should reset the environment and avoid the
215
+ # monitor.py `RuntimeError: Tried to step environment that needs reset`
216
+ if terminated or truncated:
217
+ obs, info = self.env.reset(**kwargs)
218
+ self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
219
+ return obs, info
220
+
221
+
222
+ class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
223
+ """
224
+ Return only every ``skip``-th frame (frameskipping)
225
+ and return the max between the two last frames.
226
+
227
+ :param env: Environment to wrap
228
+ :param skip: Number of ``skip``-th frame
229
+ The same action will be taken ``skip`` times.
230
+ """
231
+
232
+ def __init__(self, env: gym.Env, skip: int = 4) -> None:
233
+ super().__init__(env)
234
+ # most recent raw observations (for max pooling across time steps)
235
+ assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
236
+ assert env.observation_space.shape is not None, "No shape defined for the observation space"
237
+ self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
238
+ self._skip = skip
239
+
240
+ def step(self, action: int):
241
+ """
242
+ Step the environment with the given action
243
+ Repeat action, sum reward, and max over last observations.
244
+
245
+ :param action: the action
246
+ :return: observation, reward, terminated, truncated, information
247
+ """
248
+ total_reward = 0.0
249
+ terminated = truncated = False
250
+ for i in range(self._skip):
251
+ obs, reward, terminated, truncated, info = self.env.step(action)
252
+ done = terminated or truncated
253
+ if i == self._skip - 2:
254
+ self._obs_buffer[0] = obs
255
+ if i == self._skip - 1:
256
+ self._obs_buffer[1] = obs
257
+ total_reward += float(reward)
258
+ if done:
259
+ break
260
+ # Note that the observation on the done=True frame
261
+ # doesn't matter
262
+ max_frame = self._obs_buffer.max(axis=0)
263
+
264
+ return max_frame, total_reward, terminated, truncated, info
265
+
266
+
267
+ class ClipRewardEnv(gym.RewardWrapper):
268
+ """
269
+ Clip the reward to {+1, 0, -1} by its sign.
270
+
271
+ :param env: Environment to wrap
272
+ """
273
+
274
+ def __init__(self, env: gym.Env) -> None:
275
+ super().__init__(env)
276
+
277
+ def reward(self, reward: SupportsFloat) -> float:
278
+ """
279
+ Bin reward to {+1, 0, -1} by its sign.
280
+
281
+ :param reward:
282
+ :return:
283
+ """
284
+ return np.sign(float(reward))
285
+
286
+
287
+ def make(env_id):
288
+ def thunk():
289
+ env = gym.make(env_id)
290
+ env = wrappers.RecordEpisodeStatistics(env)
291
+ if "NoFrameskip" in env_id:
292
+ env = NoopResetEnv(env, noop_max=30)
293
+ env = MaxAndSkipEnv(env, skip=4)
294
+ env = EpisodicLifeEnv(env)
295
+ if "FIRE" in env.unwrapped.get_action_meanings():
296
+ env = FireResetEnv(env)
297
+ env = ClipRewardEnv(env)
298
+ env = wrappers.ResizeObservation(env, (84, 84))
299
+ env = wrappers.GrayScaleObservation(env)
300
+ env = wrappers.FrameStack(env, 4)
301
+ return env
302
+
303
+ return thunk
304
+
305
+
306
+ def pattern_match(patterns, source_list):
307
+ if isinstance(patterns, str):
308
+ patterns = [patterns]
309
+
310
+ env_ids = set()
311
+ for pattern in patterns:
312
+ for matching in fnmatch.filter(source_list, pattern):
313
+ env_ids.add(matching)
314
+ return sorted(list(env_ids))
315
+
316
+
317
+ def evaluate(model_id, revision):
318
+ tags = API.model_info(model_id, revision=revision).tags
319
+
320
+ # Extract the environment IDs from the tags (usually only one)
321
+ env_ids = pattern_match(tags, ALL_ENV_IDS)
322
+ logger.info(f"Selected environments: {env_ids}")
323
+
324
+ results = {}
325
+
326
+ # Check if the agent exists
327
+ try:
328
+ agent_path = API.hf_hub_download(repo_id=model_id, filename="agent.pt")
329
+ except EntryNotFoundError:
330
+ logger.error("Agent not found")
331
+ return None
332
+
333
+ # Check safety
334
+ security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security
335
+ if security is None or "safe" not in security:
336
+ logger.warn("Agent safety not available")
337
+ # return None
338
+ elif not security["safe"]:
339
+ logger.error("Agent not safe")
340
+ return None
341
+
342
+ # Load the agent
343
+ try:
344
+ agent = torch.jit.load(agent_path).to("cuda")
345
+ except Exception as e:
346
+ logger.error(f"Error loading agent: {e}")
347
+ return None
348
+
349
+ # Evaluate the agent on the environments
350
+ for env_id in env_ids:
351
+ envs = gym.vector.SyncVectorEnv([make(env_id) for _ in range(1)])
352
+ observations, _ = envs.reset()
353
+ episodic_returns = []
354
+ while len(episodic_returns) < NUM_EPISODES:
355
+ actions = agent(torch.tensor(observations)).numpy()
356
+ observations, _, _, _, infos = envs.step(actions)
357
+ if "final_info" in infos:
358
+ for info in infos["final_info"]:
359
+ if info is None or "episode" not in info:
360
+ continue
361
+ episodic_returns.append(float(info["episode"]["r"]))
362
+
363
+ results[env_id] = {"episodic_returns": episodic_returns}
364
+ logger.info(f"Environment {env_id}: {np.mean(episodic_returns)} ± {np.std(episodic_returns)}")
365
+ return results
src/logging.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ proj_dir = Path(__file__).parents[1]
4
+
5
+ log_file = proj_dir / "output.log"
6
+
7
+
8
+ import logging
9
+
10
+
11
+ def setup_logger(name: str):
12
+ logger = logging.getLogger(name)
13
+ logger.setLevel(logging.INFO)
14
+
15
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
16
+
17
+ # Create a file handler to write logs to a file
18
+ file_handler = logging.FileHandler(log_file)
19
+ file_handler.setLevel(logging.INFO)
20
+ file_handler.setFormatter(formatter)
21
+ logger.addHandler(file_handler)
22
+
23
+ return logger
24
+
25
+
26
+ def configure_root_logger():
27
+ # Configure the root logger
28
+ logging.basicConfig(level=logging.INFO)
29
+ root_logger = logging.getLogger()
30
+
31
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32
+
33
+ file_handler = logging.FileHandler(log_file)
34
+ file_handler.setLevel(logging.INFO)
35
+ file_handler.setFormatter(formatter)
36
+
37
+ root_logger.addHandler(file_handler)