from typing import Any, Mapping, MutableMapping, Optional, Tuple import flax.core import flax.serialization import flax.struct import jax.numpy as jnp from flax import traverse_util from flax.core import scope as flax_scope from flax.linen import partitioning as flax_partitioning EMPTY_DICT = flax.core.freeze({}) FrozenDict = flax_scope.FrozenDict FrozenVariableDict = flax_scope.FrozenVariableDict MutableVariableDict = flax_scope.MutableVariableDict VariableDict = flax_scope.VariableDict def _validate_params_axes(params_axes, params): axis_names = flax_partitioning.get_axis_names(params_axes) missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set( traverse_util.flatten_dict(axis_names, sep="/") ) if missing_params_axes: raise ValueError(f"Missing axis names for parameters: {missing_params_axes}") def _split_variables_and_axes( variables_and_axes: FrozenVariableDict, ) -> Tuple[FrozenVariableDict, FrozenVariableDict]: """Splits `variables_and_axes` into two separate dicts with the same keys.""" # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`. variables = {} axes = {} for k, v in variables_and_axes.items(): if k.endswith("_axes"): axes[k[:-5]] = v # k without "_axes". _validate_params_axes(v, variables_and_axes[k[:-5]]) # k without "_axes". else: variables[k] = v return flax.core.freeze(variables), flax.core.freeze(axes) class InferenceState(flax.struct.PyTreeNode): """State compatible with FlaxOptimTrainState without optimizer state.""" step: jnp.ndarray params: flax_scope.FrozenVariableDict params_axes: Optional[flax_scope.FrozenVariableDict] = None flax_mutables: flax_scope.FrozenDict = EMPTY_DICT flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None @classmethod def create(cls, model_variables: FrozenVariableDict) -> "InferenceState": other_variables, params = model_variables.pop("params") if "params_axes" in other_variables: other_variables, params_axes = other_variables.pop("params_axes") _validate_params_axes(params_axes, params) else: params_axes = None # Split other_variables into mutables and their corresponding axes. flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables) flax_mutables_axes = flax_mutables_axes or None return InferenceState( step=jnp.array(0), params=params, params_axes=params_axes, flax_mutables=flax_mutables, flax_mutables_axes=flax_mutables_axes, ) @property def param_states(self) -> FrozenVariableDict: """The optimizer states of the parameters as a PyTree.""" raise NotImplementedError("InferenceState has no optimizer states.") def apply_gradient(self, *args, **kwargs) -> "InferenceState": raise NotImplementedError("InferenceState does not support `apply_gradient`.") def state_dict(self) -> MutableMapping[str, Any]: state_dict = { "target": flax.core.unfreeze(self.params), "state": {"step": self.step}, } if self.flax_mutables: state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables) return state_dict def replace_step(self, step: jnp.ndarray) -> "InferenceState": return self.replace(step=step) def replace_params(self, params: FrozenVariableDict) -> "InferenceState": return self.replace(params=params) def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState": return self.replace(flax_mutables=flax_mutables) def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState": return self.replace( params=flax.core.freeze(state_dict["target"]), step=state_dict["state"]["step"], flax_mutables=( flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT ), ) def as_logical_axes(self) -> "InferenceState": # Set step to None so that when the logical axes are processed by the # flax.partitioning.logical_to_mesh_axes function, it will be skipped # because jax.tree_map will short circut and never call the function on the # step. flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT return InferenceState( step=None, params=flax_partitioning.get_axis_names(self.params_axes), flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), )