Spaces:
Running
Running
# Copyright 2019 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Utility functions or classes shared between BERT benchmarks.""" | |
from __future__ import absolute_import | |
from __future__ import division | |
from __future__ import print_function | |
import time | |
# pylint: disable=g-bad-import-order | |
import numpy as np | |
from absl import flags | |
import tensorflow as tf | |
# pylint: enable=g-bad-import-order | |
from official.utils.flags import core as flags_core | |
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark | |
FLAGS = flags.FLAGS | |
class BenchmarkTimerCallback(tf.keras.callbacks.Callback): | |
"""Callback that records time it takes to run each batch.""" | |
def __init__(self, num_batches_to_skip=10): | |
super(BenchmarkTimerCallback, self).__init__() | |
self.batch_start_times = {} | |
self.batch_stop_times = {} | |
def on_batch_begin(self, batch, logs=None): | |
self.batch_start_times[batch] = time.time() | |
def on_batch_end(self, batch, logs=None): | |
# If there are multiple steps_per_loop, the end batch index will not be the | |
# same as the starting index. Use the last starting index instead. | |
if batch not in self.batch_start_times: | |
batch = max(self.batch_start_times.keys()) | |
self.batch_stop_times[batch] = time.time() | |
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1): | |
batch_durations = [] | |
for batch in self.batch_start_times: | |
if batch in self.batch_stop_times and batch >= num_batches_to_skip: | |
batch_durations.append(self.batch_stop_times[batch] - | |
self.batch_start_times[batch]) | |
return batch_size / np.mean(batch_durations) | |
def get_startup_time(self, program_start_time): | |
return self.batch_start_times[0] - program_start_time | |
class BertBenchmarkBase(PerfZeroBenchmark): | |
"""Base class to hold methods common to test classes.""" | |
local_flags = None | |
def __init__(self, output_dir=None, tpu=None, **kwargs): | |
super(BertBenchmarkBase, self).__init__( | |
output_dir=output_dir, tpu=tpu, **kwargs) | |
self.num_gpus = 8 | |
self.timer_callback = None | |
def _setup(self): | |
"""Sets up and resets flags before each test.""" | |
super(BertBenchmarkBase, self)._setup() | |
self.timer_callback = BenchmarkTimerCallback() | |
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy): | |
"""Report benchmark results by writing to local protobuf file. | |
Args: | |
stats: dict returned from BERT models with known entries. | |
wall_time_sec: the during of the benchmark execution in seconds | |
min_accuracy: Minimum classification accuracy constraint to verify | |
correctness of the model. | |
max_accuracy: Maximum classification accuracy constraint to verify | |
correctness of the model. | |
""" | |
metrics = [{ | |
'name': 'training_loss', | |
'value': stats['train_loss'], | |
}] | |
if self.timer_callback: | |
metrics.append({ | |
'name': | |
'exp_per_second', | |
'value': | |
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size * | |
FLAGS.steps_per_loop) | |
}) | |
else: | |
metrics.append({ | |
'name': 'exp_per_second', | |
'value': 0.0, | |
}) | |
if self.timer_callback and 'start_time_sec' in stats: | |
metrics.append({ | |
'name': 'startup_time', | |
'value': self.timer_callback.get_startup_time(stats['start_time_sec']) | |
}) | |
if 'eval_metrics' in stats: | |
metrics.append({ | |
'name': 'eval_accuracy', | |
'value': stats['eval_metrics'], | |
'min_value': min_accuracy, | |
'max_value': max_accuracy, | |
}) | |
flags_str = flags_core.get_nondefault_flags_as_str() | |
self.report_benchmark( | |
iters=stats['total_training_steps'], | |
wall_time=wall_time_sec, | |
metrics=metrics, | |
extras={'flags': flags_str}) | |