Spaces:
Build error
Build error
File size: 9,485 Bytes
b100e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 |
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""T5X Metrics.
Defines Metric objects and collections used by T5X models. These objects use the
CLU metrics library
"""
import dataclasses
from typing import MutableMapping, Optional, Union
from clu import metrics as clu_metrics
import flax # Only used for flax.struct.dataclass.
import jax
from jax.experimental.global_device_array import GlobalDeviceArray
import jax.numpy as jnp
import numpy as np
MetricsMap = MutableMapping[str, clu_metrics.Metric]
Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray]
def _check_param(value, *, ndim=None, dtype=jnp.float32):
"""Raises a `ValueError` if `value` does not match ndim/dtype.
Args:
value: Value to be tested.
ndim: Expected dimensions.
dtype: Expected dtype.
Raises:
A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value`
is not an instance of `jnp.ndarray`.
"""
if ndim is not None and value.ndim != ndim:
raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}")
if dtype is not None and value.dtype != dtype:
raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}")
@flax.struct.dataclass
class Sum(clu_metrics.Metric):
"""Computes the sum of a scalar or a batch of tensors.
See also documentation of `Metric`.
"""
total: Scalar
@classmethod
def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric:
"""Initializes a Sum Metric from array (or singular) values.
Args:
values: array of values to sum (or a single value).
Returns:
A Sum object.
"""
values = jnp.asarray(values)
if values.ndim == 0:
values = values[None]
return cls(total=values.sum())
def merge(self, other: "Sum") -> "Sum":
return type(self)(total=self.total + other.total)
def compute(self) -> Scalar:
return self.total
@flax.struct.dataclass
class Step(clu_metrics.Metric):
"""Abstract class representing a per-step or step-per metric.
Tracks number of steps. Must be set manually using replace_steps, since the
use of microbatches may otherwise cause the computation to be incorrect.
See also documentation of `Metric`.
"""
steps: Optional[int] = 1
def replace_steps(self, steps) -> "Step":
return self.replace(steps=steps)
def compute(self) -> Scalar:
if self.steps is None:
raise ValueError(
"`steps` must be set by calling `replace_steps` before computing metric."
)
return self.steps
@flax.struct.dataclass
class AveragePerStep(Step):
"""Represents per-step average (total divided by number of steps).
See also documentation of `Step`.
"""
total: Optional[Scalar] = None
@classmethod
def from_model_output(cls,
values: Scalar,
steps: Optional[int] = 1,
**_) -> clu_metrics.Metric:
"""Initializes an AveragePerStep Metric from array (or singular) values.
Args:
values: array of values to sum (or a single value).
steps: number of steps, defaults to 1.
Returns:
AveragePerStep object.
"""
values = jnp.asarray(values)
if values.ndim == 0:
values = values[None]
return cls(total=values.sum(), steps=steps)
def merge(self, other: "AveragePerStep") -> "AveragePerStep":
assert type(self) is type(other)
return type(self)(
total=self.total + other.total, steps=self.steps + other.steps)
def compute(self) -> Scalar:
steps = super().compute()
if self.total is None:
raise ValueError("`AveragePerStep` `total` cannot be None.")
return self.total / steps
@flax.struct.dataclass
class Time(clu_metrics.Metric):
"""Computes the sum of a float-valued metric over a period of time.
Duration (the denominator) must be set manually. This is because JAX does not
properly support time functions inside compiled functions. Calling time.time()
inside a compiled function results in the stored time being the compilation
time, not the run time.
See also documentation of `Metric`.
"""
duration: Optional[Scalar] = None
def merge(self, other: "Time") -> "Time":
return self
def compute(self) -> Scalar:
if self.duration is None:
raise ValueError(
"`Time` `duration` must be set by calling `replace_duration` before computing."
)
return self.duration
def replace_duration(self, duration: Scalar) -> "Time":
"""Replaces duration with the given value.
Should be used outside a compiled function to set the duration of the
metric.
Args:
duration: metric duration
Returns:
A new Time object.
"""
return self.replace(duration=duration)
@flax.struct.dataclass
class TimeRate(Time):
"""Computes the sum of a float-valued metric over a period of time.
Duration (the denominator) must be set using replace_duration. This is because
JAX does not properly support time functions inside compiled functions.
Calling time.time() inside a compiled function results in the stored time
being the compilation time, not the run time.
See also documentation of `Time` and `Metric`.
"""
numerator: Optional[jnp.ndarray] = None
@classmethod
def from_model_output(cls, numerator: float, **_) -> clu_metrics.Metric:
"""Initializes a TimeRate Metric from a float value (the numerator).
Args:
numerator: a float (numerator of the metric)
Returns:
A TimeRate object.
"""
return cls(numerator=numerator)
def merge(self, other: "TimeRate") -> "TimeRate":
assert_msg = "Merging with non-None durations is currently not supported."
assert self.duration is None and other.duration is None, assert_msg
return type(self)(numerator=self.numerator + other.numerator)
def compute(self) -> Scalar:
duration = super().compute()
return self.numerator / duration
def replace_duration(self, duration: Scalar) -> "Time":
if not (isinstance(self.numerator, np.ndarray) or
isinstance(self.numerator, GlobalDeviceArray)):
raise ValueError(
"Expected numerator to be of type np.ndarray or GlobalDeviceArray "
"since method should be called outside of a compiled function. "
"Got ", type(self.numerator))
return super().replace_duration(duration)
@flax.struct.dataclass
class StepsPerTime(Step, Time):
"""Represents a metric computed as number of steps per time.
See also documentation of `Step`.
"""
@classmethod
def from_model_output(cls,
steps: Optional[int] = 1,
**_) -> clu_metrics.Metric:
"""Initializes an StepsPerTime Metric.
Args:
steps: number of steps, defaults to 1.
Returns:
StepsPerTime object.
"""
return cls(steps=steps)
def merge(self, other: "StepsPerTime") -> "StepsPerTime":
assert type(self) is type(other)
return type(self)(steps=self.steps + other.steps)
def compute(self) -> Scalar:
steps = Step.compute(self)
duration = Time.compute(self)
return steps / duration
def is_metric_obj(obj):
return isinstance(obj, clu_metrics.Metric)
def is_time_metric(obj):
return isinstance(obj, Time)
def create_metrics_dict(float_metrics_dict):
"""Input: dict{str: float} | Output: dict{str: Metric}."""
return {k: Sum(v) for k, v in float_metrics_dict.items()}
def shape_obj_to_defined_obj(obj: clu_metrics.Metric):
"""Converts shapes in Metric to zero arrays.
obj should be a Metric object subclass where each member variable is a
ShapeDtypeStruct (from jax.eval_shape). A new object of the same class where
each member variable is an array of zeros with the same shape and type as
the corresponding variable defined by ShapeDtypeStruct.
Args:
obj: a clu.metrics.Metric object where each member variable is a
ShapeDtypeStruct (from jax.eval_shape)
Returns:
A Metric object with class variables initialized as zero arrays.
"""
def class_attr_shape(a):
attr = getattr(obj, a.name)
if isinstance(attr, clu_metrics.Metric):
return shape_obj_to_defined_obj(attr)
else:
if hasattr(attr, "shape"):
return jnp.zeros(shape=attr.shape, dtype=attr.dtype)
else:
return attr
return obj.__class__(
**{a.name: class_attr_shape(a) for a in dataclasses.fields(obj)})
def set_time_metrics_duration(metrics, duration):
"""Sets duration for TimeRate objects in metrics pytree."""
def fn(o):
if isinstance(o, Time):
return o.replace_duration(duration)
else:
return o
return jax.tree_map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time))
def set_step_metrics_num_steps(metrics, num_steps):
"""Sets steps for Step objects in metrics pytree."""
def fn(o):
if isinstance(o, Step):
return o.replace_steps(num_steps)
else:
return o
return jax.tree_map(fn, metrics, is_leaf=is_metric_obj)
|