File size: 4,747 Bytes
84a891e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from typing import Any, Mapping, MutableMapping, Optional, Tuple

import flax.core
import flax.serialization
import flax.struct
import jax.numpy as jnp
from flax import traverse_util
from flax.core import scope as flax_scope
from flax.linen import partitioning as flax_partitioning


EMPTY_DICT = flax.core.freeze({})
FrozenDict = flax_scope.FrozenDict
FrozenVariableDict = flax_scope.FrozenVariableDict
MutableVariableDict = flax_scope.MutableVariableDict
VariableDict = flax_scope.VariableDict


def _validate_params_axes(params_axes, params):
    axis_names = flax_partitioning.get_axis_names(params_axes)
    missing_params_axes = set(traverse_util.flatten_dict(params, sep="/")) - set(
        traverse_util.flatten_dict(axis_names, sep="/")
    )
    if missing_params_axes:
        raise ValueError(f"Missing axis names for parameters: {missing_params_axes}")


def _split_variables_and_axes(
    variables_and_axes: FrozenVariableDict,
) -> Tuple[FrozenVariableDict, FrozenVariableDict]:
    """Splits `variables_and_axes` into two separate dicts with the same keys."""
    # For each `key`, `key_axes` (if any) are its axes in `variables_and_axes`.
    variables = {}
    axes = {}
    for k, v in variables_and_axes.items():
        if k.endswith("_axes"):
            axes[k[:-5]] = v  # k without "_axes".
            _validate_params_axes(v, variables_and_axes[k[:-5]])  # k without "_axes".
        else:
            variables[k] = v
    return flax.core.freeze(variables), flax.core.freeze(axes)


class InferenceState(flax.struct.PyTreeNode):
    """State compatible with FlaxOptimTrainState without optimizer state."""

    step: jnp.ndarray
    params: flax_scope.FrozenVariableDict
    params_axes: Optional[flax_scope.FrozenVariableDict] = None
    flax_mutables: flax_scope.FrozenDict = EMPTY_DICT
    flax_mutables_axes: Optional[flax_scope.FrozenVariableDict] = None

    @classmethod
    def create(cls, model_variables: FrozenVariableDict) -> "InferenceState":
        other_variables, params = model_variables.pop("params")
        if "params_axes" in other_variables:
            other_variables, params_axes = other_variables.pop("params_axes")
            _validate_params_axes(params_axes, params)
        else:
            params_axes = None

        # Split other_variables into mutables and their corresponding axes.
        flax_mutables, flax_mutables_axes = _split_variables_and_axes(other_variables)
        flax_mutables_axes = flax_mutables_axes or None
        return InferenceState(
            step=jnp.array(0),
            params=params,
            params_axes=params_axes,
            flax_mutables=flax_mutables,
            flax_mutables_axes=flax_mutables_axes,
        )

    @property
    def param_states(self) -> FrozenVariableDict:
        """The optimizer states of the parameters as a PyTree."""
        raise NotImplementedError("InferenceState has no optimizer states.")

    def apply_gradient(self, *args, **kwargs) -> "InferenceState":
        raise NotImplementedError("InferenceState does not support `apply_gradient`.")

    def state_dict(self) -> MutableMapping[str, Any]:
        state_dict = {
            "target": flax.core.unfreeze(self.params),
            "state": {"step": self.step},
        }
        if self.flax_mutables:
            state_dict["flax_mutables"] = flax.core.unfreeze(self.flax_mutables)
        return state_dict

    def replace_step(self, step: jnp.ndarray) -> "InferenceState":
        return self.replace(step=step)

    def replace_params(self, params: FrozenVariableDict) -> "InferenceState":
        return self.replace(params=params)

    def replace_flax_mutables(self, flax_mutables: FrozenDict) -> "InferenceState":
        return self.replace(flax_mutables=flax_mutables)

    def restore_state(self, state_dict: Mapping[str, Any]) -> "InferenceState":
        return self.replace(
            params=flax.core.freeze(state_dict["target"]),
            step=state_dict["state"]["step"],
            flax_mutables=(
                flax.core.freeze(state_dict["flax_mutables"]) if "flax_mutables" in state_dict else EMPTY_DICT
            ),
        )

    def as_logical_axes(self) -> "InferenceState":
        # Set step to None so that when the logical axes are processed by the
        # flax.partitioning.logical_to_mesh_axes function, it will be skipped
        # because jax.tree_map will short circut and never call the function on the
        # step.
        flax_mutables_axes = self.flax_mutables_axes or EMPTY_DICT
        return InferenceState(
            step=None,
            params=flax_partitioning.get_axis_names(self.params_axes),
            flax_mutables=flax_partitioning.get_axis_names(flax_mutables_axes),
        )