Spaces:
Sleeping
Sleeping
| # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Computes the RL Reliability Metrics.""" | |
| import datasets | |
| import numpy as np | |
| from rl_reliability_metrics.evaluation import eval_metrics | |
| from rl_reliability_metrics.metrics import metrics_offline, metrics_online | |
| import evaluate | |
| logger = evaluate.logging.get_logger(__name__) | |
| DEFAULT_EVAL_POINTS = [ | |
| 50000, | |
| 150000, | |
| 250000, | |
| 350000, | |
| 450000, | |
| 550000, | |
| 650000, | |
| 750000, | |
| 850000, | |
| 950000, | |
| 1050000, | |
| 1150000, | |
| 1250000, | |
| 1350000, | |
| 1450000, | |
| 1550000, | |
| 1650000, | |
| 1750000, | |
| 1850000, | |
| 1950000, | |
| ] | |
| N_RUNS_RECOMMENDED = 10 | |
| _CITATION = """\ | |
| @conference{rl_reliability_metrics, | |
| title = {Measuring the Reliability of Reinforcement Learning Algorithms}, | |
| author = {Stephanie CY Chan, Sam Fishman, John Canny, Anoop Korattikara, and Sergio Guadarrama}, | |
| booktitle = {International Conference on Learning Representations, Addis Ababa, Ethiopia}, | |
| year = 2020, | |
| } | |
| """ | |
| _DESCRIPTION = """\ | |
| Computes the RL reliability metrics from a set of experiments. There is an `"online"` and `"offline"` configuration for evaluation. | |
| """ | |
| _KWARGS_DESCRIPTION = """ | |
| Computes the RL reliability metrics from a set of experiments. There is an `"online"` and `"offline"` configuration for evaluation. | |
| Args: | |
| timestamps: list of timestep lists/arrays that serve as index. | |
| rewards: list of reward lists/arrays of each experiment. | |
| Returns: | |
| dictionary: a set of reliability metrics | |
| Examples: | |
| >>> import numpy as np | |
| >>> rl_reliability = evaluate.load("rl_reliability", "online") | |
| >>> results = rl_reliability.compute( | |
| ... timesteps=[np.linspace(0, 2000000, 1000, dtype=np.int64())], | |
| ... rewards=[np.linspace(0, 100, 1000)] | |
| ... ) | |
| >>> print(results["LowerCVaROnRaw"].round(4)) | |
| [0.0258] | |
| """ | |
| class RLReliability(evaluate.Metric): | |
| """Computes the RL Reliability Metrics.""" | |
| def _info(self): | |
| if self.config_name not in ["online", "offline"]: | |
| raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""") | |
| return evaluate.MetricInfo( | |
| module_type="metric", | |
| description=_DESCRIPTION, | |
| citation=_CITATION, | |
| inputs_description=_KWARGS_DESCRIPTION, | |
| features=datasets.Features( | |
| { | |
| "timesteps": datasets.Sequence(datasets.Value("int64")), | |
| "rewards": datasets.Sequence(datasets.Value("float")), | |
| } | |
| ), | |
| homepage="https://github.com/google-research/rl-reliability-metrics", | |
| ) | |
| def _compute( | |
| self, | |
| timesteps, | |
| rewards, | |
| baseline="default", | |
| freq_thresh=0.01, | |
| window_size=100000, | |
| window_size_trimmed=99000, | |
| alpha=0.05, | |
| eval_points=None, | |
| ): | |
| if len(timesteps) < N_RUNS_RECOMMENDED: | |
| logger.warning( | |
| f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}." | |
| ) | |
| curves = [] | |
| for timestep, reward in zip(timesteps, rewards): | |
| curves.append(np.stack([timestep, reward])) | |
| if self.config_name == "online": | |
| if baseline == "default": | |
| baseline = "curve_range" | |
| if eval_points is None: | |
| eval_points = DEFAULT_EVAL_POINTS | |
| metrics = [ | |
| metrics_online.HighFreqEnergyWithinRuns(thresh=freq_thresh), | |
| metrics_online.IqrWithinRuns( | |
| window_size=window_size_trimmed, eval_points=eval_points, baseline=baseline | |
| ), | |
| metrics_online.IqrAcrossRuns( | |
| lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline | |
| ), | |
| metrics_online.LowerCVaROnDiffs(baseline=baseline), | |
| metrics_online.LowerCVaROnDrawdown(baseline=baseline), | |
| metrics_online.LowerCVaROnAcross( | |
| lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline | |
| ), | |
| metrics_online.LowerCVaROnRaw(alpha=alpha, baseline=baseline), | |
| metrics_online.MadAcrossRuns( | |
| lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline | |
| ), | |
| metrics_online.MadWithinRuns( | |
| eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline | |
| ), | |
| metrics_online.MaxDrawdown(), | |
| metrics_online.StddevAcrossRuns( | |
| lowpass_thresh=freq_thresh, eval_points=eval_points, window_size=window_size, baseline=baseline | |
| ), | |
| metrics_online.StddevWithinRuns( | |
| eval_points=eval_points, window_size=window_size_trimmed, baseline=baseline | |
| ), | |
| metrics_online.UpperCVaROnAcross( | |
| alpha=alpha, | |
| lowpass_thresh=freq_thresh, | |
| eval_points=eval_points, | |
| window_size=window_size, | |
| baseline=baseline, | |
| ), | |
| metrics_online.UpperCVaROnDiffs(alpha=alpha, baseline=baseline), | |
| metrics_online.UpperCVaROnDrawdown(alpha=alpha, baseline=baseline), | |
| metrics_online.UpperCVaROnRaw(alpha=alpha, baseline=baseline), | |
| metrics_online.MedianPerfDuringTraining(window_size=window_size, eval_points=eval_points), | |
| ] | |
| else: | |
| if baseline == "default": | |
| baseline = "median_perf" | |
| metrics = [ | |
| metrics_offline.MadAcrossRollouts(baseline=baseline), | |
| metrics_offline.IqrAcrossRollouts(baseline=baseline), | |
| metrics_offline.StddevAcrossRollouts(baseline=baseline), | |
| metrics_offline.LowerCVaRAcrossRollouts(alpha=alpha, baseline=baseline), | |
| metrics_offline.UpperCVaRAcrossRollouts(alpha=alpha, baseline=baseline), | |
| metrics_offline.MedianPerfAcrossRollouts(baseline=None), | |
| ] | |
| evaluator = eval_metrics.Evaluator(metrics=metrics) | |
| result = evaluator.compute_metrics(curves) | |
| return result | |