lvwerra HF staff commited on
Commit
8a2f5ea
1 Parent(s): 2484162

Update Space (evaluate main: 1c421923)

Browse files
Files changed (4) hide show
  1. README.md +109 -6
  2. app.py +6 -0
  3. requirements.txt +7 -0
  4. rl_reliability.py +186 -0
README.md CHANGED
@@ -1,12 +1,115 @@
1
  ---
2
- title: Rl_reliability
3
- emoji: 👀
4
- colorFrom: indigo
5
- colorTo: yellow
 
 
6
  sdk: gradio
7
- sdk_version: 3.0.9
8
  app_file: app.py
9
  pinned: false
10
  ---
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
+ title: RL Reliability
3
+ datasets:
4
+ -
5
+ tags:
6
+ - evaluate
7
+ - metric
8
  sdk: gradio
9
+ sdk_version: 3.0.2
10
  app_file: app.py
11
  pinned: false
12
  ---
13
 
14
+ # Metric Card for RL Reliability
15
+
16
+ ## Metric Description
17
+ The RL Reliability Metrics library provides a set of metrics for measuring the reliability of reinforcement learning (RL) algorithms.
18
+
19
+ ## How to Use
20
+
21
+ ```python
22
+ import evaluate
23
+ import numpy as np
24
+
25
+ rl_reliability = evaluate.load("rl_reliability", "online")
26
+ results = rl_reliability.compute(
27
+ timesteps=[np.linspace(0, 2000000, 1000)],
28
+ rewards=[np.linspace(0, 100, 1000)]
29
+ )
30
+
31
+ rl_reliability = evaluate.load("rl_reliability", "offline")
32
+ results = rl_reliability.compute(
33
+ timesteps=[np.linspace(0, 2000000, 1000)],
34
+ rewards=[np.linspace(0, 100, 1000)]
35
+ )
36
+ ```
37
+
38
+
39
+ ### Inputs
40
+ - **timesteps** *(List[int]): For each run a an list/array with its timesteps.*
41
+ - **rewards** *(List[float]): For each run a an list/array with its rewards.*
42
+
43
+ KWARGS:
44
+ - **baseline="default"** *(Union[str, float]) Normalization used for curves. When `"default"` is passed the curves are normalized by their range in the online setting and by the median performance across runs in the offline case. When a float is passed the curves are divided by that value.*
45
+ - **eval_points=[50000, 150000, ..., 2000000]** *(List[int]) Statistics will be computed at these points*
46
+ - **freq_thresh=0.01** *(float) Frequency threshold for low-pass filtering.*
47
+ - **window_size=100000** *(int) Defines a window centered at each eval point.*
48
+ - **window_size_trimmed=99000** *(int) To handle shortened curves due to differencing*
49
+ - **alpha=0.05** *(float)The "value at risk" (VaR) cutoff point, a float in the range [0,1].*
50
+
51
+ ### Output Values
52
+
53
+ In `"online"` mode:
54
+ - HighFreqEnergyWithinRuns: High Frequency across Time (DT)
55
+ - IqrWithinRuns: IQR across Time (DT)
56
+ - MadWithinRuns: 'MAD across Time (DT)
57
+ - StddevWithinRuns: Stddev across Time (DT)
58
+ - LowerCVaROnDiffs: Lower CVaR on Differences (SRT)
59
+ - UpperCVaROnDiffs: Upper CVaR on Differences (SRT)
60
+ - MaxDrawdown: Max Drawdown (LRT)
61
+ - LowerCVaROnDrawdown: Lower CVaR on Drawdown (LRT)
62
+ - UpperCVaROnDrawdown: Upper CVaR on Drawdown (LRT)
63
+ - LowerCVaROnRaw: Lower CVaR on Raw
64
+ - UpperCVaROnRaw: Upper CVaR on Raw
65
+ - IqrAcrossRuns: IQR across Runs (DR)
66
+ - MadAcrossRuns: MAD across Runs (DR)
67
+ - StddevAcrossRuns: Stddev across Runs (DR)
68
+ - LowerCVaROnAcross: Lower CVaR across Runs (RR)
69
+ - UpperCVaROnAcross: Upper CVaR across Runs (RR)
70
+ - MedianPerfDuringTraining: Median Performance across Runs
71
+
72
+ In `"offline"` mode:
73
+ - MadAcrossRollouts: MAD across rollouts (DF)
74
+ - IqrAcrossRollouts: IQR across rollouts (DF)
75
+ - LowerCVaRAcrossRollouts: Lower CVaR across rollouts (RF)
76
+ - UpperCVaRAcrossRollouts: Upper CVaR across rollouts (RF)
77
+ - MedianPerfAcrossRollouts: Median Performance across rollouts
78
+
79
+
80
+ ### Examples
81
+ First get the sample data from the repository:
82
+
83
+ ```bash
84
+ wget https://storage.googleapis.com/rl-reliability-metrics/data/tf_agents_example_csv_dataset.tgz
85
+ tar -xvzf tf_agents_example_csv_dataset.tgz
86
+ ```
87
+
88
+ Load the sample data:
89
+ ```python
90
+ dfs = [pd.read_csv(f"./csv_data/sac_humanoid_{i}_train.csv") for i in range(1, 4)]
91
+ ```
92
+
93
+ Compute the metrics:
94
+ ```python
95
+ rl_reliability = evaluate.load("rl_reliability", "online")
96
+ rl_reliability.compute(timesteps=[df["Metrics/EnvironmentSteps"] for df in dfs],
97
+ rewards=[df["Metrics/AverageReturn"] for df in dfs])
98
+ ```
99
+
100
+ ## Limitations and Bias
101
+ This implementation of RL reliability metrics does not compute permutation tests to determine whether algorithms are statistically different in their metric values and also does not compute bootstrap confidence intervals on the rankings of the algorithms. See the [original library](https://github.com/google-research/rl-reliability-metrics/) for more resources.
102
+
103
+ ## Citation
104
+
105
+ ```bibtex
106
+ @conference{rl_reliability_metrics,
107
+ title = {Measuring the Reliability of Reinforcement Learning Algorithms},
108
+ author = {Stephanie CY Chan, Sam Fishman, John Canny, Anoop Korattikara, and Sergio Guadarrama},
109
+ booktitle = {International Conference on Learning Representations, Addis Ababa, Ethiopia},
110
+ year = 2020,
111
+ }
112
+ ```
113
+
114
+ ## Further References
115
+ - Homepage: https://github.com/google-research/rl-reliability-metrics
app.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
1
+ import evaluate
2
+ from evaluate.utils import launch_gradio_widget
3
+
4
+
5
+ module = evaluate.load("rl_reliability", "online")
6
+ launch_gradio_widget(module)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ # TODO: fix github to release
2
+ git+https://github.com/huggingface/evaluate.git@main
3
+ datasets~=2.0
4
+ git+https://github.com/google-research/rl-reliability-metrics
5
+ scipy
6
+ tensorflow
7
+ gin-config
rl_reliability.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Computes the RL Reliability Metrics."""
15
+
16
+ import datasets
17
+ import numpy as np
18
+ from rl_reliability_metrics.evaluation import eval_metrics
19
+ from rl_reliability_metrics.metrics import metrics_offline, metrics_online
20
+
21
+ import evaluate
22
+
23
+
24
+ logger = evaluate.logging.get_logger(__name__)
25
+
26
+ DEFAULT_EVAL_POINTS = [
27
+ 50000,
28
+ 150000,
29
+ 250000,
30
+ 350000,
31
+ 450000,
32
+ 550000,
33
+ 650000,
34
+ 750000,
35
+ 850000,
36
+ 950000,
37
+ 1050000,
38
+ 1150000,
39
+ 1250000,
40
+ 1350000,
41
+ 1450000,
42
+ 1550000,
43
+ 1650000,
44
+ 1750000,
45
+ 1850000,
46
+ 1950000,
47
+ ]
48
+
49
+ N_RUNS_RECOMMENDED = 10
50
+
51
+ _CITATION = """\
52
+ @conference{rl_reliability_metrics,
53
+ title = {Measuring the Reliability of Reinforcement Learning Algorithms},
54
+ author = {Stephanie CY Chan, Sam Fishman, John Canny, Anoop Korattikara, and Sergio Guadarrama},
55
+ booktitle = {International Conference on Learning Representations, Addis Ababa, Ethiopia},
56
+ year = 2020,
57
+ }
58
+ """
59
+
60
+ _DESCRIPTION = """\
61
+ This new module is designed to solve this great NLP task and is crafted with a lot of care.
62
+ """
63
+
64
+
65
+ _KWARGS_DESCRIPTION = """
66
+ Computes the RL reliability metrics from a set of experiments. There is an `"online"` and `"offline"` configuration for evaluation.
67
+ Args:
68
+ timestamps: list of timestep lists/arrays that serve as index.
69
+ rewards: list of reward lists/arrays of each experiment.
70
+ Returns:
71
+ dictionary: a set of reliability metrics
72
+ Examples:
73
+ >>> import numpy as np
74
+ >>> rl_reliability = evaluate.load("rl_reliability", "online")
75
+ >>> results = rl_reliability.compute(
76
+ ... timesteps=[np.linspace(0, 2000000, 1000)],
77
+ ... rewards=[np.linspace(0, 100, 1000)]
78
+ ... )
79
+ >>> print(results["LowerCVaROnRaw"].round(4))
80
+ [0.0258]
81
+ """
82
+
83
+
84
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
85
+ class RLReliability(evaluate.EvaluationModule):
86
+ """Computes the RL Reliability Metrics."""
87
+
88
+ def _info(self):
89
+ if self.config_name not in ["online", "offline"]:
90
+ raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
91
+
92
+ return evaluate.EvaluationModuleInfo(
93
+ module_type="metric",
94
+ description=_DESCRIPTION,
95
+ citation=_CITATION,
96
+ inputs_description=_KWARGS_DESCRIPTION,
97
+ features=datasets.Features(
98
+ {
99
+ "timesteps": datasets.Sequence(datasets.Value("int64")),
100
+ "rewards": datasets.Sequence(datasets.Value("float")),
101
+ }
102
+ ),
103
+ homepage="https://github.com/google-research/rl-reliability-metrics",
104
+ )
105
+
106
+ def _compute(
107
+ self,
108
+ timesteps,
109
+ rewards,
110
+ baseline="default",
111
+ freq_thresh=0.01,
112
+ window_size=100000,
113
+ window_size_trimmed=99000,
114
+ alpha=0.05,
115
+ eval_points=None,
116
+ ):
117
+ if len(timesteps) < N_RUNS_RECOMMENDED:
118
+ logger.warning(
119
+ f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
120
+ )
121
+
122
+ curves = []
123
+ for timestep, reward in zip(timesteps, rewards):
124
+ curves.append(np.stack([timestep, reward]))
125
+
126
+ if self.config_name == "online":
127
+ if baseline == "default":
128
+ baseline = "curve_range"
129
+ if eval_points is None:
130
+ eval_points = DEFAULT_EVAL_POINTS
131
+
132
+ metrics = [
133
+ metrics_online.HighFreqEnergyWithinRuns(thresh=freq_thresh),
134
+ metrics_online.IqrWithinRuns(
135
+ window_size=window_size_trimmed, eval_points=eval_points, baseline=baseline
136
+ ),
137
+ metrics_online.IqrAcrossRuns(
138
+ lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
139
+ ),
140
+ metrics_online.LowerCVaROnDiffs(baseline=baseline),
141
+ metrics_online.LowerCVaROnDrawdown(baseline=baseline),
142
+ metrics_online.LowerCVaROnAcross(
143
+ lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
144
+ ),
145
+ metrics_online.LowerCVaROnRaw(alpha=alpha, baseline=baseline),
146
+ metrics_online.MadAcrossRuns(
147
+ lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
148
+ ),
149
+ metrics_online.MadWithinRuns(
150
+ eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline
151
+ ),
152
+ metrics_online.MaxDrawdown(),
153
+ metrics_online.StddevAcrossRuns(
154
+ lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline
155
+ ),
156
+ metrics_online.StddevWithinRuns(
157
+ eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline
158
+ ),
159
+ metrics_online.UpperCVaROnAcross(
160
+ alpha=alpha,
161
+ lowpass_thresh=freq_thresh,
162
+ eval_points=eval_points,
163
+ window_size=window_size,
164
+ baseline=baseline,
165
+ ),
166
+ metrics_online.UpperCVaROnDiffs(alpha=alpha, baseline=baseline),
167
+ metrics_online.UpperCVaROnDrawdown(alpha=alpha, baseline=baseline),
168
+ metrics_online.UpperCVaROnRaw(alpha=alpha, baseline=baseline),
169
+ metrics_online.MedianPerfDuringTraining(window_size=window_size, eval_points=eval_points),
170
+ ]
171
+ else:
172
+ if baseline == "default":
173
+ baseline = "median_perf"
174
+
175
+ metrics = [
176
+ metrics_offline.MadAcrossRollouts(baseline=baseline),
177
+ metrics_offline.IqrAcrossRollouts(baseline=baseline),
178
+ metrics_offline.StddevAcrossRollouts(baseline=baseline),
179
+ metrics_offline.LowerCVaRAcrossRollouts(alpha=alpha, baseline=baseline),
180
+ metrics_offline.UpperCVaRAcrossRollouts(alpha=alpha, baseline=baseline),
181
+ metrics_offline.MedianPerfAcrossRollouts(baseline=None),
182
+ ]
183
+
184
+ evaluator = eval_metrics.Evaluator(metrics=metrics)
185
+ result = evaluator.compute_metrics(curves)
186
+ return result