File size: 6,082 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93528c6
 
 
 
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Dataclasses for optimization configs.

This file define the dataclass for optimization configs (OptimizationConfig).
It also has two helper functions get_optimizer_config, and get_lr_config from
an OptimizationConfig class.
"""
from typing import Optional

import dataclasses

from official.modeling.hyperparams import base_config
from official.modeling.hyperparams import oneof
from official.modeling.optimization.configs import learning_rate_config as lr_cfg
from official.modeling.optimization.configs import optimizer_config as opt_cfg


@dataclasses.dataclass
class OptimizerConfig(oneof.OneOfConfig):
  """Configuration for optimizer.

  Attributes:
    type: 'str', type of optimizer to be used, on the of fields below.
    sgd: sgd optimizer config.
    adam: adam optimizer config.
    adamw: adam with weight decay.
    lamb: lamb optimizer.
    rmsprop: rmsprop optimizer.
    lars: lars optimizer.
    adagrad: adagrad optimizer.
    slide: slide optimizer.
    adafactor: adafactor optimizer.
    adafactor_keras: adafactor optimizer.
  """
  type: Optional[str] = None
  sgd: opt_cfg.SGDConfig = dataclasses.field(default_factory=opt_cfg.SGDConfig)
  sgd_experimental: opt_cfg.SGDExperimentalConfig = dataclasses.field(
      default_factory=opt_cfg.SGDExperimentalConfig
  )
  adam: opt_cfg.AdamConfig = dataclasses.field(
      default_factory=opt_cfg.AdamConfig
  )
  adam_experimental: opt_cfg.AdamExperimentalConfig = dataclasses.field(
      default_factory=opt_cfg.AdamExperimentalConfig
  )
  adamw: opt_cfg.AdamWeightDecayConfig = dataclasses.field(
      default_factory=opt_cfg.AdamWeightDecayConfig
  )
  adamw_experimental: opt_cfg.AdamWeightDecayExperimentalConfig = (
      dataclasses.field(
          default_factory=opt_cfg.AdamWeightDecayExperimentalConfig
      )
  )
  lamb: opt_cfg.LAMBConfig = dataclasses.field(
      default_factory=opt_cfg.LAMBConfig
  )
  rmsprop: opt_cfg.RMSPropConfig = dataclasses.field(
      default_factory=opt_cfg.RMSPropConfig
  )
  lars: opt_cfg.LARSConfig = dataclasses.field(
      default_factory=opt_cfg.LARSConfig
  )
  adagrad: opt_cfg.AdagradConfig = dataclasses.field(
      default_factory=opt_cfg.AdagradConfig
  )
  slide: opt_cfg.SLIDEConfig = dataclasses.field(
      default_factory=opt_cfg.SLIDEConfig
  )
  adafactor: opt_cfg.AdafactorConfig = dataclasses.field(
      default_factory=opt_cfg.AdafactorConfig
  )
  adafactor_keras: opt_cfg.AdafactorKerasConfig = dataclasses.field(
      default_factory=opt_cfg.AdafactorKerasConfig
  )


@dataclasses.dataclass
class LrConfig(oneof.OneOfConfig):
  """Configuration for lr schedule.

  Attributes:
    type: 'str', type of lr schedule to be used, one of the fields below.
    constant: constant learning rate config.
    stepwise: stepwise learning rate config.
    exponential: exponential learning rate config.
    polynomial: polynomial learning rate config.
    cosine: cosine learning rate config.
    power: step^power learning rate config.
    power_linear: learning rate config of step^power followed by
      step^power*linear.
    power_with_offset: power decay with a step offset.
    step_cosine_with_offset: Step cosine with a step offset.
  """
  type: Optional[str] = None
  constant: lr_cfg.ConstantLrConfig = dataclasses.field(
      default_factory=lr_cfg.ConstantLrConfig
  )
  stepwise: lr_cfg.StepwiseLrConfig = dataclasses.field(
      default_factory=lr_cfg.StepwiseLrConfig
  )
  exponential: lr_cfg.ExponentialLrConfig = dataclasses.field(
      default_factory=lr_cfg.ExponentialLrConfig
  )
  polynomial: lr_cfg.PolynomialLrConfig = dataclasses.field(
      default_factory=lr_cfg.PolynomialLrConfig
  )
  cosine: lr_cfg.CosineLrConfig = dataclasses.field(
      default_factory=lr_cfg.CosineLrConfig
  )
  power: lr_cfg.DirectPowerLrConfig = dataclasses.field(
      default_factory=lr_cfg.DirectPowerLrConfig
  )
  power_linear: lr_cfg.PowerAndLinearDecayLrConfig = dataclasses.field(
      default_factory=lr_cfg.PowerAndLinearDecayLrConfig
  )
  power_with_offset: lr_cfg.PowerDecayWithOffsetLrConfig = dataclasses.field(
      default_factory=lr_cfg.PowerDecayWithOffsetLrConfig
  )
  step_cosine_with_offset: lr_cfg.StepCosineLrConfig = dataclasses.field(
      default_factory=lr_cfg.StepCosineLrConfig
  )


@dataclasses.dataclass
class WarmupConfig(oneof.OneOfConfig):
  """Configuration for lr schedule.

  Attributes:
    type: 'str', type of warmup schedule to be used, one of the fields below.
    linear: linear warmup config.
    polynomial: polynomial warmup config.
  """
  type: Optional[str] = None
  linear: lr_cfg.LinearWarmupConfig = dataclasses.field(
      default_factory=lr_cfg.LinearWarmupConfig
  )
  polynomial: lr_cfg.PolynomialWarmupConfig = dataclasses.field(
      default_factory=lr_cfg.PolynomialWarmupConfig
  )


@dataclasses.dataclass
class OptimizationConfig(base_config.Config):
  """Configuration for optimizer and learning rate schedule.

  Attributes:
    optimizer: optimizer oneof config.
    ema: optional exponential moving average optimizer config, if specified, ema
      optimizer will be used.
    learning_rate: learning rate oneof config.
    warmup: warmup oneof config.
  """
  optimizer: OptimizerConfig = dataclasses.field(
      default_factory=OptimizerConfig
  )
  ema: Optional[opt_cfg.EMAConfig] = None
  learning_rate: LrConfig = dataclasses.field(default_factory=LrConfig)
  warmup: WarmupConfig = dataclasses.field(default_factory=WarmupConfig)