lvwerra HF staff commited on
Commit
c0f1666
1 Parent(s): 289642b

Update Space (evaluate main: c447fc8e)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. rl_reliability.py +7 -28
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  git+https://github.com/google-research/rl-reliability-metrics
3
  scipy
4
  tensorflow
1
+ git+https://github.com/huggingface/evaluate@c447fc8eda9c62af501bfdc6988919571050d950
2
  git+https://github.com/google-research/rl-reliability-metrics
3
  scipy
4
  tensorflow
rl_reliability.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  """Computes the RL Reliability Metrics."""
15
 
16
- from dataclasses import dataclass
17
- from typing import List, Optional
18
-
19
  import datasets
20
  import numpy as np
21
  from rl_reliability_metrics.evaluation import eval_metrics
@@ -84,27 +81,11 @@ Examples:
84
  """
85
 
86
 
87
- @dataclass
88
- class RLReliabilityConfig(evaluate.info.Config):
89
-
90
- name: str = "default"
91
-
92
- baseline: str = "default"
93
- freq_thresh: float = 0.01
94
- window_size: int = 100000
95
- window_size_trimmed: int = 99000
96
- alpha: float = 0.05
97
- eval_points: Optional[List] = None
98
-
99
-
100
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
101
  class RLReliability(evaluate.Metric):
102
  """Computes the RL Reliability Metrics."""
103
 
104
- CONFIG_CLASS = RLReliabilityConfig
105
- ALLOWED_CONFIG_NAMES = ["online", "offline"]
106
-
107
- def _info(self, config):
108
  if self.config_name not in ["online", "offline"]:
109
  raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
110
 
@@ -113,7 +94,6 @@ class RLReliability(evaluate.Metric):
113
  description=_DESCRIPTION,
114
  citation=_CITATION,
115
  inputs_description=_KWARGS_DESCRIPTION,
116
- config=config,
117
  features=datasets.Features(
118
  {
119
  "timesteps": datasets.Sequence(datasets.Value("int64")),
@@ -127,19 +107,18 @@ class RLReliability(evaluate.Metric):
127
  self,
128
  timesteps,
129
  rewards,
 
 
 
 
 
 
130
  ):
131
  if len(timesteps) < N_RUNS_RECOMMENDED:
132
  logger.warning(
133
  f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
134
  )
135
 
136
- baseline = self.config.baseline
137
- freq_thresh = self.config.freq_thresh
138
- window_size = self.config.window_size
139
- window_size_trimmed = self.config.window_size_trimmed
140
- alpha = self.config.alpha
141
- eval_points = self.config.eval_points
142
-
143
  curves = []
144
  for timestep, reward in zip(timesteps, rewards):
145
  curves.append(np.stack([timestep, reward]))
13
  # limitations under the License.
14
  """Computes the RL Reliability Metrics."""
15
 
 
 
 
16
  import datasets
17
  import numpy as np
18
  from rl_reliability_metrics.evaluation import eval_metrics
81
  """
82
 
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
85
  class RLReliability(evaluate.Metric):
86
  """Computes the RL Reliability Metrics."""
87
 
88
+ def _info(self):
 
 
 
89
  if self.config_name not in ["online", "offline"]:
90
  raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
91
 
94
  description=_DESCRIPTION,
95
  citation=_CITATION,
96
  inputs_description=_KWARGS_DESCRIPTION,
 
97
  features=datasets.Features(
98
  {
99
  "timesteps": datasets.Sequence(datasets.Value("int64")),
107
  self,
108
  timesteps,
109
  rewards,
110
+ baseline="default",
111
+ freq_thresh=0.01,
112
+ window_size=100000,
113
+ window_size_trimmed=99000,
114
+ alpha=0.05,
115
+ eval_points=None,
116
  ):
117
  if len(timesteps) < N_RUNS_RECOMMENDED:
118
  logger.warning(
119
  f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
120
  )
121
 
 
 
 
 
 
 
 
122
  curves = []
123
  for timestep, reward in zip(timesteps, rewards):
124
  curves.append(np.stack([timestep, reward]))