Update Space (evaluate main: e4a27243)
Browse files- requirements.txt +1 -1
 - rl_reliability.py +28 -7
 
    	
        requirements.txt
    CHANGED
    
    | 
         @@ -1,4 +1,4 @@ 
     | 
|
| 1 | 
         
            -
            git+https://github.com/huggingface/evaluate@ 
     | 
| 2 | 
         
             
            git+https://github.com/google-research/rl-reliability-metrics
         
     | 
| 3 | 
         
             
            scipy
         
     | 
| 4 | 
         
             
            tensorflow
         
     | 
| 
         | 
|
| 1 | 
         
            +
            git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
         
     | 
| 2 | 
         
             
            git+https://github.com/google-research/rl-reliability-metrics
         
     | 
| 3 | 
         
             
            scipy
         
     | 
| 4 | 
         
             
            tensorflow
         
     | 
    	
        rl_reliability.py
    CHANGED
    
    | 
         @@ -13,6 +13,9 @@ 
     | 
|
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         
             
            """Computes the RL Reliability Metrics."""
         
     | 
| 15 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 16 | 
         
             
            import datasets
         
     | 
| 17 | 
         
             
            import numpy as np
         
     | 
| 18 | 
         
             
            from rl_reliability_metrics.evaluation import eval_metrics
         
     | 
| 
         @@ -81,11 +84,27 @@ Examples: 
     | 
|
| 81 | 
         
             
            """
         
     | 
| 82 | 
         | 
| 83 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 84 | 
         
             
            @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
         
     | 
| 85 | 
         
             
            class RLReliability(evaluate.Metric):
         
     | 
| 86 | 
         
             
                """Computes the RL Reliability Metrics."""
         
     | 
| 87 | 
         | 
| 88 | 
         
            -
                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 89 | 
         
             
                    if self.config_name not in ["online", "offline"]:
         
     | 
| 90 | 
         
             
                        raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
         
     | 
| 91 | 
         | 
| 
         @@ -94,6 +113,7 @@ class RLReliability(evaluate.Metric): 
     | 
|
| 94 | 
         
             
                        description=_DESCRIPTION,
         
     | 
| 95 | 
         
             
                        citation=_CITATION,
         
     | 
| 96 | 
         
             
                        inputs_description=_KWARGS_DESCRIPTION,
         
     | 
| 
         | 
|
| 97 | 
         
             
                        features=datasets.Features(
         
     | 
| 98 | 
         
             
                            {
         
     | 
| 99 | 
         
             
                                "timesteps": datasets.Sequence(datasets.Value("int64")),
         
     | 
| 
         @@ -107,18 +127,19 @@ class RLReliability(evaluate.Metric): 
     | 
|
| 107 | 
         
             
                    self,
         
     | 
| 108 | 
         
             
                    timesteps,
         
     | 
| 109 | 
         
             
                    rewards,
         
     | 
| 110 | 
         
            -
                    baseline="default",
         
     | 
| 111 | 
         
            -
                    freq_thresh=0.01,
         
     | 
| 112 | 
         
            -
                    window_size=100000,
         
     | 
| 113 | 
         
            -
                    window_size_trimmed=99000,
         
     | 
| 114 | 
         
            -
                    alpha=0.05,
         
     | 
| 115 | 
         
            -
                    eval_points=None,
         
     | 
| 116 | 
         
             
                ):
         
     | 
| 117 | 
         
             
                    if len(timesteps) < N_RUNS_RECOMMENDED:
         
     | 
| 118 | 
         
             
                        logger.warning(
         
     | 
| 119 | 
         
             
                            f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
         
     | 
| 120 | 
         
             
                        )
         
     | 
| 121 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 122 | 
         
             
                    curves = []
         
     | 
| 123 | 
         
             
                    for timestep, reward in zip(timesteps, rewards):
         
     | 
| 124 | 
         
             
                        curves.append(np.stack([timestep, reward]))
         
     | 
| 
         | 
|
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         
             
            """Computes the RL Reliability Metrics."""
         
     | 
| 15 | 
         | 
| 16 | 
         
            +
            from dataclasses import dataclass
         
     | 
| 17 | 
         
            +
            from typing import List, Optional
         
     | 
| 18 | 
         
            +
             
     | 
| 19 | 
         
             
            import datasets
         
     | 
| 20 | 
         
             
            import numpy as np
         
     | 
| 21 | 
         
             
            from rl_reliability_metrics.evaluation import eval_metrics
         
     | 
| 
         | 
|
| 84 | 
         
             
            """
         
     | 
| 85 | 
         | 
| 86 | 
         | 
| 87 | 
         
            +
            @dataclass
         
     | 
| 88 | 
         
            +
            class RLReliabilityConfig(evaluate.info.Config):
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
                name: str = "default"
         
     | 
| 91 | 
         
            +
             
     | 
| 92 | 
         
            +
                baseline: str = "default"
         
     | 
| 93 | 
         
            +
                freq_thresh: float = 0.01
         
     | 
| 94 | 
         
            +
                window_size: int = 100000
         
     | 
| 95 | 
         
            +
                window_size_trimmed: int = 99000
         
     | 
| 96 | 
         
            +
                alpha: float = 0.05
         
     | 
| 97 | 
         
            +
                eval_points: Optional[List] = None
         
     | 
| 98 | 
         
            +
             
     | 
| 99 | 
         
            +
             
     | 
| 100 | 
         
             
            @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
         
     | 
| 101 | 
         
             
            class RLReliability(evaluate.Metric):
         
     | 
| 102 | 
         
             
                """Computes the RL Reliability Metrics."""
         
     | 
| 103 | 
         | 
| 104 | 
         
            +
                CONFIG_CLASS = RLReliabilityConfig
         
     | 
| 105 | 
         
            +
                ALLOWED_CONFIG_NAMES = ["online", "offline"]
         
     | 
| 106 | 
         
            +
             
     | 
| 107 | 
         
            +
                def _info(self, config):
         
     | 
| 108 | 
         
             
                    if self.config_name not in ["online", "offline"]:
         
     | 
| 109 | 
         
             
                        raise KeyError("""You should supply a configuration name selected in '["online", "offline"]'""")
         
     | 
| 110 | 
         | 
| 
         | 
|
| 113 | 
         
             
                        description=_DESCRIPTION,
         
     | 
| 114 | 
         
             
                        citation=_CITATION,
         
     | 
| 115 | 
         
             
                        inputs_description=_KWARGS_DESCRIPTION,
         
     | 
| 116 | 
         
            +
                        config=config,
         
     | 
| 117 | 
         
             
                        features=datasets.Features(
         
     | 
| 118 | 
         
             
                            {
         
     | 
| 119 | 
         
             
                                "timesteps": datasets.Sequence(datasets.Value("int64")),
         
     | 
| 
         | 
|
| 127 | 
         
             
                    self,
         
     | 
| 128 | 
         
             
                    timesteps,
         
     | 
| 129 | 
         
             
                    rewards,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 130 | 
         
             
                ):
         
     | 
| 131 | 
         
             
                    if len(timesteps) < N_RUNS_RECOMMENDED:
         
     | 
| 132 | 
         
             
                        logger.warning(
         
     | 
| 133 | 
         
             
                            f"For robust statistics it is recommended to use at least {N_RUNS_RECOMMENDED} runs whereas you provided {len(timesteps)}."
         
     | 
| 134 | 
         
             
                        )
         
     | 
| 135 | 
         | 
| 136 | 
         
            +
                    baseline = self.config.baseline
         
     | 
| 137 | 
         
            +
                    freq_thresh = self.config.freq_thresh
         
     | 
| 138 | 
         
            +
                    window_size = self.config.window_size
         
     | 
| 139 | 
         
            +
                    window_size_trimmed = self.config.window_size_trimmed
         
     | 
| 140 | 
         
            +
                    alpha = self.config.alpha
         
     | 
| 141 | 
         
            +
                    eval_points = self.config.eval_points
         
     | 
| 142 | 
         
            +
             
     | 
| 143 | 
         
             
                    curves = []
         
     | 
| 144 | 
         
             
                    for timestep, reward in zip(timesteps, rewards):
         
     | 
| 145 | 
         
             
                        curves.append(np.stack([timestep, reward]))
         
     |