Quentin Gallouédec commited on
Commit
76e0bcf
·
1 Parent(s): 4a5bd80

move eval to dedicated file

Browse files
Files changed (2) hide show
  1. app.py +3 -86
  2. src/evaluation.py +277 -0
app.py CHANGED
@@ -1,40 +1,24 @@
1
- import fnmatch
2
  import glob
3
  import json
4
- import logging
5
  import os
6
  import pprint
7
 
8
  import gradio as gr
9
- import gymnasium as gym
10
- import numpy as np
11
  import pandas as pd
12
- import torch
13
  from apscheduler.schedulers.background import BackgroundScheduler
14
- from huggingface_hub import hf_hub_download, snapshot_download
15
- from huggingface_hub.utils._errors import EntryNotFoundError
16
 
17
  from src.css_html_js import dark_mode_gradio_js
18
  from src.envs import API, RESULTS_PATH, RESULTS_REPO, TOKEN
 
19
  from src.logging import configure_root_logger, setup_logger
20
 
21
- logging.getLogger("openai").setLevel(logging.WARNING)
22
- logger = setup_logger(__name__)
23
-
24
  configure_root_logger()
25
  logger = setup_logger(__name__)
26
 
27
  pp = pprint.PrettyPrinter(width=80)
28
 
29
 
30
- ALL_ENV_IDS = [
31
- "CartPole-v1",
32
- "MountainCar-v0",
33
- "Acrobot-v1",
34
- "Hopper-v4",
35
- ]
36
-
37
-
38
  def model_hyperlink(link, model_id):
39
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_id}</a>'
40
 
@@ -44,73 +28,6 @@ def make_clickable_model(model_id):
44
  return model_hyperlink(link, model_id)
45
 
46
 
47
- def pattern_match(patterns, source_list):
48
- if isinstance(patterns, str):
49
- patterns = [patterns]
50
-
51
- env_ids = set()
52
- for pattern in patterns:
53
- for matching in fnmatch.filter(source_list, pattern):
54
- env_ids.add(matching)
55
- return sorted(list(env_ids))
56
-
57
-
58
- def evaluate(model_id, revision):
59
- tags = API.model_info(model_id, revision=revision).tags
60
-
61
- # Extract the environment IDs from the tags (usually only one)
62
- env_ids = pattern_match(tags, ALL_ENV_IDS)
63
- logger.info(f"Selected environments: {env_ids}")
64
-
65
- results = {}
66
-
67
- # Check if the agent exists
68
- try:
69
- agent_path = hf_hub_download(repo_id=model_id, filename="agent.pt")
70
- except EntryNotFoundError:
71
- logger.error("Agent not found")
72
- return None
73
-
74
- # Check safety
75
- security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security
76
- if security is None or "safe" not in security:
77
- logger.error("Agent safety not available")
78
- return None
79
- elif not security["safe"]:
80
- logger.error("Agent not safe")
81
- return None
82
-
83
- # Load the agent
84
- try:
85
- agent = torch.jit.load(agent_path)
86
- except Exception as e:
87
- logger.error(f"Error loading agent: {e}")
88
- return None
89
-
90
- # Evaluate the agent on the environments
91
- for env_id in env_ids:
92
- episodic_rewards = []
93
- env = gym.make(env_id)
94
- for _ in range(10):
95
- episodic_reward = 0.0
96
- observation, info = env.reset()
97
- done = False
98
- while not done:
99
- torch_observation = torch.from_numpy(np.array([observation]))
100
- action = agent(torch_observation).numpy()[0]
101
- observation, reward, terminated, truncated, info = env.step(action)
102
- done = terminated or truncated
103
- episodic_reward += reward
104
-
105
- episodic_rewards.append(episodic_reward)
106
-
107
- mean_reward = np.mean(episodic_rewards)
108
- std_reward = np.std(episodic_rewards)
109
- results[env_id] = {"episodic_return_mean": mean_reward, "episodic_reward_std": std_reward}
110
- logger.info(f"Environment {env_id}: {mean_reward} ± {std_reward}")
111
- return results
112
-
113
-
114
  def _backend_routine():
115
  # List only the text classification models
116
  rl_models = list(API.list_models(filter="reinforcement-learning"))
@@ -265,7 +182,7 @@ with gr.Blocks(js=dark_mode_gradio_js) as demo:
265
 
266
 
267
  scheduler = BackgroundScheduler()
268
- scheduler.add_job(func=backend_routine, trigger="interval", seconds=5 * 60)
269
  scheduler.start()
270
 
271
 
 
 
1
  import glob
2
  import json
 
3
  import os
4
  import pprint
5
 
6
  import gradio as gr
 
 
7
  import pandas as pd
 
8
  from apscheduler.schedulers.background import BackgroundScheduler
9
+ from huggingface_hub import snapshot_download
 
10
 
11
  from src.css_html_js import dark_mode_gradio_js
12
  from src.envs import API, RESULTS_PATH, RESULTS_REPO, TOKEN
13
+ from src.evaluation import ALL_ENV_IDS, evaluate
14
  from src.logging import configure_root_logger, setup_logger
15
 
 
 
 
16
  configure_root_logger()
17
  logger = setup_logger(__name__)
18
 
19
  pp = pprint.PrettyPrinter(width=80)
20
 
21
 
 
 
 
 
 
 
 
 
22
  def model_hyperlink(link, model_id):
23
  return f'<a target="_blank" href="{link}" style="color: var(--link-text-color); text-decoration: underline;text-decoration-style: dotted;">{model_id}</a>'
24
 
 
28
  return model_hyperlink(link, model_id)
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def _backend_routine():
32
  # List only the text classification models
33
  rl_models = list(API.list_models(filter="reinforcement-learning"))
 
182
 
183
 
184
  scheduler = BackgroundScheduler()
185
+ scheduler.add_job(func=backend_routine, trigger="interval", seconds=0.5 * 60)
186
  scheduler.start()
187
 
188
 
src/evaluation.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import fnmatch
2
+ from typing import Dict, SupportsFloat
3
+
4
+ import gymnasium as gym
5
+ import numpy as np
6
+ import torch
7
+ from gymnasium import wrappers
8
+ from huggingface_hub import hf_hub_download
9
+ from huggingface_hub.utils._errors import EntryNotFoundError
10
+
11
+ from src.envs import API
12
+ from src.logging import setup_logger
13
+
14
+ logger = setup_logger(__name__)
15
+
16
+
17
+ ALL_ENV_IDS = [
18
+ "CartPole-v1",
19
+ "MountainCar-v0",
20
+ "Acrobot-v1",
21
+ "Hopper-v4",
22
+ "MsPacmanNoFrameskip-v4",
23
+ ]
24
+
25
+
26
+ class NoopResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
27
+ """
28
+ Sample initial states by taking random number of no-ops on reset.
29
+ No-op is assumed to be action 0.
30
+
31
+ :param env: Environment to wrap
32
+ :param noop_max: Maximum value of no-ops to run
33
+ """
34
+
35
+ def __init__(self, env: gym.Env, noop_max: int = 30) -> None:
36
+ super().__init__(env)
37
+ self.noop_max = noop_max
38
+ self.override_num_noops = None
39
+ self.noop_action = 0
40
+ assert env.unwrapped.get_action_meanings()[0] == "NOOP" # type: ignore[attr-defined]
41
+
42
+ def reset(self, **kwargs):
43
+ self.env.reset(**kwargs)
44
+ if self.override_num_noops is not None:
45
+ noops = self.override_num_noops
46
+ else:
47
+ noops = self.unwrapped.np_random.integers(1, self.noop_max + 1)
48
+ assert noops > 0
49
+ obs = np.zeros(0)
50
+ info: Dict = {}
51
+ for _ in range(noops):
52
+ obs, _, terminated, truncated, info = self.env.step(self.noop_action)
53
+ if terminated or truncated:
54
+ obs, info = self.env.reset(**kwargs)
55
+ return obs, info
56
+
57
+
58
+ class FireResetEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
59
+ """
60
+ Take action on reset for environments that are fixed until firing.
61
+
62
+ :param env: Environment to wrap
63
+ """
64
+
65
+ def __init__(self, env: gym.Env) -> None:
66
+ super().__init__(env)
67
+ assert env.unwrapped.get_action_meanings()[1] == "FIRE" # type: ignore[attr-defined]
68
+ assert len(env.unwrapped.get_action_meanings()) >= 3 # type: ignore[attr-defined]
69
+
70
+ def reset(self, **kwargs):
71
+ self.env.reset(**kwargs)
72
+ obs, _, terminated, truncated, _ = self.env.step(1)
73
+ if terminated or truncated:
74
+ self.env.reset(**kwargs)
75
+ obs, _, terminated, truncated, _ = self.env.step(2)
76
+ if terminated or truncated:
77
+ self.env.reset(**kwargs)
78
+ return obs, {}
79
+
80
+
81
+ class EpisodicLifeEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
82
+ """
83
+ Make end-of-life == end-of-episode, but only reset on true game over.
84
+ Done by DeepMind for the DQN and co. since it helps value estimation.
85
+
86
+ :param env: Environment to wrap
87
+ """
88
+
89
+ def __init__(self, env: gym.Env) -> None:
90
+ super().__init__(env)
91
+ self.lives = 0
92
+ self.was_real_done = True
93
+
94
+ def step(self, action: int):
95
+ obs, reward, terminated, truncated, info = self.env.step(action)
96
+ self.was_real_done = terminated or truncated
97
+ # check current lives, make loss of life terminal,
98
+ # then update lives to handle bonus lives
99
+ lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
100
+ if 0 < lives < self.lives:
101
+ # for Qbert sometimes we stay in lives == 0 condition for a few frames
102
+ # so its important to keep lives > 0, so that we only reset once
103
+ # the environment advertises done.
104
+ terminated = True
105
+ self.lives = lives
106
+ return obs, reward, terminated, truncated, info
107
+
108
+ def reset(self, **kwargs):
109
+ """
110
+ Calls the Gym environment reset, only when lives are exhausted.
111
+ This way all states are still reachable even though lives are episodic,
112
+ and the learner need not know about any of this behind-the-scenes.
113
+
114
+ :param kwargs: Extra keywords passed to env.reset() call
115
+ :return: the first observation of the environment
116
+ """
117
+ if self.was_real_done:
118
+ obs, info = self.env.reset(**kwargs)
119
+ else:
120
+ # no-op step to advance from terminal/lost life state
121
+ obs, _, terminated, truncated, info = self.env.step(0)
122
+
123
+ # The no-op step can lead to a game over, so we need to check it again
124
+ # to see if we should reset the environment and avoid the
125
+ # monitor.py `RuntimeError: Tried to step environment that needs reset`
126
+ if terminated or truncated:
127
+ obs, info = self.env.reset(**kwargs)
128
+ self.lives = self.env.unwrapped.ale.lives() # type: ignore[attr-defined]
129
+ return obs, info
130
+
131
+
132
+ class MaxAndSkipEnv(gym.Wrapper[np.ndarray, int, np.ndarray, int]):
133
+ """
134
+ Return only every ``skip``-th frame (frameskipping)
135
+ and return the max between the two last frames.
136
+
137
+ :param env: Environment to wrap
138
+ :param skip: Number of ``skip``-th frame
139
+ The same action will be taken ``skip`` times.
140
+ """
141
+
142
+ def __init__(self, env: gym.Env, skip: int = 4) -> None:
143
+ super().__init__(env)
144
+ # most recent raw observations (for max pooling across time steps)
145
+ assert env.observation_space.dtype is not None, "No dtype specified for the observation space"
146
+ assert env.observation_space.shape is not None, "No shape defined for the observation space"
147
+ self._obs_buffer = np.zeros((2, *env.observation_space.shape), dtype=env.observation_space.dtype)
148
+ self._skip = skip
149
+
150
+ def step(self, action: int):
151
+ """
152
+ Step the environment with the given action
153
+ Repeat action, sum reward, and max over last observations.
154
+
155
+ :param action: the action
156
+ :return: observation, reward, terminated, truncated, information
157
+ """
158
+ total_reward = 0.0
159
+ terminated = truncated = False
160
+ for i in range(self._skip):
161
+ obs, reward, terminated, truncated, info = self.env.step(action)
162
+ done = terminated or truncated
163
+ if i == self._skip - 2:
164
+ self._obs_buffer[0] = obs
165
+ if i == self._skip - 1:
166
+ self._obs_buffer[1] = obs
167
+ total_reward += float(reward)
168
+ if done:
169
+ break
170
+ # Note that the observation on the done=True frame
171
+ # doesn't matter
172
+ max_frame = self._obs_buffer.max(axis=0)
173
+
174
+ return max_frame, total_reward, terminated, truncated, info
175
+
176
+
177
+ class ClipRewardEnv(gym.RewardWrapper):
178
+ """
179
+ Clip the reward to {+1, 0, -1} by its sign.
180
+
181
+ :param env: Environment to wrap
182
+ """
183
+
184
+ def __init__(self, env: gym.Env) -> None:
185
+ super().__init__(env)
186
+
187
+ def reward(self, reward: SupportsFloat) -> float:
188
+ """
189
+ Bin reward to {+1, 0, -1} by its sign.
190
+
191
+ :param reward:
192
+ :return:
193
+ """
194
+ return np.sign(float(reward))
195
+
196
+
197
+ def make(env_id):
198
+ def thunk():
199
+ env = gym.make(env_id)
200
+ env = wrappers.RecordEpisodeStatistics(env)
201
+ if "NoFrameskip" in env_id:
202
+ env = NoopResetEnv(env, noop_max=30)
203
+ env = MaxAndSkipEnv(env, skip=4)
204
+ env = EpisodicLifeEnv(env)
205
+ if "FIRE" in env.unwrapped.get_action_meanings():
206
+ env = FireResetEnv(env)
207
+ env = ClipRewardEnv(env)
208
+ env = wrappers.ResizeObservation(env, (84, 84))
209
+ env = wrappers.GrayScaleObservation(env)
210
+ env = wrappers.FrameStack(env, 4)
211
+ return env
212
+
213
+ return thunk
214
+
215
+
216
+ def pattern_match(patterns, source_list):
217
+ if isinstance(patterns, str):
218
+ patterns = [patterns]
219
+
220
+ env_ids = set()
221
+ for pattern in patterns:
222
+ for matching in fnmatch.filter(source_list, pattern):
223
+ env_ids.add(matching)
224
+ return sorted(list(env_ids))
225
+
226
+
227
+ def evaluate(model_id, revision):
228
+ tags = API.model_info(model_id, revision=revision).tags
229
+
230
+ # Extract the environment IDs from the tags (usually only one)
231
+ env_ids = pattern_match(tags, ALL_ENV_IDS)
232
+ logger.info(f"Selected environments: {env_ids}")
233
+
234
+ results = {}
235
+
236
+ # Check if the agent exists
237
+ try:
238
+ agent_path = hf_hub_download(repo_id=model_id, filename="agent.pt")
239
+ except EntryNotFoundError:
240
+ logger.error("Agent not found")
241
+ return None
242
+
243
+ # Check safety
244
+ security = next(iter(API.get_paths_info(model_id, "agent.pt", expand=True))).security
245
+ if security is None or "safe" not in security:
246
+ logger.error("Agent safety not available")
247
+ return None
248
+ elif not security["safe"]:
249
+ logger.error("Agent not safe")
250
+ return None
251
+
252
+ # Load the agent
253
+ try:
254
+ agent = torch.jit.load(agent_path)
255
+ except Exception as e:
256
+ logger.error(f"Error loading agent: {e}")
257
+ return None
258
+
259
+ # Evaluate the agent on the environments
260
+ for env_id in env_ids:
261
+ envs = gym.vector.SyncVectorEnv([make(env_id) for _ in range(3)])
262
+ observations, _ = envs.reset()
263
+ episodic_returns = []
264
+ while len(episodic_returns) < 10:
265
+ actions = agent(torch.tensor(observations)).numpy()
266
+ observations, _, _, _, infos = envs.step(actions)
267
+ if "final_info" in infos:
268
+ for info in infos["final_info"]:
269
+ if info is None or "episode" not in info:
270
+ continue
271
+ episodic_returns.append(info["episode"]["r"])
272
+
273
+ mean_reward = float(np.mean(episodic_returns))
274
+ std_reward = float(np.std(episodic_returns))
275
+ results[env_id] = {"episodic_return_mean": mean_reward, "episodic_reward_std": std_reward}
276
+ logger.info(f"Environment {env_id}: {mean_reward} ± {std_reward}")
277
+ return results