File size: 10,276 Bytes
b100e1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
# Copyright 2022 The T5X Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Train state for passing around objects during training."""

from typing import Any, Mapping, MutableMapping, Optional, Tuple

from flax import traverse_util
import flax.core
from flax.core import scope as flax_scope
from flax.linen import partitioning as flax_partitioning
import flax.serialization
import flax.struct
import jax.numpy as jnp
from t5x import optimizers

import typing_extensions

EMPTY_DICT = flax.core.freeze({})
FrozenDict = flax_scope.FrozenDict
FrozenVariableDict = flax_scope.FrozenVariableDict
MutableVariableDict = flax_scope.MutableVariableDict
VariableDict = flax_scope.VariableDict


class TrainState(typing_extensions.Protocol):
  """TrainState interface."""

  @property
  def step(self) -> jnp.ndarray:
    """The current training step as an integer scalar."""
    ...

  @property
  def params(self) -> FrozenVariableDict:
    """The parameters of the model as a PyTree matching the Flax module."""
    ...

  @property
  def param_states(self) -> FrozenVariableDict:
    """The optimizer states of the parameters as a PyTree."""
    ...

  @property
  def flax_mutables(self) -> FrozenVariableDict:
    """Flax mutable collection."""
    ...

  def state_dict(self) -> MutableVariableDict:
    """Returns a mutable representation of the state for checkpointing."""
    ...

  def restore_state(self, state_dict: Mapping[str, Any]) -> 'TrainState':
    """Restores the object state from a state dict."""
    ...

  def replace_params(self, params: VariableDict) -> 'TrainState':
    ...

  def replace_step(self, step: jnp.ndarray) -> 'TrainState':
    ...

  def apply_gradient(self,
                     grads,
                     learning_rate,
                     flax_mutables=EMPTY_DICT) -> 'TrainState':
    """Applies gradient, increments step, and returns an updated TrainState."""
    ...

  def as_logical_axes(self) -> 'TrainState':
    """Replaces `param` and `param-states` with their logical axis names."""
    ...


def _validate_params_axes(params_axes, params):
  axis_names = flax_partitioning.get_axis_names(params_axes)
  missing_params_axes = (
      set(traverse_util.flatten_dict(params, sep='/')) -
      set(traverse_util.flatten_dict(axis_names, sep='/')))
  if missing_params_axes:
    raise ValueError(
        f'Missing axis names for parameters: {missing_params_axes}')


def _split_variables_and_axes(
    variables_and_axes: FrozenVariableDict
) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
  """Splits `variables_and_axes` into two separate dicts with the same keys."""
  # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
  variables = {}
  axes = {}
  for k, v in variables_and_axes.items():
    if k.endswith('_axes'):
      axes[k[:-5]] = v  # k without "_axes".
      _validate_params_axes(v, variables_and_axes[k[:-5]])  # k without "_axes".
    else:
      variables[k] = v
  return flax.core.freeze(variables), flax.core.freeze(axes)


class FlaxOptimTrainState(flax.struct.PyTreeNode):
  """Simple train state for holding parameters, step, optimizer state."""
  _optimizer: optimizers.OptimizerType
  # Contains axis metadata (e.g., names) matching parameter tree.
  params_axes: Optional[FrozenVariableDict] = None
  # Flax mutable fields.
  flax_mutables: FrozenDict = EMPTY_DICT
  # Contains axis metadata (e.g., names) matching flax_mutables tree.
  flax_mutables_axes: Optional[FrozenVariableDict] = EMPTY_DICT

  @classmethod
  def create(cls, optimizer_def: optimizers.OptimizerDefType,
             model_variables: FrozenVariableDict) -> 'FlaxOptimTrainState':
    other_variables, params = model_variables.pop('params')
    if 'params_axes' in other_variables:
      other_variables, params_axes = other_variables.pop('params_axes')
      _validate_params_axes(params_axes, params)
    else:
      params_axes = None

    # Split other_variables into mutables and their corresponding axes.
    flax_mutables, flax_mutables_axes = _split_variables_and_axes(
        other_variables)

    # If the optimizer supports `set_param_axes`, then assume that the model
    # code is emitting these axes and use it.
    if hasattr(optimizer_def, 'set_param_axes'):
      if params_axes is None:
        raise ValueError('The optimizer supports params_axes for model-based '
                         'partitioning, but the model is not emitting them.')
      # `get_axis_names` removes "_axes" suffix in the leaf name and replaces
      # `AxisMetadata` with `PartitionSpec`.
      axis_names = flax_partitioning.get_axis_names(params_axes)
      optimizer_def.set_param_axes(axis_names)

    optimizer = optimizer_def.create(params)
    return FlaxOptimTrainState(
        optimizer,
        params_axes=params_axes,
        flax_mutables=flax_mutables,
        flax_mutables_axes=flax_mutables_axes)

  @property
  def step(self) -> jnp.ndarray:
    return self._optimizer.state.step

  @property
  def params(self) -> FrozenVariableDict:
    return self._optimizer.target

  @property
  def param_states(self) -> FrozenVariableDict:
    return self._optimizer.state.param_states

  def state_dict(self) -> MutableVariableDict:
    state_dict = self._optimizer.state_dict()
    if self.flax_mutables:
      state_dict['flax_mutables'] = flax.core.unfreeze(self.flax_mutables)
    return state_dict

  def apply_gradient(self,
                     grads,
                     learning_rate,
                     flax_mutables=EMPTY_DICT) -> 'FlaxOptimTrainState':
    new_optimizer = self._optimizer.apply_gradient(
        grads, learning_rate=learning_rate)
    return self.replace(_optimizer=new_optimizer, flax_mutables=flax_mutables)

  def replace_params(self, params: VariableDict) -> 'FlaxOptimTrainState':
    return self.replace(_optimizer=self._optimizer.replace(target=params))

  def replace_step(self, step: jnp.ndarray) -> 'FlaxOptimTrainState':
    state_dict = self.state_dict()
    state_dict['state']['step'] = step
    return self.restore_state(state_dict)

  def restore_state(self, state_dict: VariableDict) -> 'FlaxOptimTrainState':
    new_optimizer = self._optimizer.restore_state(state_dict)
    return self.replace(
        _optimizer=new_optimizer,
        flax_mutables=flax.core.freeze(state_dict['flax_mutables'])
        if 'flax_mutables' in state_dict else EMPTY_DICT)

  def as_logical_axes(self) -> 'FlaxOptimTrainState':
    if not hasattr(self._optimizer.optimizer_def, 'derive_logical_axes'):
      raise ValueError(
          f"Optimizer '{self._optimizer.optimizer_def.__class__.__name__}' "
          'requires a `derive_logical_axes` method to be used with named axis '
          'partitioning.')
    return FlaxOptimTrainState(
        _optimizer=self._optimizer.optimizer_def.derive_logical_axes(
            self._optimizer,
            flax_partitioning.get_axis_names(self.params_axes)),
        flax_mutables=flax_partitioning.get_axis_names(self.flax_mutables_axes))


class InferenceState(flax.struct.PyTreeNode):
  """State compatible with FlaxOptimTrainState without optimizer state."""

  step: jnp.ndarray
  params: flax_scope.FrozenVariableDict
  params_axes: Optional[flax_scope.FrozenVariableDict] = None
  flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
  flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None

  @classmethod
  def create(cls, model_variables: FrozenVariableDict) -> 'InferenceState':
    other_variables, params = model_variables.pop('params')
    if 'params_axes' in other_variables:
      other_variables, params_axes = other_variables.pop('params_axes')
      _validate_params_axes(params_axes, params)
    else:
      params_axes = None

    # Split other_variables into mutables and their corresponding axes.
    flax_mutables, flax_mutables_axes = _split_variables_and_axes(
        other_variables)

    return InferenceState(
        step=jnp.array(0),
        params=params,
        params_axes=params_axes,
        flax_mutables=flax_mutables,
        flax_mutables_axes=flax_mutables_axes)

  @property
  def param_states(self) -> FrozenVariableDict:
    """The optimizer states of the parameters as a PyTree."""
    raise NotImplementedError('InferenceState has no optimizer states.')

  def apply_gradient(self, *args, **kwargs) -> 'InferenceState':
    raise NotImplementedError(
        'InferenceState does not support `apply_gradient`.')

  def state_dict(self) -> MutableMapping[str, Any]:
    state_dict = {
        'target': flax.core.unfreeze(self.params),
        'state': {
            'step': self.step
        }
    }
    if self.flax_mutables:
      state_dict['flax_mutables'] = flax.core.unfreeze(self.flax_mutables)
    return state_dict

  def replace_step(self, step: jnp.ndarray) -> 'InferenceState':
    return self.replace(step=step)

  def replace_params(self, params: FrozenVariableDict) -> 'InferenceState':
    return self.replace(params=params)

  def restore_state(self, state_dict: Mapping[str, Any]) -> 'InferenceState':
    return self.replace(
        params=flax.core.freeze(state_dict['target']),
        step=state_dict['state']['step'],
        flax_mutables=flax.core.freeze(state_dict['flax_mutables'])
        if 'flax_mutables' in state_dict else EMPTY_DICT)

  def as_logical_axes(self) -> 'InferenceState':
    # Set step to None so that when the logical axes are processed by the
    # flax.partitioning.logical_to_mesh_axes function, it will be skipped
    # because jax.tree_map will short circut and never call the function on the
    # step.
    return InferenceState(
        step=None,
        params=flax_partitioning.get_axis_names(self.params_axes),
        flax_mutables=flax_partitioning.get_axis_names(self.flax_mutables_axes))