Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import unittest | |
from absl import flags | |
import tensorflow as tf, tf_keras | |
from official.utils.flags import core as flags_core # pylint: disable=g-bad-import-order | |
def define_flags(): | |
flags_core.define_base( | |
clean=True, | |
num_gpu=False, | |
stop_threshold=True, | |
hooks=True, | |
train_epochs=True, | |
epochs_between_evals=True) | |
flags_core.define_performance( | |
num_parallel_calls=True, | |
inter_op=True, | |
intra_op=True, | |
loss_scale=True, | |
synthetic_data=True, | |
dtype=True) | |
flags_core.define_image() | |
flags_core.define_benchmark() | |
class BaseTester(unittest.TestCase): | |
def setUpClass(cls): | |
super(BaseTester, cls).setUpClass() | |
define_flags() | |
def test_default_setting(self): | |
"""Test to ensure fields exist and defaults can be set.""" | |
defaults = dict( | |
data_dir="dfgasf", | |
model_dir="dfsdkjgbs", | |
train_epochs=534, | |
epochs_between_evals=15, | |
batch_size=256, | |
hooks=["LoggingTensorHook"], | |
num_parallel_calls=18, | |
inter_op_parallelism_threads=5, | |
intra_op_parallelism_threads=10, | |
data_format="channels_first") | |
flags_core.set_defaults(**defaults) | |
flags_core.parse_flags() | |
for key, value in defaults.items(): | |
assert flags.FLAGS.get_flag_value(name=key, default=None) == value | |
def test_benchmark_setting(self): | |
defaults = dict( | |
hooks=["LoggingMetricHook"], | |
benchmark_log_dir="/tmp/12345", | |
gcp_project="project_abc", | |
) | |
flags_core.set_defaults(**defaults) | |
flags_core.parse_flags() | |
for key, value in defaults.items(): | |
assert flags.FLAGS.get_flag_value(name=key, default=None) == value | |
def test_booleans(self): | |
"""Test to ensure boolean flags trigger as expected.""" | |
flags_core.parse_flags([__file__, "--use_synthetic_data"]) | |
assert flags.FLAGS.use_synthetic_data | |
def test_parse_dtype_info(self): | |
flags_core.parse_flags([__file__, "--dtype", "fp16"]) | |
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float16) | |
self.assertEqual( | |
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 2) | |
flags_core.parse_flags([__file__, "--dtype", "fp16", "--loss_scale", "5"]) | |
self.assertEqual( | |
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5) | |
flags_core.parse_flags( | |
[__file__, "--dtype", "fp16", "--loss_scale", "dynamic"]) | |
self.assertEqual( | |
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), "dynamic") | |
flags_core.parse_flags([__file__, "--dtype", "fp32"]) | |
self.assertEqual(flags_core.get_tf_dtype(flags.FLAGS), tf.float32) | |
self.assertEqual( | |
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 1) | |
flags_core.parse_flags([__file__, "--dtype", "fp32", "--loss_scale", "5"]) | |
self.assertEqual( | |
flags_core.get_loss_scale(flags.FLAGS, default_for_fp16=2), 5) | |
with self.assertRaises(SystemExit): | |
flags_core.parse_flags([__file__, "--dtype", "int8"]) | |
with self.assertRaises(SystemExit): | |
flags_core.parse_flags( | |
[__file__, "--dtype", "fp16", "--loss_scale", "abc"]) | |
def test_get_nondefault_flags_as_str(self): | |
defaults = dict( | |
clean=True, | |
data_dir="abc", | |
hooks=["LoggingTensorHook"], | |
stop_threshold=1.5, | |
use_synthetic_data=False) | |
flags_core.set_defaults(**defaults) | |
flags_core.parse_flags() | |
expected_flags = "" | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
flags.FLAGS.clean = False | |
expected_flags += "--noclean" | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
flags.FLAGS.data_dir = "xyz" | |
expected_flags += " --data_dir=xyz" | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
flags.FLAGS.hooks = ["aaa", "bbb", "ccc"] | |
expected_flags += " --hooks=aaa,bbb,ccc" | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
flags.FLAGS.stop_threshold = 3. | |
expected_flags += " --stop_threshold=3.0" | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
flags.FLAGS.use_synthetic_data = True | |
expected_flags += " --use_synthetic_data" | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
# Assert that explicit setting a flag to its default value does not cause it | |
# to appear in the string | |
flags.FLAGS.use_synthetic_data = False | |
expected_flags = expected_flags[:-len(" --use_synthetic_data")] | |
self.assertEqual(flags_core.get_nondefault_flags_as_str(), expected_flags) | |
if __name__ == "__main__": | |
unittest.main() | |