|
from typing import Any, Mapping, MutableMapping, Optional, Tuple |
|
|
|
import flax.core |
|
import flax.serialization |
|
import flax.struct |
|
import jax.numpy as jnp |
|
from flax import traverse_util |
|
from flax.core import scope as flax_scope |
|
from flax.linen import partitioning as flax_partitioning |
|
|
|
|
|
EMPTY_DICT = flax.core.freeze({}) |
|
FrozenDict = flax_scope.FrozenDict |
|
FrozenVariableDict = flax_scope.FrozenVariableDict |
|
MutableVariableDict = flax_scope.MutableVariableDict |
|
VariableDict = flax_scope.VariableDict |
|
|
|
|
|
def _validate_params_axes(params_axes, params): |
|
axis_names = flax_partitioning.get_axis_names(params_axes) |
|
missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set( |
|
traverse_util.flatten_dict(axis_names, sep="/") |
|
) |
|
if missing_params_axes: |
|
raise ValueError(f"Missing axis names for parameters: {missing_params_axes}") |
|
|
|
|
|
def _split_variables_and_axes( |
|
variables_and_axes: FrozenVariableDict, |
|
) -> Tuple[FrozenVariableDict, FrozenVariableDict]: |
|
"""Splits `variables_and_axes` into two separate dicts with the same keys.""" |
|
|
|
variables = {} |
|
axes = {} |
|
for k, v in variables_and_axes.items(): |
|
if k.endswith("_axes"): |
|
axes[k[:-5]] = v |
|
_validate_params_axes(v, variables_and_axes[k[:-5]]) |
|
else: |
|
variables[k] = v |
|
return flax.core.freeze(variables), flax.core.freeze(axes) |
|
|
|
|
|
class InferenceState(flax.struct.PyTreeNode): |
|
"""State compatible with FlaxOptimTrainState without optimizer state.""" |
|
|
|
step: jnp.ndarray |
|
params: flax_scope.FrozenVariableDict |
|
params_axes: Optional[flax_scope.FrozenVariableDict] = None |
|
flax_mutables: flax_scope.FrozenDict = EMPTY_DICT |
|
flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None |
|
|
|
@classmethod |
|
def create(cls, model_variables: FrozenVariableDict) -> "InferenceState": |
|
other_variables, params = model_variables.pop("params") |
|
if "params_axes" in other_variables: |
|
other_variables, params_axes = other_variables.pop("params_axes") |
|
_validate_params_axes(params_axes, params) |
|
else: |
|
params_axes = None |
|
|
|
|
|
flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables) |
|
flax_mutables_axes = flax_mutables_axes or None |
|
return InferenceState( |
|
step=jnp.array(0), |
|
params=params, |
|
params_axes=params_axes, |
|
flax_mutables=flax_mutables, |
|
flax_mutables_axes=flax_mutables_axes, |
|
) |
|
|
|
@property |
|
def param_states(self) -> FrozenVariableDict: |
|
"""The optimizer states of the parameters as a PyTree.""" |
|
raise NotImplementedError("InferenceState has no optimizer states.") |
|
|
|
def apply_gradient(self, *args, **kwargs) -> "InferenceState": |
|
raise NotImplementedError("InferenceState does not support `apply_gradient`.") |
|
|
|
def state_dict(self) -> MutableMapping[str, Any]: |
|
state_dict = { |
|
"target": flax.core.unfreeze(self.params), |
|
"state": {"step": self.step}, |
|
} |
|
if self.flax_mutables: |
|
state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables) |
|
return state_dict |
|
|
|
def replace_step(self, step: jnp.ndarray) -> "InferenceState": |
|
return self.replace(step=step) |
|
|
|
def replace_params(self, params: FrozenVariableDict) -> "InferenceState": |
|
return self.replace(params=params) |
|
|
|
def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState": |
|
return self.replace(flax_mutables=flax_mutables) |
|
|
|
def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState": |
|
return self.replace( |
|
params=flax.core.freeze(state_dict["target"]), |
|
step=state_dict["state"]["step"], |
|
flax_mutables=( |
|
flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT |
|
), |
|
) |
|
|
|
def as_logical_axes(self) -> "InferenceState": |
|
|
|
|
|
|
|
|
|
flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT |
|
return InferenceState( |
|
step=None, |
|
params=flax_partitioning.get_axis_names(self.params_axes), |
|
flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes), |
|
) |
|
|