deanna-emery's picture
updates
93528c6
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides AbstractTrainer/Evaluator base classes, defining train/eval APIs."""
import abc
from typing import Dict, Optional, Union
import numpy as np
import tensorflow as tf, tf_keras
Output = Dict[str, Union[tf.Tensor, float, np.number, np.ndarray, 'Output']] # pytype: disable=not-supported-yet
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the API required for training."""
@abc.abstractmethod
def train(self, num_steps: tf.Tensor) -> Optional[Output]:
"""Implements `num_steps` steps of training.
This method will be called by the `Controller` to perform the "inner loop"
of training. This inner loop amortizes the cost of bookkeeping associated
with checkpointing, evaluation, and writing summaries. Additionally, the
inner loop can be implemented (if desired) using TensorFlow's looping
constructs (e.g. a `for` loop over a `tf.range` inside a `tf.function`),
which can be necessary for getting optimal performance when running on TPU.
For cases that don't require peak performance, a simple Python loop can be
used instead for simplicity.
Args:
num_steps: The number of training steps to run. Note that it is up to the
model what constitutes a "step", which may involve more than one update
to model parameters (e.g., if training a GAN).
Returns:
Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
If a dictionary is returned, it will be written to logs and as TensorBoard
summaries. The dictionary may also be nested, which will generate a
hierarchy of summary directories.
"""
pass
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
"""An abstract class defining the API required for evaluation."""
@abc.abstractmethod
def evaluate(self, num_steps: tf.Tensor) -> Optional[Output]:
"""Implements `num_steps` steps of evaluation.
This method will by called the `Controller` to perform an evaluation. The
`num_steps` parameter specifies the number of steps of evaluation to run,
which is specified by the user when calling one of the `Controller`'s
evaluation methods. A special sentinel value of `-1` is reserved to indicate
evaluation should run until the underlying data source is exhausted.
Args:
num_steps: The number of evaluation steps to run. Note that it is up to
the model what constitutes a "step". Evaluations may also want to
support "complete" evaluations when `num_steps == -1`, running until a
given data source is exhausted.
Returns:
Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
If a dictionary is returned, it will be written to logs and as TensorBoard
summaries. The dictionary may also be nested, which will generate a
hierarchy of summary directories.
"""
pass