File size: 4,851 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""The main entrance for running any of the T5X supported binaries.

Currently this includes train/infer/eval/precompile.

Example Local (CPU) Pretrain Gin usage

python -m t5x.main \
  --gin_file=t5x/examples/t5/t5_1_1/tiny.gin \
  --gin_file=t5x/configs/runs/pretrain.gin \
  --gin.MODEL_DIR=\"/tmp/t5x_pretrain\" \
  --gin.TRAIN_STEPS=10 \
  --gin.MIXTURE_OR_TASK_NAME=\"c4_v220_span_corruption\" \
  --gin.MIXTURE_OR_TASK_MODULE=\"t5.data.mixtures\" \
  --gin.TASK_FEATURE_LENGTHS="{'inputs': 128, 'targets': 30}" \
  --gin.DROPOUT_RATE=0.1 \
  --run_mode=train \
  --logtostderr
"""
import concurrent.futures  # pylint:disable=unused-import
import enum
import os
from typing import Optional, Sequence

from absl import app
from absl import flags
from absl import logging

import gin
import jax
import seqio

from t5x import eval as eval_lib
from t5x import gin_utils
from t5x import infer as infer_lib
from t5x import precompile as precompile_lib
from t5x import train as train_lib
from t5x import utils


@enum.unique
class RunMode(enum.Enum):
  """All the running mode possible in T5X."""
  TRAIN = 'train'
  EVAL = 'eval'
  INFER = 'infer'
  PRECOMPILE = 'precompile'


_GIN_FILE = flags.DEFINE_multi_string(
    'gin_file',
    default=None,
    help='Path to gin configuration file. Multiple paths may be passed and '
    'will be imported in the given order, with later configurations  '
    'overriding earlier ones.')

_GIN_BINDINGS = flags.DEFINE_multi_string(
    'gin_bindings', default=[], help='Individual gin bindings.')

_GIN_SEARCH_PATHS = flags.DEFINE_list(
    'gin_search_paths',
    default=['.'],
    help='Comma-separated list of gin config path prefixes to be prepended '
    'to suffixes given via `--gin_file`. If a file appears in. Only the '
    'first prefix that produces a valid path for each suffix will be '
    'used.')

_RUN_MODE = flags.DEFINE_enum_class(
    'run_mode',
    default=None,
    enum_class=RunMode,
    help='The mode to run T5X under')

_TFDS_DATA_DIR = flags.DEFINE_string(
    'tfds_data_dir', None,
    'If set, this directory will be used to store datasets prepared by '
    'TensorFlow Datasets that are not available in the public TFDS GCS '
    'bucket. Note that this flag overrides the `tfds_data_dir` attribute of '
    'all `Task`s.')

_DRY_RUN = flags.DEFINE_bool(
    'dry_run', False,
    'If set, does not start the function but stil loads and logs the config.')


FLAGS = flags.FLAGS

# Automatically search for gin files relative to the T5X package.
_DEFAULT_GIN_SEARCH_PATHS = [
    os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
]

train = train_lib.train
evaluate = eval_lib.evaluate
infer = infer_lib.infer
precompile = precompile_lib.precompile

_FUNC_MAP = {
    RunMode.TRAIN: train,
    RunMode.EVAL: evaluate,
    RunMode.INFER: infer,
    RunMode.PRECOMPILE: precompile,
}


def main(argv: Sequence[str]):
  if len(argv) > 1:
    raise app.UsageError('Too many command-line arguments.')


  if _TFDS_DATA_DIR.value is not None:
    seqio.set_tfds_data_dir_override(_TFDS_DATA_DIR.value)


  # Register function explicitly under __main__ module, to maintain backward
  # compatability of existing '__main__' module references.
  gin.register(_FUNC_MAP[_RUN_MODE.value], '__main__')
  if _GIN_SEARCH_PATHS.value != ['.']:
    logging.warning(
        'Using absolute paths for the gin files is strongly recommended.')

  # User-provided gin paths take precedence if relative paths conflict.
  gin_utils.parse_gin_flags(_GIN_SEARCH_PATHS.value + _DEFAULT_GIN_SEARCH_PATHS,
                            _GIN_FILE.value, _GIN_BINDINGS.value)

  if _DRY_RUN.value:
    return

  run_with_gin = gin.get_configurable(_FUNC_MAP[_RUN_MODE.value])

  run_with_gin()



def _flags_parser(args: Sequence[str]) -> Sequence[str]:
  """Flag parser.

  See absl.app.parse_flags_with_usage and absl.app.main(..., flags_parser).

  Args:
    args: All command line arguments.

  Returns:
    [str], a non-empty list of remaining command line arguments after parsing
    flags, including program name.
  """
  return app.parse_flags_with_usage(list(gin_utils.rewrite_gin_args(args)))


if __name__ == '__main__':
  jax.config.parse_flags_with_absl()
  app.run(main, flags_parser=_flags_parser)