Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Tests for orbit.utils.tpu_summaries.""" | |
import functools | |
import os | |
from orbit.utils import common | |
from orbit.utils import tpu_summaries | |
import tensorflow as tf, tf_keras | |
class TrainFunctionWithSummaries(tpu_summaries.OptionalSummariesFunction): | |
"""Implements a two-program approach for summaries on TPU.""" | |
def __call__(self, num_steps): | |
if tf.summary.should_record_summaries(): | |
output = self.with_summaries(tf.constant(1)) | |
num_steps -= 1 | |
if num_steps >= 1: | |
output = self.without_summaries(num_steps) | |
return output | |
def train_function_with_summaries(function=None, **kwargs): | |
if function is not None: | |
return TrainFunctionWithSummaries(function, **kwargs) | |
return functools.partial(TrainFunctionWithSummaries, **kwargs) | |
class DummyTrainer(tf.Module): | |
def __init__(self): | |
self.step_counter = common.create_global_step() | |
def train_with_tpu_summary_optimization(self, num_steps): | |
for _ in tf.range(num_steps): | |
tf.summary.scalar("step", self.step_counter, step=self.step_counter) | |
self.step_counter.assign_add(1) | |
return self.step_counter | |
def train_with_tpu_summary_optimization_and_input_signature(self, num_steps): | |
for _ in tf.range(num_steps): | |
tf.summary.scalar("step", self.step_counter, step=self.step_counter) | |
self.step_counter.assign_add(1) | |
return self.step_counter | |
def train_with_tpu_summary_optimization_no_decorator(self, num_steps): | |
for _ in tf.range(num_steps): | |
tf.summary.scalar("step", self.step_counter, step=self.step_counter) | |
self.step_counter.assign_add(1) | |
return self.step_counter | |
class TpuSummariesTest(tf.test.TestCase): | |
def setUp(self): | |
super().setUp() | |
self.trainer = DummyTrainer() | |
def _get_events_from_logdir(self, logdir): | |
event_files = tf.io.gfile.listdir(logdir) | |
self.assertLen(event_files, 1) | |
path = os.path.join(logdir, event_files[0]) | |
events = list(tf.compat.v1.train.summary_iterator(path)) | |
return [event for event in events if event.WhichOneof("what") == "summary"] | |
def _validate_tpu_summary_optimization(self, function, *args, **kwargs): | |
logdir = self.get_temp_dir() | |
with tf.summary.create_file_writer(logdir).as_default(): | |
with tf.summary.record_if(lambda: self.trainer.step_counter % 20 == 0): | |
for _ in range(4): | |
output = function(tf.constant(10), *args, **kwargs) | |
events = self._get_events_from_logdir(logdir) | |
self.assertLen(events, 2) | |
self.assertEqual(events[0].step, 0) | |
self.assertEqual(events[1].step, 20) | |
return output | |
def test_train_with_tpu_summary_optimization(self): | |
output = self._validate_tpu_summary_optimization( | |
self.trainer.train_with_tpu_summary_optimization) | |
self.assertEqual(output, self.trainer.step_counter.numpy()) | |
def test_train_with_tpu_summary_optimization_no_decorator(self): | |
optimized = train_function_with_summaries( | |
self.trainer.train_with_tpu_summary_optimization_no_decorator) | |
output = self._validate_tpu_summary_optimization(optimized) | |
self.assertEqual(output, self.trainer.step_counter.numpy()) | |
def test_train_with_tpu_summary_optimization_and_input_signature(self): | |
output = self._validate_tpu_summary_optimization( | |
self.trainer.train_with_tpu_summary_optimization_and_input_signature) | |
self.assertEqual(output, self.trainer.step_counter.numpy()) | |
function = self.trainer.train_with_tpu_summary_optimization_and_input_signature | |
expected = (tf.TensorSpec((), dtype=tf.int32),) | |
input_signature = function.with_summaries.input_signature | |
self.assertEqual(input_signature, expected) | |
input_signature = function.without_summaries.input_signature | |
self.assertEqual(input_signature, expected) | |
if __name__ == "__main__": | |
tf.test.main() | |