Cielciel's picture
Cielciel/aift-model-review-multiple-label-classification
bbc5ecf
raw
history blame
No virus
2.9 kB
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the \"License\");
# you may not use this file except in compliance with the License.\n",
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an \"AS IS\" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from trainer import experiment
def get_args():
"""Define the task arguments with the default values.
Returns:
experiment parameters
"""
args_parser = argparse.ArgumentParser()
# Experiment arguments
args_parser.add_argument(
'--batch-size',
help='Batch size for each training and evaluation step.',
type=int,
default=16)
args_parser.add_argument(
'--num-epochs',
help="""\
Maximum number of training data epochs on which to train.
If both --train-size and --num-epochs are specified,
--train-steps will be: (train-size/train-batch-size) * num-epochs.\
""",
default=1,
type=int,
)
args_parser.add_argument(
'--seed',
help='Random seed (default: 42)',
type=int,
default=42,
)
# Estimator arguments
args_parser.add_argument(
'--learning-rate',
help='Learning rate value for the optimizers.',
default=2e-5,
type=float)
args_parser.add_argument(
'--weight-decay',
help="""
The factor by which the learning rate should decay by the end of the
training.
decayed_learning_rate =
learning_rate * decay_rate ^ (global_step / decay_steps)
If set to 0 (default), then no decay will occur.
If set to 0.5, then the learning rate should reach 0.5 of its original
value at the end of the training.
Note that decay_steps is set to train_steps.
""",
default=0.01,
type=float)
# Enable hyperparameter
args_parser.add_argument(
'--hp-tune',
default="n",
help='Enable hyperparameter tuning. Valida values are: "y" - enable, "n" - disable')
# Saved model arguments
args_parser.add_argument(
'--job-dir',
default=os.getenv('AIP_MODEL_DIR'),
help='GCS location to export models')
args_parser.add_argument(
'--model-name',
default="finetuned-bert-classifier",
help='The name of your saved model')
return args_parser.parse_args()
def main():
"""Setup / Start the experiment
"""
args = get_args()
print(args)
experiment.run(args)
if __name__ == '__main__':
main()