Spaces:
Build error
Build error
# Copyright 2022 The T5X Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""T5X Metrics. | |
Defines Metric objects and collections used by T5X models. These objects use the | |
CLU metrics library | |
""" | |
import dataclasses | |
from typing import MutableMapping, Optional, Union | |
from clu import metrics as clu_metrics | |
import flax # Only used for flax.struct.dataclass. | |
import jax | |
from jax.experimental.global_device_array import GlobalDeviceArray | |
import jax.numpy as jnp | |
import numpy as np | |
MetricsMap = MutableMapping[str, clu_metrics.Metric] | |
Scalar = Union[int, float, np.number, np.ndarray, jnp.ndarray] | |
def _check_param(value, *, ndim=None, dtype=jnp.float32): | |
"""Raises a `ValueError` if `value` does not match ndim/dtype. | |
Args: | |
value: Value to be tested. | |
ndim: Expected dimensions. | |
dtype: Expected dtype. | |
Raises: | |
A `ValueError` if `value` does not match `ndim` or `dtype`, or if `value` | |
is not an instance of `jnp.ndarray`. | |
""" | |
if ndim is not None and value.ndim != ndim: | |
raise ValueError(f"Expected ndim={ndim}, got ndim={value.ndim}") | |
if dtype is not None and value.dtype != dtype: | |
raise ValueError(f"Expected dtype={dtype}, got dtype={value.dtype}") | |
class Sum(clu_metrics.Metric): | |
"""Computes the sum of a scalar or a batch of tensors. | |
See also documentation of `Metric`. | |
""" | |
total: Scalar | |
def from_model_output(cls, values: Scalar, **_) -> clu_metrics.Metric: | |
"""Initializes a Sum Metric from array (or singular) values. | |
Args: | |
values: array of values to sum (or a single value). | |
Returns: | |
A Sum object. | |
""" | |
values = jnp.asarray(values) | |
if values.ndim == 0: | |
values = values[None] | |
return cls(total=values.sum()) | |
def merge(self, other: "Sum") -> "Sum": | |
return type(self)(total=self.total + other.total) | |
def compute(self) -> Scalar: | |
return self.total | |
class Step(clu_metrics.Metric): | |
"""Abstract class representing a per-step or step-per metric. | |
Tracks number of steps. Must be set manually using replace_steps, since the | |
use of microbatches may otherwise cause the computation to be incorrect. | |
See also documentation of `Metric`. | |
""" | |
steps: Optional[int] = 1 | |
def replace_steps(self, steps) -> "Step": | |
return self.replace(steps=steps) | |
def compute(self) -> Scalar: | |
if self.steps is None: | |
raise ValueError( | |
"`steps` must be set by calling `replace_steps` before computing metric." | |
) | |
return self.steps | |
class AveragePerStep(Step): | |
"""Represents per-step average (total divided by number of steps). | |
See also documentation of `Step`. | |
""" | |
total: Optional[Scalar] = None | |
def from_model_output(cls, | |
values: Scalar, | |
steps: Optional[int] = 1, | |
**_) -> clu_metrics.Metric: | |
"""Initializes an AveragePerStep Metric from array (or singular) values. | |
Args: | |
values: array of values to sum (or a single value). | |
steps: number of steps, defaults to 1. | |
Returns: | |
AveragePerStep object. | |
""" | |
values = jnp.asarray(values) | |
if values.ndim == 0: | |
values = values[None] | |
return cls(total=values.sum(), steps=steps) | |
def merge(self, other: "AveragePerStep") -> "AveragePerStep": | |
assert type(self) is type(other) | |
return type(self)( | |
total=self.total + other.total, steps=self.steps + other.steps) | |
def compute(self) -> Scalar: | |
steps = super().compute() | |
if self.total is None: | |
raise ValueError("`AveragePerStep` `total` cannot be None.") | |
return self.total / steps | |
class Time(clu_metrics.Metric): | |
"""Computes the sum of a float-valued metric over a period of time. | |
Duration (the denominator) must be set manually. This is because JAX does not | |
properly support time functions inside compiled functions. Calling time.time() | |
inside a compiled function results in the stored time being the compilation | |
time, not the run time. | |
See also documentation of `Metric`. | |
""" | |
duration: Optional[Scalar] = None | |
def merge(self, other: "Time") -> "Time": | |
return self | |
def compute(self) -> Scalar: | |
if self.duration is None: | |
raise ValueError( | |
"`Time` `duration` must be set by calling `replace_duration` before computing." | |
) | |
return self.duration | |
def replace_duration(self, duration: Scalar) -> "Time": | |
"""Replaces duration with the given value. | |
Should be used outside a compiled function to set the duration of the | |
metric. | |
Args: | |
duration: metric duration | |
Returns: | |
A new Time object. | |
""" | |
return self.replace(duration=duration) | |
class TimeRate(Time): | |
"""Computes the sum of a float-valued metric over a period of time. | |
Duration (the denominator) must be set using replace_duration. This is because | |
JAX does not properly support time functions inside compiled functions. | |
Calling time.time() inside a compiled function results in the stored time | |
being the compilation time, not the run time. | |
See also documentation of `Time` and `Metric`. | |
""" | |
numerator: Optional[jnp.ndarray] = None | |
def from_model_output(cls, numerator: float, **_) -> clu_metrics.Metric: | |
"""Initializes a TimeRate Metric from a float value (the numerator). | |
Args: | |
numerator: a float (numerator of the metric) | |
Returns: | |
A TimeRate object. | |
""" | |
return cls(numerator=numerator) | |
def merge(self, other: "TimeRate") -> "TimeRate": | |
assert_msg = "Merging with non-None durations is currently not supported." | |
assert self.duration is None and other.duration is None, assert_msg | |
return type(self)(numerator=self.numerator + other.numerator) | |
def compute(self) -> Scalar: | |
duration = super().compute() | |
return self.numerator / duration | |
def replace_duration(self, duration: Scalar) -> "Time": | |
if not (isinstance(self.numerator, np.ndarray) or | |
isinstance(self.numerator, GlobalDeviceArray)): | |
raise ValueError( | |
"Expected numerator to be of type np.ndarray or GlobalDeviceArray " | |
"since method should be called outside of a compiled function. " | |
"Got ", type(self.numerator)) | |
return super().replace_duration(duration) | |
class StepsPerTime(Step, Time): | |
"""Represents a metric computed as number of steps per time. | |
See also documentation of `Step`. | |
""" | |
def from_model_output(cls, | |
steps: Optional[int] = 1, | |
**_) -> clu_metrics.Metric: | |
"""Initializes an StepsPerTime Metric. | |
Args: | |
steps: number of steps, defaults to 1. | |
Returns: | |
StepsPerTime object. | |
""" | |
return cls(steps=steps) | |
def merge(self, other: "StepsPerTime") -> "StepsPerTime": | |
assert type(self) is type(other) | |
return type(self)(steps=self.steps + other.steps) | |
def compute(self) -> Scalar: | |
steps = Step.compute(self) | |
duration = Time.compute(self) | |
return steps / duration | |
def is_metric_obj(obj): | |
return isinstance(obj, clu_metrics.Metric) | |
def is_time_metric(obj): | |
return isinstance(obj, Time) | |
def create_metrics_dict(float_metrics_dict): | |
"""Input: dict{str: float} | Output: dict{str: Metric}.""" | |
return {k: Sum(v) for k, v in float_metrics_dict.items()} | |
def shape_obj_to_defined_obj(obj: clu_metrics.Metric): | |
"""Converts shapes in Metric to zero arrays. | |
obj should be a Metric object subclass where each member variable is a | |
ShapeDtypeStruct (from jax.eval_shape). A new object of the same class where | |
each member variable is an array of zeros with the same shape and type as | |
the corresponding variable defined by ShapeDtypeStruct. | |
Args: | |
obj: a clu.metrics.Metric object where each member variable is a | |
ShapeDtypeStruct (from jax.eval_shape) | |
Returns: | |
A Metric object with class variables initialized as zero arrays. | |
""" | |
def class_attr_shape(a): | |
attr = getattr(obj, a.name) | |
if isinstance(attr, clu_metrics.Metric): | |
return shape_obj_to_defined_obj(attr) | |
else: | |
if hasattr(attr, "shape"): | |
return jnp.zeros(shape=attr.shape, dtype=attr.dtype) | |
else: | |
return attr | |
return obj.__class__( | |
**{a.name: class_attr_shape(a) for a in dataclasses.fields(obj)}) | |
def set_time_metrics_duration(metrics, duration): | |
"""Sets duration for TimeRate objects in metrics pytree.""" | |
def fn(o): | |
if isinstance(o, Time): | |
return o.replace_duration(duration) | |
else: | |
return o | |
return jax.tree_map(fn, metrics, is_leaf=lambda obj: isinstance(obj, Time)) | |
def set_step_metrics_num_steps(metrics, num_steps): | |
"""Sets steps for Step objects in metrics pytree.""" | |
def fn(o): | |
if isinstance(o, Step): | |
return o.replace_steps(num_steps) | |
else: | |
return o | |
return jax.tree_map(fn, metrics, is_leaf=is_metric_obj) | |