Quentin Gallouédec commited on
Commit
c67a861
1 Parent(s): 1b0277d

remove backend from the front

Browse files
Files changed (8) hide show
  1. README.md +1 -1
  2. app.py +20 -36
  3. packages.txt +0 -3
  4. requirements.txt +7 -24
  5. src/backend.py +0 -90
  6. src/css_html_js.py +0 -20
  7. src/evaluation.py +0 -365
  8. src/logging.py +0 -37
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🥇
4
  colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 4.20.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
 
4
  colorFrom: green
5
  colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 4.31.0
8
  app_file: app.py
9
  pinned: true
10
  license: apache-2.0
app.py CHANGED
@@ -1,5 +1,3 @@
1
- import glob
2
- import json
3
  import logging
4
  import os
5
 
@@ -8,17 +6,22 @@ import numpy as np
8
  import pandas as pd
9
  import scipy.stats
10
  from apscheduler.schedulers.background import BackgroundScheduler
 
11
  from huggingface_hub import HfApi
12
 
13
- from src.backend import backend_routine
14
- from src.logging import configure_root_logger, setup_logger
 
 
 
 
 
15
 
16
- configure_root_logger()
17
- logger = setup_logger(__name__)
18
  logging.getLogger("absl").setLevel(logging.WARNING)
19
 
20
  API = HfApi(token=os.environ.get("TOKEN"))
21
- RESULTS_REPO = "open-rl-leaderboard/results"
22
  REFRESH_RATE = 5 * 60 # 5 minutes
23
  ALL_ENV_IDS = {
24
  "Atari": [
@@ -127,30 +130,11 @@ def iqm(x):
127
 
128
  def get_leaderboard_df():
129
  logger.info("Downloading results")
130
- dir_path = API.snapshot_download(repo_id=RESULTS_REPO, repo_type="dataset")
131
- pattern = os.path.join(dir_path, "**", "results_*.json")
132
- filenames = glob.glob(pattern, recursive=True)
133
-
134
- data = []
135
- for filename in filenames:
136
- try:
137
- with open(filename) as fp:
138
- report = json.load(fp)
139
- if report["status"] == "DONE" and len(report["results"]) > 0:
140
- user_id, model_id = report["config"]["model_id"].split("/")
141
- row = {"user_id": user_id, "model_id": model_id, "model_sha": report["config"]["model_sha"]}
142
- env_ids = list(report["results"].keys())
143
- assert len(env_ids) == 1, "Only one environment supported for the moment"
144
- row["env_id"] = env_ids[0]
145
- row["iqm_episodic_return"] = iqm(report["results"][env_ids[0]]["episodic_returns"])
146
- data.append(row)
147
- except Exception as e:
148
- logger.error(f"Error while processing {filename}: {e}")
149
-
150
- df = pd.DataFrame(data) # create DataFrame
151
- df = df.fillna("") # replace NaN values with empty strings
152
- # Save to csv
153
- df.to_csv("leaderboard.csv", index=False)
154
  return df
155
 
156
 
@@ -180,10 +164,10 @@ def refresh_video(df, env_id):
180
  if not env_df.empty:
181
  user_id = env_df.iloc[0]["user_id"]
182
  model_id = env_df.iloc[0]["model_id"]
183
- model_sha = env_df.iloc[0]["model_sha"]
184
  repo_id = f"{user_id}/{model_id}"
185
  try:
186
- video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=model_sha, repo_type="model")
187
  return video_path
188
  except Exception as e:
189
  logger.error(f"Error while downloading video for {env_id}: {e}")
@@ -217,7 +201,8 @@ This leaderboard is quite empty... 😢
217
  Be the first to submit your model!
218
  Check the tab "🚀 Getting my agent evaluated"
219
  """
220
-
 
221
  def refresh_num_models(df):
222
  return f"The leaderboard currently contains {len(df):,} models."
223
 
@@ -269,7 +254,7 @@ with gr.Blocks(css=css) as demo:
269
  # If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
270
  tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
271
  with gr.TabItem(tab_env_id) as tab:
272
- logger.info(f"Creating tab for {env_id}")
273
  with gr.Row(equal_height=False):
274
  with gr.Column(scale=3):
275
  gr_df = gr.components.Dataframe(
@@ -308,7 +293,6 @@ with gr.Blocks(css=css) as demo:
308
  demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])
309
 
310
  scheduler = BackgroundScheduler()
311
- scheduler.add_job(func=backend_routine, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
312
  scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
313
  scheduler.start()
314
 
 
 
 
1
  import logging
2
  import os
3
 
 
6
  import pandas as pd
7
  import scipy.stats
8
  from apscheduler.schedulers.background import BackgroundScheduler
9
+ from datasets import load_dataset
10
  from huggingface_hub import HfApi
11
 
12
+ # Set up logging
13
+ logger = logging.getLogger("app")
14
+ logger.setLevel(logging.INFO)
15
+ formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
16
+ ch = logging.StreamHandler()
17
+ ch.setFormatter(formatter)
18
+ logger.addHandler(ch)
19
 
20
+ # Disable the absl logger (annoying)
 
21
  logging.getLogger("absl").setLevel(logging.WARNING)
22
 
23
  API = HfApi(token=os.environ.get("TOKEN"))
24
+ RESULTS_REPO = "open-rl-leaderboard/results_v2"
25
  REFRESH_RATE = 5 * 60 # 5 minutes
26
  ALL_ENV_IDS = {
27
  "Atari": [
 
130
 
131
  def get_leaderboard_df():
132
  logger.info("Downloading results")
133
+ dataset = load_dataset(RESULTS_REPO, split="train") # split is not important, but we need to use "train")
134
+ df = dataset.to_pandas() # convert to pandas dataframe
135
+ df = df[df["status"] == "DONE"] # keep only the models that are done
136
+ df["iqm_episodic_return"] = df["episodic_returns"].apply(iqm)
137
+ logger.debug("Results downloaded")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  return df
139
 
140
 
 
164
  if not env_df.empty:
165
  user_id = env_df.iloc[0]["user_id"]
166
  model_id = env_df.iloc[0]["model_id"]
167
+ sha = env_df.iloc[0]["sha"]
168
  repo_id = f"{user_id}/{model_id}"
169
  try:
170
+ video_path = API.hf_hub_download(repo_id=repo_id, filename="replay.mp4", revision=sha, repo_type="model")
171
  return video_path
172
  except Exception as e:
173
  logger.error(f"Error while downloading video for {env_id}: {e}")
 
201
  Be the first to submit your model!
202
  Check the tab "🚀 Getting my agent evaluated"
203
  """
204
+
205
+
206
  def refresh_num_models(df):
207
  return f"The leaderboard currently contains {len(df):,} models."
208
 
 
254
  # If the env_id envs with "NoFrameskip-v4", we remove it to improve readability
255
  tab_env_id = env_id[: -len("NoFrameskip-v4")] if env_id.endswith("NoFrameskip-v4") else env_id
256
  with gr.TabItem(tab_env_id) as tab:
257
+ logger.debug(f"Creating tab for {env_id}")
258
  with gr.Row(equal_height=False):
259
  with gr.Column(scale=3):
260
  gr_df = gr.components.Dataframe(
 
293
  demo.load(refresh, outputs=list(all_gr_dfs.values()) + list(all_gr_winners.values()) + [num_models_md])
294
 
295
  scheduler = BackgroundScheduler()
 
296
  scheduler.add_job(func=update_globals, trigger="interval", seconds=REFRESH_RATE, max_instances=1)
297
  scheduler.start()
298
 
packages.txt DELETED
@@ -1,3 +0,0 @@
1
- swig
2
- libosmesa6-dev
3
- patchelf
 
 
 
 
requirements.txt CHANGED
@@ -1,24 +1,7 @@
1
- APScheduler==3.10.1
2
- black==23.11.0
3
- click==8.1.3
4
- datasets==2.14.5
5
- gradio==4.20.0
6
- gradio_client
7
- gymnasium[all,accept-rom-license]==0.29.1
8
- huggingface-hub>=0.18.0
9
- matplotlib==3.7.1
10
- free-mujoco-py
11
- mujoco<=2.3.7
12
- numpy==1.24.2
13
- pandas==2.0.0
14
- python-dateutil==2.8.2
15
- requests==2.28.2
16
- rliable==1.0.8
17
- torch==2.2.2
18
- tqdm==4.65.0
19
-
20
-
21
- # Log Visualizer
22
- BeautifulSoup4==4.12.2
23
- lxml==4.9.3
24
- rich==13.3.4
 
1
+ APScheduler==3.10.4
2
+ datasets==2.19.1
3
+ gradio==4.31.2
4
+ huggingface-hub==0.23.0
5
+ numpy==1.26.4
6
+ pandas==2.2.2
7
+ scipy==1.13.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/backend.py DELETED
@@ -1,90 +0,0 @@
1
- import json
2
- import os
3
- import random
4
- import re
5
- import tempfile
6
-
7
- from huggingface_hub import CommitOperationAdd, HfApi
8
-
9
- from src.evaluation import evaluate
10
- from src.logging import setup_logger
11
-
12
- logger = setup_logger(__name__)
13
-
14
- API = HfApi(token=os.environ.get("TOKEN"))
15
- RESULTS_REPO = "open-rl-leaderboard/results"
16
-
17
-
18
- def _backend_routine():
19
- # List only the text classification models
20
- rl_models = list(API.list_models(filter="reinforcement-learning"))
21
- logger.info(f"Found {len(rl_models)} RL models")
22
- compatible_models = []
23
- for model in rl_models:
24
- filenames = [sib.rfilename for sib in model.siblings]
25
- if "agent.pt" in filenames:
26
- compatible_models.append((model.modelId, model.sha))
27
-
28
- logger.info(f"Found {len(compatible_models)} compatible models")
29
-
30
- # Get the results
31
- pattern = re.compile(r"^[^/]*/[^/]*/[^/]*results_[a-f0-9]+\.json$")
32
- filenames = API.list_repo_files(RESULTS_REPO, repo_type="dataset")
33
- filenames = [filename for filename in filenames if pattern.match(filename)]
34
-
35
- evaluated_models = set()
36
- for filename in filenames:
37
- path = API.hf_hub_download(repo_id=RESULTS_REPO, filename=filename, repo_type="dataset")
38
- with open(path) as fp:
39
- report = json.load(fp)
40
- evaluated_models.add((report["config"]["model_id"], report["config"]["model_sha"]))
41
-
42
- # Find the models that are not associated with any results
43
- pending_models = list(set(compatible_models) - evaluated_models)
44
- logger.info(f"Found {len(pending_models)} pending models")
45
-
46
- if len(pending_models) == 0:
47
- return None
48
-
49
- # Run an evaluation on the models
50
- with tempfile.TemporaryDirectory() as tmp_dir:
51
- commits = []
52
- model_id, sha = random.choice(pending_models)
53
- logger.info(f"Running evaluation on {model_id}")
54
- report = {"config": {"model_id": model_id, "model_sha": sha}}
55
- try:
56
- evaluations = evaluate(model_id, revision=sha)
57
- except Exception as e:
58
- logger.error(f"Error evaluating {model_id}: {e}")
59
- evaluations = None
60
-
61
- if evaluations is not None:
62
- report["results"] = evaluations
63
- report["status"] = "DONE"
64
- else:
65
- report["status"] = "FAILED"
66
-
67
- # Update the results
68
- dumped = json.dumps(report, indent=2)
69
- path_in_repo = f"{model_id}/results_{sha}.json"
70
- local_path = os.path.join(tmp_dir, path_in_repo)
71
- os.makedirs(os.path.dirname(local_path), exist_ok=True)
72
- with open(local_path, "w") as f:
73
- f.write(dumped)
74
-
75
- commits.append(CommitOperationAdd(path_in_repo=path_in_repo, path_or_fileobj=local_path))
76
-
77
- API.create_commit(
78
- repo_id=RESULTS_REPO, commit_message="Add evaluation results", operations=commits, repo_type="dataset"
79
- )
80
-
81
-
82
- def backend_routine():
83
- try:
84
- _backend_routine()
85
- except Exception as e:
86
- logger.error(f"{e.__class__.__name__}: {str(e)}")
87
-
88
-
89
- if __name__ == "__main__":
90
- backend_routine()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/css_html_js.py DELETED
@@ -1,20 +0,0 @@
1
- style_content = """
2
- pre, code {
3
- background-color: #272822;
4
- }
5
- .scrollable {
6
- font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace;
7
- height: 500px;
8
- overflow: auto;
9
- }
10
- """
11
- dark_mode_gradio_js = """
12
- function refresh() {
13
- const url = new URL(window.location);
14
-
15
- if (url.searchParams.get('__theme') !== 'dark') {
16
- url.searchParams.set('__theme', 'dark');
17
- window.location.href = url.href;
18
- }
19
- }
20
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/evaluation.py DELETED
@@ -1,365 +0,0 @@
1
- import fnmatch
2
- import os
3
- from typing import Dict, SupportsFloat
4
-
5
- import gymnasium as gym
6
- import numpy as np
7
- import torch
8
- from gymnasium import wrappers
9
- from huggingface_hub import HfApi
10
- from huggingface_hub.utils._errors import EntryNotFoundError
11
-
12
- from src.logging import setup_logger
13
-
14
- logger = setup_logger(__name__)
15
-
16
- API = HfApi(token=os.environ.get("TOKEN"))
17
-
18
-
19
- ALL_ENV_IDS = [
20
- "AdventureNoFrameskip-v4",
21
- "AirRaidNoFrameskip-v4",
22
- "AlienNoFrameskip-v4",
23
- "AmidarNoFrameskip-v4",
24
- "AssaultNoFrameskip-v4",
25
- "AsterixNoFrameskip-v4",
26
- "AsteroidsNoFrameskip-v4",
27
- "AtlantisNoFrameskip-v4",
28
- "BankHeistNoFrameskip-v4",
29
- "BattleZoneNoFrameskip-v4",
30
- "BeamRiderNoFrameskip-v4",
31
- "BerzerkNoFrameskip-v4",
32
- "BowlingNoFrameskip-v4",
33
- "BoxingNoFrameskip-v4",
34
- "BreakoutNoFrameskip-v4",
35
- "CarnivalNoFrameskip-v4",
36
- "CentipedeNoFrameskip-v4",
37
- "ChopperCommandNoFrameskip-v4",
38
- "CrazyClimberNoFrameskip-v4",
39
- "DefenderNoFrameskip-v4",
40
- "DemonAttackNoFrameskip-v4",
41
- "DoubleDunkNoFrameskip-v4",
42
- "ElevatorActionNoFrameskip-v4",
43
- "EnduroNoFrameskip-v4",
44
- "FishingDerbyNoFrameskip-v4",
45
- "FreewayNoFrameskip-v4",
46
- "FrostbiteNoFrameskip-v4",
47
- "GopherNoFrameskip-v4",
48
- "GravitarNoFrameskip-v4",
49
- "HeroNoFrameskip-v4",
50
- "IceHockeyNoFrameskip-v4",
51
- "JamesbondNoFrameskip-v4",
52
- "JourneyEscapeNoFrameskip-v4",
53
- "KangarooNoFrameskip-v4",
54
- "KrullNoFrameskip-v4",
55
- "KungFuMasterNoFrameskip-v4",
56
- "MontezumaRevengeNoFrameskip-v4",
57
- "MsPacmanNoFrameskip-v4",
58
- "NameThisGameNoFrameskip-v4",
59
- "PhoenixNoFrameskip-v4",
60
- "PitfallNoFrameskip-v4",
61
- "PongNoFrameskip-v4",
62
- "PooyanNoFrameskip-v4",
63
- "PrivateEyeNoFrameskip-v4",
64
- "QbertNoFrameskip-v4",
65
- "RiverraidNoFrameskip-v4",
66
- "RoadRunnerNoFrameskip-v4",
67
- "RobotankNoFrameskip-v4",
68
- "SeaquestNoFrameskip-v4",
69
- "SkiingNoFrameskip-v4",
70
- "SolarisNoFrameskip-v4",
71
- "SpaceInvadersNoFrameskip-v4",
72
- "StarGunnerNoFrameskip-v4",
73
- "TennisNoFrameskip-v4",
74
- "TimePilotNoFrameskip-v4",
75
- "TutankhamNoFrameskip-v4",
76
- "UpNDownNoFrameskip-v4",
77
- "VentureNoFrameskip-v4",
78
- "VideoPinballNoFrameskip-v4",
79
- "WizardOfWorNoFrameskip-v4",
80
- "YarsRevengeNoFrameskip-v4",
81
- "ZaxxonNoFrameskip-v4",
82
- # Box2D
83
- "BipedalWalker-v3",
84
- "BipedalWalkerHardcore-v3",
85
- "CarRacing-v2",
86
- "LunarLander-v2",
87
- "LunarLanderContinuous-v2",
88
- # Toy text
89
- "Blackjack-v1",
90
- "CliffWalking-v0",
91
- "FrozenLake-v1",
92
- "FrozenLake8x8-v1",
93
- # Classic control
94
- "Acrobot-v1",
95
- "CartPole-v1",
96
- "MountainCar-v0",
97
- "MountainCarContinuous-v0",
98
- "Pendulum-v1",
99
- # MuJoCo
100
- "Ant-v4",
101
- "HalfCheetah-v4",
102
- "Hopper-v4",
103
- "Humanoid-v4",
104
- "HumanoidStandup-v4",
105
- "InvertedDoublePendulum-v4",
106
- "InvertedPendulum-v4",
107
- "Pusher-v4",
108
- "Reacher-v4",
109
- "Swimmer-v4",
110
- "Walker2d-v4",
111
- ]
112
-
113
- NUM_EPISODES = 50
114
-
115
-
116
- class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
117
- """
118
- Sample initial states by taking random number of no-ops on reset.
119
- No-op is assumed to be action 0.
120
-
121
- :param env: Environment to wrap
122
- :param noop_max: Maximum value of no-ops to run
123
- """
124
-
125
- def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
126
- super().__init__(env)
127
- self.noop_max = noop_max
128
- self.override_num_noops = None
129
- self.noop_action = 0
130
- assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
131
-
132
- def reset(self, **kwargs):
133
- self.env.reset(**kwargs)
134
- if self.override_num_noops is not None:
135
- noops = self.override_num_noops
136
- else:
137
- noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
138
- assert noops > 0
139
- obs = np.zeros(0)
140
- info: Dict = {}
141
- for _ in range(noops):
142
- obs, _, terminated, truncated, info = self.env.step(self.noop_action)
143
- if terminated or truncated:
144
- obs, info = self.env.reset(**kwargs)
145
- return obs, info
146
-
147
-
148
- class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
149
- """
150
- Take action on reset for environments that are fixed until firing.
151
-
152
- :param env: Environment to wrap
153
- """
154
-
155
- def __init__(self, env: gym.Env) -> None:
156
- super().__init__(env)
157
- assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
158
- assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
159
-
160
- def reset(self, **kwargs):
161
- self.env.reset(**kwargs)
162
- obs, _, terminated, truncated, _ = self.env.step(1)
163
- if terminated or truncated:
164
- self.env.reset(**kwargs)
165
- obs, _, terminated, truncated, _ = self.env.step(2)
166
- if terminated or truncated:
167
- self.env.reset(**kwargs)
168
- return obs, {}
169
-
170
-
171
- class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
172
- """
173
- Make end-of-life == end-of-episode, but only reset on true game over.
174
- Done by DeepMind for the DQN and co. since it helps value estimation.
175
-
176
- :param env: Environment to wrap
177
- """
178
-
179
- def __init__(self, env: gym.Env) -> None:
180
- super().__init__(env)
181
- self.lives = 0
182
- self.was_real_done = True
183
-
184
- def step(self, action: int):
185
- obs, reward, terminated, truncated, info = self.env.step(action)
186
- self.was_real_done = terminated or truncated
187
- # check current lives, make loss of life terminal,
188
- # then update lives to handle bonus lives
189
- lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
190
- if 0 < lives < self.lives:
191
- # for Qbert sometimes we stay in lives == 0 condition for a few frames
192
- # so its important to keep lives > 0, so that we only reset once
193
- # the environment advertises done.
194
- terminated = True
195
- self.lives = lives
196
- return obs, reward, terminated, truncated, info
197
-
198
- def reset(self, **kwargs):
199
- """
200
- Calls the Gym environment reset, only when lives are exhausted.
201
- This way all states are still reachable even though lives are episodic,
202
- and the learner need not know about any of this behind-the-scenes.
203
-
204
- :param kwargs: Extra keywords passed to env.reset() call
205
- :return: the first observation of the environment
206
- """
207
- if self.was_real_done:
208
- obs, info = self.env.reset(**kwargs)
209
- else:
210
- # no-op step to advance from terminal/lost life state
211
- obs, _, terminated, truncated, info = self.env.step(0)
212
-
213
- # The no-op step can lead to a game over, so we need to check it again
214
- # to see if we should reset the environment and avoid the
215
- # monitor.py `RuntimeError: Tried to step environment that needs reset`
216
- if terminated or truncated:
217
- obs, info = self.env.reset(**kwargs)
218
- self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
219
- return obs, info
220
-
221
-
222
- class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
223
- """
224
- Return only every ``skip``-th frame (frameskipping)
225
- and return the max between the two last frames.
226
-
227
- :param env: Environment to wrap
228
- :param skip: Number of ``skip``-th frame
229
- The same action will be taken ``skip`` times.
230
- """
231
-
232
- def __init__(self, env: gym.Env, skip: int = 4) -> None:
233
- super().__init__(env)
234
- # most recent raw observations (for max pooling across time steps)
235
- assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
236
- assert env.observation_space.shape is not None, "No shape defined for the observation space"
237
- self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
238
- self._skip = skip
239
-
240
- def step(self, action: int):
241
- """
242
- Step the environment with the given action
243
- Repeat action, sum reward, and max over last observations.
244
-
245
- :param action: the action
246
- :return: observation, reward, terminated, truncated, information
247
- """
248
- total_reward = 0.0
249
- terminated = truncated = False
250
- for i in range(self._skip):
251
- obs, reward, terminated, truncated, info = self.env.step(action)
252
- done = terminated or truncated
253
- if i == self._skip - 2:
254
- self._obs_buffer[0] = obs
255
- if i == self._skip - 1:
256
- self._obs_buffer[1] = obs
257
- total_reward += float(reward)
258
- if done:
259
- break
260
- # Note that the observation on the done=True frame
261
- # doesn't matter
262
- max_frame = self._obs_buffer.max(axis=0)
263
-
264
- return max_frame, total_reward, terminated, truncated, info
265
-
266
-
267
- class ClipRewardEnv(gym.RewardWrapper):
268
- """
269
- Clip the reward to {+1, 0, -1} by its sign.
270
-
271
- :param env: Environment to wrap
272
- """
273
-
274
- def __init__(self, env: gym.Env) -> None:
275
- super().__init__(env)
276
-
277
- def reward(self, reward: SupportsFloat) -> float:
278
- """
279
- Bin reward to {+1, 0, -1} by its sign.
280
-
281
- :param reward:
282
- :return:
283
- """
284
- return np.sign(float(reward))
285
-
286
-
287
- def make(env_id):
288
- def thunk():
289
- env = gym.make(env_id)
290
- env = wrappers.RecordEpisodeStatistics(env)
291
- if "NoFrameskip" in env_id:
292
- env = NoopResetEnv(env, noop_max=30)
293
- env = MaxAndSkipEnv(env, skip=4)
294
- env = EpisodicLifeEnv(env)
295
- if "FIRE" in env.unwrapped.get_action_meanings():
296
- env = FireResetEnv(env)
297
- env = ClipRewardEnv(env)
298
- env = wrappers.ResizeObservation(env, (84, 84))
299
- env = wrappers.GrayScaleObservation(env)
300
- env = wrappers.FrameStack(env, 4)
301
- return env
302
-
303
- return thunk
304
-
305
-
306
- def pattern_match(patterns, source_list):
307
- if isinstance(patterns, str):
308
- patterns = [patterns]
309
-
310
- env_ids = set()
311
- for pattern in patterns:
312
- for matching in fnmatch.filter(source_list, pattern):
313
- env_ids.add(matching)
314
- return sorted(list(env_ids))
315
-
316
-
317
- def evaluate(model_id, revision):
318
- tags = API.model_info(model_id, revision=revision).tags
319
-
320
- # Extract the environment IDs from the tags (usually only one)
321
- env_ids = pattern_match(tags, ALL_ENV_IDS)
322
- logger.info(f"Selected environments: {env_ids}")
323
-
324
- results = {}
325
-
326
- # Check if the agent exists
327
- try:
328
- agent_path = API.hf_hub_download(repo_id=model_id, filename="agent.pt")
329
- except EntryNotFoundError:
330
- logger.error("Agent not found")
331
- return None
332
-
333
- # Check safety
334
- security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security
335
- if security is None or "safe" not in security:
336
- logger.warn("Agent safety not available")
337
- # return None
338
- elif not security["safe"]:
339
- logger.error("Agent not safe")
340
- return None
341
-
342
- # Load the agent
343
- try:
344
- agent = torch.jit.load(agent_path).to("cuda")
345
- except Exception as e:
346
- logger.error(f"Error loading agent: {e}")
347
- return None
348
-
349
- # Evaluate the agent on the environments
350
- for env_id in env_ids:
351
- envs = gym.vector.SyncVectorEnv([make(env_id) for _ in range(1)])
352
- observations, _ = envs.reset()
353
- episodic_returns = []
354
- while len(episodic_returns) < NUM_EPISODES:
355
- actions = agent(torch.tensor(observations)).numpy()
356
- observations, _, _, _, infos = envs.step(actions)
357
- if "final_info" in infos:
358
- for info in infos["final_info"]:
359
- if info is None or "episode" not in info:
360
- continue
361
- episodic_returns.append(float(info["episode"]["r"]))
362
-
363
- results[env_id] = {"episodic_returns": episodic_returns}
364
- logger.info(f"Environment {env_id}: {np.mean(episodic_returns)} ± {np.std(episodic_returns)}")
365
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/logging.py DELETED
@@ -1,37 +0,0 @@
1
- from pathlib import Path
2
-
3
- proj_dir = Path(__file__).parents[1]
4
-
5
- log_file = proj_dir / "output.log"
6
-
7
-
8
- import logging
9
-
10
-
11
- def setup_logger(name: str):
12
- logger = logging.getLogger(name)
13
- logger.setLevel(logging.INFO)
14
-
15
- formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
16
-
17
- # Create a file handler to write logs to a file
18
- file_handler = logging.FileHandler(log_file)
19
- file_handler.setLevel(logging.INFO)
20
- file_handler.setFormatter(formatter)
21
- logger.addHandler(file_handler)
22
-
23
- return logger
24
-
25
-
26
- def configure_root_logger():
27
- # Configure the root logger
28
- logging.basicConfig(level=logging.INFO)
29
- root_logger = logging.getLogger()
30
-
31
- formatter = logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
32
-
33
- file_handler = logging.FileHandler(log_file)
34
- file_handler.setLevel(logging.INFO)
35
- file_handler.setFormatter(formatter)
36
-
37
- root_logger.addHandler(file_handler)