ASL-MoViNet-T5-translator / orbit /utils /tpu_summaries_test.py
deanna-emery's picture
updates
93528c6
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for orbit.utils.tpu_summaries."""
import functools
import os
from orbit.utils import common
from orbit.utils import tpu_summaries
import tensorflow as tf, tf_keras
class TrainFunctionWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for summaries on TPU."""
def __call__(self, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(num_steps)
return output
def train_function_with_summaries(function=None, **kwargs):
if function is not None:
return TrainFunctionWithSummaries(function, **kwargs)
return functools.partial(TrainFunctionWithSummaries, **kwargs)
class DummyTrainer(tf.Module):
def __init__(self):
self.step_counter = common.create_global_step()
@train_function_with_summaries
def train_with_tpu_summary_optimization(self, num_steps):
for _ in tf.range(num_steps):
tf.summary.scalar("step", self.step_counter, step=self.step_counter)
self.step_counter.assign_add(1)
return self.step_counter
@train_function_with_summaries(
input_signature=[tf.TensorSpec((), dtype=tf.int32)])
def train_with_tpu_summary_optimization_and_input_signature(self, num_steps):
for _ in tf.range(num_steps):
tf.summary.scalar("step", self.step_counter, step=self.step_counter)
self.step_counter.assign_add(1)
return self.step_counter
def train_with_tpu_summary_optimization_no_decorator(self, num_steps):
for _ in tf.range(num_steps):
tf.summary.scalar("step", self.step_counter, step=self.step_counter)
self.step_counter.assign_add(1)
return self.step_counter
class TpuSummariesTest(tf.test.TestCase):
def setUp(self):
super().setUp()
self.trainer = DummyTrainer()
def _get_events_from_logdir(self, logdir):
event_files = tf.io.gfile.listdir(logdir)
self.assertLen(event_files, 1)
path = os.path.join(logdir, event_files[0])
events = list(tf.compat.v1.train.summary_iterator(path))
return [event for event in events if event.WhichOneof("what") == "summary"]
def _validate_tpu_summary_optimization(self, function, *args, **kwargs):
logdir = self.get_temp_dir()
with tf.summary.create_file_writer(logdir).as_default():
with tf.summary.record_if(lambda: self.trainer.step_counter % 20 == 0):
for _ in range(4):
output = function(tf.constant(10), *args, **kwargs)
events = self._get_events_from_logdir(logdir)
self.assertLen(events, 2)
self.assertEqual(events[0].step, 0)
self.assertEqual(events[1].step, 20)
return output
def test_train_with_tpu_summary_optimization(self):
output = self._validate_tpu_summary_optimization(
self.trainer.train_with_tpu_summary_optimization)
self.assertEqual(output, self.trainer.step_counter.numpy())
def test_train_with_tpu_summary_optimization_no_decorator(self):
optimized = train_function_with_summaries(
self.trainer.train_with_tpu_summary_optimization_no_decorator)
output = self._validate_tpu_summary_optimization(optimized)
self.assertEqual(output, self.trainer.step_counter.numpy())
def test_train_with_tpu_summary_optimization_and_input_signature(self):
output = self._validate_tpu_summary_optimization(
self.trainer.train_with_tpu_summary_optimization_and_input_signature)
self.assertEqual(output, self.trainer.step_counter.numpy())
function = self.trainer.train_with_tpu_summary_optimization_and_input_signature
expected = (tf.TensorSpec((), dtype=tf.int32),)
input_signature = function.with_summaries.input_signature
self.assertEqual(input_signature, expected)
input_signature = function.without_summaries.input_signature
self.assertEqual(input_signature, expected)
if __name__ == "__main__":
tf.test.main()