Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Contains utilities for TPU summary optimization.""" | |
import contextlib | |
import functools | |
import tensorflow as tf, tf_keras | |
def _soft_device_placement(): | |
"""Context manager for soft device placement, allowing summaries on CPU.""" | |
original_setting = tf.config.get_soft_device_placement() | |
try: | |
tf.config.set_soft_device_placement(True) | |
yield | |
finally: | |
tf.config.set_soft_device_placement(original_setting) | |
class OptionalSummariesFunction: | |
"""Wrapper that provides versions of a function with and without summaries. | |
This is a utility class for implementing optimized summary recording via a | |
two-function approach, specifically important for TPUs. Two `tf.function` | |
versions of a given `function` are created: one with soft device placement | |
enabled (for use on steps that require summary writing), and one with summary | |
writing and soft device placement entirely disabled (for use on all other | |
steps). This removes any performance impact of summaries on steps where they | |
aren't recorded (b/148418718). | |
This class can be used as a base class to implement summary optimizations for | |
a function with a specific signature. For example, to implement efficient TPU | |
summaries for a standard `train()` method (as in `orbit.AbstractTrainer`): | |
class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction): | |
'''Implements a two-program approach for summaries on TPU.''' | |
def __call__(self, num_steps): | |
if tf.summary.should_record_summaries(): | |
output = self.with_summaries(tf.constant(1)) | |
num_steps -= 1 | |
if num_steps >= 1: | |
output = self.without_summaries(num_steps) | |
return output | |
This can be used directly or to implement a decorator: | |
def train_function_with_summaries(function=None, **kwargs): | |
if function is not None: | |
return TrainFunctionWithSummaries(function, **kwargs) | |
return functools.partial(TrainFunctionWithSummaries, **kwargs) | |
The decorator can be applied directly to `train()` methods: | |
@train_function_with_summaries | |
def train(self, num_steps): | |
... | |
A similar approach approach can be implemented for functions with different | |
signatures. | |
Note: The above approach assumes that the frequency of summary writing is | |
based on a step interval that is divisible by the number of steps executed | |
in each call to the `train()` function. This is enforced by the | |
`orbit.Controller`. | |
This wrapper properly handles instance methods (see `__get__`). | |
Attributes: | |
with_summaries: A wrapped version of the underlying function with summaries | |
enabled (using whatever the active predicate is for | |
`tf.summary.record_if`), and placed inside a "soft device placement" | |
context to enable summary recording on TPU. | |
without_summaries: A wrapped version of the underlying function with all | |
summary recording disabled. | |
""" | |
def __init__(self, function, **tf_function_kwargs): | |
"""Constructs an instance wrapping the given `function`. | |
The given `function` is wrapped twice: Once in a "soft device placement" | |
context (allowing summaries to also run on TPU), and once with summary | |
recording entirely disabled. | |
Both of these versions are compiled via `tf.function` (optionally using any | |
supplied `tf.function` settings), and made available as attributes. | |
Args: | |
function: The underlying function to wrap. | |
**tf_function_kwargs: Additional arguments to pass to `tf.function`. | |
""" | |
def with_summaries(*args, **kwargs): | |
with _soft_device_placement(): | |
return function(*args, **kwargs) | |
def without_summaries(*args, **kwargs): | |
with tf.summary.record_if(False): | |
return function(*args, **kwargs) | |
self.with_summaries = with_summaries | |
self.without_summaries = without_summaries | |
def __get__(self, instance, owner): | |
"""Allows this class to be used to wrap methods as well as free functions. | |
For `tf.function` to work properly in all cases (e.g., when an | |
input_signature is specified), any `tf.function`-converted methods must be | |
properly bound to an instance if they are called as an instance method. | |
This is done by implementing this `__get__` method of the descriptor | |
protocol, and forwarding to the `__get__` method on the underlying | |
`tf.function`s. | |
Args: | |
instance: The instance to bind to. | |
owner: The class type of the instance. | |
Returns: | |
A new bound instance of `TpuDiscretionarySummariesFunctions`. | |
""" | |
new = object.__new__(self.__class__) | |
# pytype: disable=attribute-error # See b/162476201. | |
new.with_summaries = self.with_summaries.__get__(instance, owner) | |
new.without_summaries = self.without_summaries.__get__(instance, owner) | |
# pytype: enable=attribute-error | |
return new | |