sanchit-gandhi's picture
Saving train state of step 5000
84a891e verified
from typing import Any, Mapping, MutableMapping, Optional, Tuple
import flax.core
import flax.serialization
import flax.struct
import jax.numpy as jnp
from flax import traverse_util
from flax.core import scope as flax_scope
from flax.linen import partitioning as flax_partitioning
EMPTY_DICT = flax.core.freeze({})
FrozenDict = flax_scope.FrozenDict
FrozenVariableDict = flax_scope.FrozenVariableDict
MutableVariableDict = flax_scope.MutableVariableDict
VariableDict = flax_scope.VariableDict
def _validate_params_axes(params_axes, params):
axis_names = flax_partitioning.get_axis_names(params_axes)
missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
traverse_util.flatten_dict(axis_names, sep="/")
)
if missing_params_axes:
raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")
def _split_variables_and_axes(
variables_and_axes: FrozenVariableDict,
) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
"""Splits `variables_and_axes` into two separate dicts with the same keys."""
# For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
variables = {}
axes = {}
for k, v in variables_and_axes.items():
if k.endswith("_axes"):
axes[k[:-5]] = v # k without "_axes".
_validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes".
else:
variables[k] = v
return flax.core.freeze(variables), flax.core.freeze(axes)
class InferenceState(flax.struct.PyTreeNode):
"""State compatible with FlaxOptimTrainState without optimizer state."""
step: jnp.ndarray
params: flax_scope.FrozenVariableDict
params_axes: Optional[flax_scope.FrozenVariableDict] = None
flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None
@classmethod
def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
other_variables, params = model_variables.pop("params")
if "params_axes" in other_variables:
other_variables, params_axes = other_variables.pop("params_axes")
_validate_params_axes(params_axes, params)
else:
params_axes = None
# Split other_variables into mutables and their corresponding axes.
flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
flax_mutables_axes = flax_mutables_axes or None
return InferenceState(
step=jnp.array(0),
params=params,
params_axes=params_axes,
flax_mutables=flax_mutables,
flax_mutables_axes=flax_mutables_axes,
)
@property
def param_states(self) -> FrozenVariableDict:
"""The optimizer states of the parameters as a PyTree."""
raise NotImplementedError("InferenceState has no optimizer states.")
def apply_gradient(self, *args, **kwargs) -> "InferenceState":
raise NotImplementedError("InferenceState does not support `apply_gradient`.")
def state_dict(self) -> MutableMapping[str, Any]:
state_dict = {
"target": flax.core.unfreeze(self.params),
"state": {"step": self.step},
}
if self.flax_mutables:
state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
return state_dict
def replace_step(self, step: jnp.ndarray) -> "InferenceState":
return self.replace(step=step)
def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
return self.replace(params=params)
def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
return self.replace(flax_mutables=flax_mutables)
def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
return self.replace(
params=flax.core.freeze(state_dict["target"]),
step=state_dict["state"]["step"],
flax_mutables=(
flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
),
)
def as_logical_axes(self) -> "InferenceState":
# Set step to None so that when the logical axes are processed by the
# flax.partitioning.logical_to_mesh_axes function, it will be skipped
# because jax.tree_map will short circut and never call the function on the
# step.
flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
return InferenceState(
step=None,
params=flax_partitioning.get_axis_names(self.params_axes),
flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
)