# Copyright 2021 The Flax Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for defining custom classes that can be used with jax transformations.
"""
import typing
from typing import TypeVar, Callable, Tuple, Union, Any
from . import serialization
import dataclasses
import jax
# This decorator is interpreted by static analysis tools as a hint
# that a decorator or metaclass causes dataclass-like behavior.
# See https://github.com/microsoft/pyright/blob/main/specs/dataclass_transforms.md
# for more information about the __dataclass_transform__ magic.
_T = TypeVar("_T")
def __dataclass_transform__(
*,
eq_default: bool = True,
order_default: bool = False,
kw_only_default: bool = False,
field_descriptors: Tuple[Union[type, Callable[..., Any]], ...] = (()),
) -> Callable[[_T], _T]:
# If used within a stub file, the following implementation can be
# replaced with "...".
return lambda a: a
@__dataclass_transform__()
def dataclass(clz: _T) -> _T:
"""Create a class which can be passed to functional transformations.
NOTE: Inherit from ``PyTreeNode`` instead to avoid type checking issues when
using PyType.
Jax transformations such as `jax.jit` and `jax.grad` require objects that are
immutable and can be mapped over using the `jax.tree_util` methods.
The `dataclass` decorator makes it easy to define custom classes that can be
passed safely to Jax. For example::
from flax import struct
@struct.dataclass
class Model():
params: Any
# use pytree_node=False to indicate an attribute should not be touched
# by Jax transformations.
apply_fn: FunctionType = struct.field(pytree_node=False)
def __apply__(self, *args):
return self.apply_fn(*args)
model = Model(params, apply_fn)
model.params = params_b # Model is immutable. This will raise an error.
model_b = model.replace(params=params_b) # Use the replace method instead.
# This class can now be used safely in Jax to compute gradients w.r.t. the
# parameters.
model = Model(params, apply_fn)
model_grad = jax.grad(some_loss_fn)(model)
Args:
clz: the class that will be transformed by the decorator.
Returns:
The new class.
"""
# workaround for pytype not recognizing __dataclass_fields__
data_clz: Any = dataclasses.dataclass(frozen=True)(clz)
meta_fields = []
data_fields = []
for name, field_info in data_clz.__dataclass_fields__.items():
is_pytree_node = field_info.metadata.get('pytree_node', True)
if is_pytree_node:
data_fields.append(name)
else:
meta_fields.append(name)
def replace(self, **updates):
""""Returns a new object replacing the specified fields with new values."""
return dataclasses.replace(self, **updates)
data_clz.replace = replace
def iterate_clz(x):
meta = tuple(getattr(x, name) for name in meta_fields)
data = tuple(getattr(x, name) for name in data_fields)
return data, meta
def clz_from_iterable(meta, data):
meta_args = tuple(zip(meta_fields, meta))
data_args = tuple(zip(data_fields, data))
kwargs = dict(meta_args + data_args)
return data_clz(**kwargs)
jax.tree_util.register_pytree_node(data_clz,
iterate_clz,
clz_from_iterable)
def to_state_dict(x):
state_dict = {name: serialization.to_state_dict(getattr(x, name))
for name in data_fields}
return state_dict
def from_state_dict(x, state):
"""Restore the state of a data class."""
state = state.copy() # copy the state so we can pop the restored fields.
updates = {}
for name in data_fields:
if name not in state:
raise ValueError(f'Missing field {name} in state dict while restoring'
f' an instance of {clz.__name__}')
value = getattr(x, name)
value_state = state.pop(name)
updates[name] = serialization.from_state_dict(value, value_state)
if state:
names = ','.join(state.keys())
raise ValueError(f'Unknown field(s) "{names}" in state dict while'
f' restoring an instance of {clz.__name__}')
return x.replace(**updates)
serialization.register_serialization_state(
data_clz, to_state_dict, from_state_dict)
return data_clz
def field(pytree_node=True, **kwargs):
return dataclasses.field(metadata={'pytree_node': pytree_node}, **kwargs)
TNode = TypeVar('TNode', bound='PyTreeNode')
if typing.TYPE_CHECKING:
@__dataclass_transform__()
class PyTreeNodeMeta(type):
pass
else:
PyTreeNodeMeta = type
class PyTreeNode(metaclass=PyTreeNodeMeta):
"""Base class for dataclasses that should act like a JAX pytree node.
See ``flax.struct.dataclass`` for the ``jax.tree_util`` behavior.
This base class additionally avoids type checking errors when using PyType.
Example::
from flax import struct
class Model(struct.PyTreeNode):
params: Any
# use pytree_node=False to indicate an attribute should not be touched
# by Jax transformations.
apply_fn: FunctionType = struct.field(pytree_node=False)
def __apply__(self, *args):
return self.apply_fn(*args)
model = Model(params, apply_fn)
model.params = params_b # Model is immutable. This will raise an error.
model_b = model.replace(params=params_b) # Use the replace method instead.
# This class can now be used safely in Jax to compute gradients w.r.t. the
# parameters.
model = Model(params, apply_fn)
model_grad = jax.grad(some_loss_fn)(model)
"""
def __init_subclass__(cls):
dataclass(cls) # pytype: disable=wrong-arg-types
def __init__(self, *args, **kwargs):
# stub for pytype
raise NotImplementedError
def replace(self: TNode, **overrides) -> TNode:
# stub for pytype
raise NotImplementedError