Spaces:
Runtime error
Runtime error
File size: 10,356 Bytes
5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 |
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Misc for Transformer."""
# pylint: disable=g-bad-import-order
from absl import flags
import tensorflow as tf, tf_keras
from official.legacy.transformer import model_params
from official.utils.flags import core as flags_core
from official.utils.misc import keras_utils
FLAGS = flags.FLAGS
PARAMS_MAP = {
'tiny': model_params.TINY_PARAMS,
'base': model_params.BASE_PARAMS,
'big': model_params.BIG_PARAMS,
}
def get_model_params(param_set, num_gpus):
"""Gets predefined model params."""
if num_gpus > 1:
if param_set == 'big':
return model_params.BIG_MULTI_GPU_PARAMS.copy()
elif param_set == 'base':
return model_params.BASE_MULTI_GPU_PARAMS.copy()
else:
raise ValueError('Not valid params: param_set={} num_gpus={}'.format(
param_set, num_gpus))
return PARAMS_MAP[param_set].copy()
def define_transformer_flags():
"""Add flags and flag validators for running transformer_main."""
# Add common flags (data_dir, model_dir, etc.).
flags_core.define_base(num_gpu=True, distribution_strategy=True)
flags_core.define_performance(
num_parallel_calls=True,
inter_op=False,
intra_op=False,
synthetic_data=True,
max_train_steps=False,
dtype=True,
loss_scale=True,
all_reduce_alg=True,
num_packs=True,
tf_gpu_thread_mode=True,
datasets_num_private_threads=True,
enable_xla=True,
fp16_implementation=True)
flags_core.define_benchmark()
flags_core.define_device(tpu=True)
flags.DEFINE_integer(
name='train_steps',
short_name='ts',
default=300000,
help=flags_core.help_wrap('The number of steps used to train.'))
flags.DEFINE_integer(
name='steps_between_evals',
short_name='sbe',
default=5000,
help=flags_core.help_wrap(
'The Number of training steps to run between evaluations. This is '
'used if --train_steps is defined.'))
flags.DEFINE_boolean(
name='enable_time_history',
default=True,
help='Whether to enable TimeHistory callback.')
flags.DEFINE_boolean(
name='enable_tensorboard',
default=False,
help='Whether to enable Tensorboard callback.')
flags.DEFINE_boolean(
name='enable_metrics_in_training',
default=False,
help='Whether to enable metrics during training.')
flags.DEFINE_boolean(
name='enable_mlir_bridge',
default=False,
help='Whether to enable the TF to XLA bridge.')
# Set flags from the flags_core module as 'key flags' so they're listed when
# the '-h' flag is used. Without this line, the flags defined above are
# only shown in the full `--helpful` help text.
flags.adopt_module_key_flags(flags_core)
# Add transformer-specific flags
flags.DEFINE_enum(
name='param_set',
short_name='mp',
default='big',
enum_values=PARAMS_MAP.keys(),
help=flags_core.help_wrap(
'Parameter set to use when creating and training the model. The '
'parameters define the input shape (batch size and max length), '
'model configuration (size of embedding, # of hidden layers, etc.), '
'and various other settings. The big parameter set increases the '
'default batch size, embedding/hidden size, and filter size. For a '
'complete list of parameters, please see model/model_params.py.'))
flags.DEFINE_bool(
name='static_batch',
short_name='sb',
default=False,
help=flags_core.help_wrap(
'Whether the batches in the dataset should have static shapes. In '
'general, this setting should be False. Dynamic shapes allow the '
'inputs to be grouped so that the number of padding tokens is '
'minimized, and helps model training. In cases where the input shape '
'must be static (e.g. running on TPU), this setting will be ignored '
'and static batching will always be used.'))
flags.DEFINE_integer(
name='max_length',
short_name='ml',
default=256,
help=flags_core.help_wrap(
'Max sentence length for Transformer. Default is 256. Note: Usually '
'it is more effective to use a smaller max length if static_batch is '
'enabled, e.g. 64.'))
# Flags for training with steps (may be used for debugging)
flags.DEFINE_integer(
name='validation_steps',
short_name='vs',
default=64,
help=flags_core.help_wrap('The number of steps used in validation.'))
# BLEU score computation
flags.DEFINE_string(
name='bleu_source',
short_name='bls',
default=None,
help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
))
flags.DEFINE_string(
name='bleu_ref',
short_name='blr',
default=None,
help=flags_core.help_wrap(
'Path to source file containing text translate when calculating the '
'official BLEU score. Both --bleu_source and --bleu_ref must be set. '
))
flags.DEFINE_string(
name='vocab_file',
short_name='vf',
default=None,
help=flags_core.help_wrap(
'Path to subtoken vocabulary file. If data_download.py was used to '
'download and encode the training data, look in the data_dir to find '
'the vocab file.'))
flags.DEFINE_string(
name='mode',
default='train',
help=flags_core.help_wrap('mode: train, eval, or predict'))
flags.DEFINE_bool(
name='use_ctl',
default=False,
help=flags_core.help_wrap(
'Whether the model runs with custom training loop.'))
flags.DEFINE_integer(
name='decode_batch_size',
default=32,
help=flags_core.help_wrap(
'Global batch size used for Transformer autoregressive decoding on '
'TPU.'))
flags.DEFINE_integer(
name='decode_max_length',
default=97,
help=flags_core.help_wrap(
'Max sequence length of the decode/eval data. This is used by '
'Transformer autoregressive decoding on TPU to have minimum '
'paddings.'))
flags.DEFINE_bool(
name='padded_decode',
default=False,
help=flags_core.help_wrap(
'Whether the autoregressive decoding runs with input data padded to '
'the decode_max_length. For TPU/XLA-GPU runs, this flag has to be '
'set due the static shape requirement. Although CPU/GPU could also '
'use padded_decode, it has not been tested. In addition, this method '
'will introduce unnecessary overheads which grow quadratically with '
'the max sequence length.'))
flags.DEFINE_bool(
name='enable_checkpointing',
default=True,
help=flags_core.help_wrap(
'Whether to do checkpointing during training. When running under '
'benchmark harness, we will avoid checkpointing.'))
flags.DEFINE_bool(
name='save_weights_only',
default=True,
help=flags_core.help_wrap(
'Only used when above `enable_checkpointing` is True. '
'If True, then only the model\'s weights will be saved '
'(`model.save_weights(filepath)`), else the full model is saved '
'(`model.save(filepath)`)'))
flags_core.set_defaults(
data_dir='/tmp/translate_ende',
model_dir='/tmp/transformer_model',
batch_size=None)
# pylint: disable=unused-variable
@flags.multi_flags_validator(
['bleu_source', 'bleu_ref'],
message='Both or neither --bleu_source and --bleu_ref must be defined.')
def _check_bleu_files(flags_dict):
return (flags_dict['bleu_source'] is None) == (
flags_dict['bleu_ref'] is None)
@flags.multi_flags_validator(
['bleu_source', 'bleu_ref', 'vocab_file'],
message='--vocab_file must be defined if --bleu_source and --bleu_ref '
'are defined.')
def _check_bleu_vocab_file(flags_dict):
if flags_dict['bleu_source'] and flags_dict['bleu_ref']:
return flags_dict['vocab_file'] is not None
return True
# pylint: enable=unused-variable
def get_callbacks():
"""Returns common callbacks."""
callbacks = []
if FLAGS.enable_time_history:
time_callback = keras_utils.TimeHistory(
FLAGS.batch_size,
FLAGS.log_steps,
logdir=FLAGS.model_dir if FLAGS.enable_tensorboard else None)
callbacks.append(time_callback)
if FLAGS.enable_tensorboard:
tensorboard_callback = tf_keras.callbacks.TensorBoard(
log_dir=FLAGS.model_dir)
callbacks.append(tensorboard_callback)
return callbacks
def update_stats(history, stats, callbacks):
"""Normalizes and updates dictionary of stats.
Args:
history: Results of the training step.
stats: Dict with pre-existing training stats.
callbacks: a list of callbacks which might include a time history callback
used during keras.fit.
"""
if history and history.history:
train_hist = history.history
# Gets final loss from training.
stats['loss'] = float(train_hist['loss'][-1])
if not callbacks:
return
# Look for the time history callback which was used during keras.fit
for callback in callbacks:
if isinstance(callback, keras_utils.TimeHistory):
timestamp_log = callback.timestamp_log
stats['step_timestamp_log'] = timestamp_log
stats['train_finish_time'] = callback.train_finish_time
if len(timestamp_log) > 1:
stats['avg_exp_per_second'] = (
callback.batch_size * callback.log_steps *
(len(callback.timestamp_log) - 1) /
(timestamp_log[-1].timestamp - timestamp_log[0].timestamp))
|