sanchit-gandhi's picture
Saving train state of step 5000
de07efc
from typing import Any, Mapping, MutableMapping, Optional, Tuple
import flax.core
import flax.serialization
import flax.struct
import jax.numpy as jnp
from flax import traverse_util
from flax.core import scope as flax_scope
from flax.linen import partitioning as flax_partitioning
EMPTY_DICT = flax.core.freeze({})
FrozenDict = flax_scope.FrozenDict
FrozenVariableDict = flax_scope.FrozenVariableDict
MutableVariableDict = flax_scope.MutableVariableDict
VariableDict = flax_scope.VariableDict
def _validate_params_axes(params_axes, params):
axis_names = flax_partitioning.get_axis_names(params_axes)
missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
traverse_util.flatten_dict(axis_names, sep="/")
)
if missing_params_axes:
raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
def _split_variables_and_axes(
variables_and_axes: FrozenVariableDict,
) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
"""Splits `variables_and_axes` into two separate dicts with the same keys."""
# For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
variables = {}
axes = {}
for k, v in variables_and_axes.items():
if k.endswith("_axes"):
axes[k[:-5]] = v # k without "_axes".
_validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
else:
variables[k] = v
return flax.core.freeze(variables), flax.core.freeze(axes)
class InferenceState(flax.struct.PyTreeNode):
"""State compatible with FlaxOptimTrainState without optimizer state."""
step: jnp.ndarray
params: flax_scope.FrozenVariableDict
params_axes: Optional[flax_scope.FrozenVariableDict] = None
flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
@classmethod
def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
other_variables, params = model_variables.pop("params")
if "params_axes" in other_variables:
other_variables, params_axes = other_variables.pop("params_axes")
_validate_params_axes(params_axes, params)
else:
params_axes = None
# Split other_variables into mutables and their corresponding axes.
flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
flax_mutables_axes = flax_mutables_axes or None
return InferenceState(
step=jnp.array(0),
params=params,
params_axes=params_axes,
flax_mutables=flax_mutables,
flax_mutables_axes=flax_mutables_axes,
)
@property
def param_states(self) -> FrozenVariableDict:
"""The optimizer states of the parameters as a PyTree."""
raise NotImplementedError("InferenceState has no optimizer states.")
def apply_gradient(self, *args, **kwargs) -> "InferenceState":
raise NotImplementedError("InferenceState does not support `apply_gradient`.")
def state_dict(self) -> MutableMapping[str, Any]:
state_dict = {
"target": flax.core.unfreeze(self.params),
"state": {"step": self.step},
}
if self.flax_mutables:
state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
return state_dict
def replace_step(self, step: jnp.ndarray) -> "InferenceState":
return self.replace(step=step)
def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
return self.replace(params=params)
def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
return self.replace(flax_mutables=flax_mutables)
def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
return self.replace(
params=flax.core.freeze(state_dict["target"]),
step=state_dict["state"]["step"],
flax_mutables=(
flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
),
)
def as_logical_axes(self) -> "InferenceState":
# Set step to None so that when the logical axes are processed by the
# flax.partitioning.logical_to_mesh_axes function, it will be skipped
# because jax.tree_map will short circut and never call the function on the
# step.
flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
return InferenceState(
step=None,
params=flax_partitioning.get_axis_names(self.params_axes),
flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
)