Spaces:
Runtime error
Runtime error
# Copyright (c) Facebook, Inc. and its affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the BSD-style license found in the | |
# LICENSE file in the root directory of this source tree. | |
import copy | |
from typing import ( | |
Any, | |
Callable, | |
Dict, | |
Iterable, | |
List, | |
NoReturn, | |
Sequence, | |
Tuple, | |
Type, | |
Union, | |
) | |
import torch | |
import torch.nn as nn | |
from torch import Tensor | |
from torch.nn.utils._named_member_accessor import NamedMemberAccessor | |
# Utilities to make nn.Module "functional" | |
# In particular the goal is to be able to provide a function that takes as input | |
# the parameters and evaluate the nn.Module using fixed inputs. | |
def raise_parameter_tying_error() -> NoReturn: | |
raise RuntimeError( | |
"make_functional(module): we don't yet support models that " | |
"do parameter tying (also sometimes known as weight sharing). " | |
"Please try to rewrite your model by replacing all instances of the " | |
"tied parameter with another and/or comment your support in " | |
"https://github.com/pytorch/functorch/issues/446" | |
) | |
def create_names_map( | |
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], | |
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]], | |
) -> Dict[str, List[str]]: | |
""" | |
named_params is a dictionary of tensors: {'A': A, 'B': B} | |
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B} | |
with potentially tied (or 'duplicated') tensors | |
This function creates a mapping from the names in named_params to the | |
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}. | |
""" | |
named_params = dict(named_params) | |
tied_named_params = dict(tied_named_params) | |
tensors_dict_keys = set(named_params.keys()) | |
tied_tensors_dict_keys = set(tied_named_params.keys()) | |
assert tensors_dict_keys.issubset(tied_tensors_dict_keys) | |
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {} | |
for key, tensor in named_params.items(): | |
tensor_to_mapping[tensor] = (key, []) | |
for key, tensor in tied_named_params.items(): | |
assert tensor in tensor_to_mapping | |
tensor_to_mapping[tensor][1].append(key) | |
return dict(tensor_to_mapping.values()) | |
def _extract_members( | |
mod: nn.Module, | |
named_members: Callable[..., Iterable[Tuple[str, Tensor]]], | |
subclass: Callable[[Tensor], Tensor], | |
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: | |
all_named_members = tuple(named_members(remove_duplicate=False)) | |
unique_named_members = tuple(named_members(remove_duplicate=True)) | |
names_map = create_names_map(unique_named_members, all_named_members) | |
# Remove all the members in the model | |
memo = {} | |
accessor = NamedMemberAccessor(mod) | |
for name, p in all_named_members: | |
if p not in memo: | |
memo[p] = subclass(torch.empty_like(p, device="meta")) | |
replacement = memo[p] | |
accessor.set_tensor(name, replacement) | |
if len(unique_named_members) == 0: | |
names, params = (), () | |
else: | |
names, params = zip(*unique_named_members) # type: ignore[assignment] | |
return params, names, names_map | |
def extract_weights( | |
mod: nn.Module, | |
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: | |
""" | |
This function removes all the Parameters from the model and | |
return them as a tuple as well as their original attribute names. | |
The weights must be re-loaded with `load_weights` before the model | |
can be used again. | |
Note that this function modifies the model in place and after this | |
call, mod.parameters() will be empty. | |
""" | |
return _extract_members(mod, mod.named_parameters, nn.Parameter) | |
def extract_buffers( | |
mod: nn.Module, | |
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]: | |
return _extract_members(mod, mod.named_buffers, lambda x: x) | |
def load_weights( | |
mod: nn.Module, | |
names: Sequence[str], | |
params: Sequence[Tensor], | |
as_params: bool = False, | |
) -> None: | |
""" | |
Reload a set of weights so that `mod` can be used again to perform a forward pass. | |
Note that the `params` are regular Tensors (that can have history) and so are left | |
as Tensors. This means that mod.parameters() will still be empty after this call. | |
""" | |
accessor = NamedMemberAccessor(mod) | |
if as_params: | |
params = [nn.Parameter(p) for p in params] | |
accessor.set_tensors(names, params) | |
def _swap_state( | |
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor] | |
) -> List[Tensor]: | |
result: List[Tensor] = [] | |
accessor = NamedMemberAccessor(mod) | |
for (_, attr_names), elem in zip(names_map.items(), elems): | |
for i, attr_name in enumerate(attr_names): | |
if i == 0: | |
result.append(accessor.swap_tensor(attr_name, elem)) | |
else: | |
accessor.set_tensor(attr_name, elem) | |
return result | |
def load_buffers( | |
mod: nn.Module, | |
names: Sequence[str], | |
buffers: Sequence[Tensor], | |
as_params: bool = False, | |
) -> None: | |
accessor = NamedMemberAccessor(mod) | |
accessor.set_tensors(names, buffers) | |
def load_state( | |
model: nn.Module, | |
weights: Sequence[Tensor], | |
weight_names: Sequence[str], | |
buffers: Sequence[Tensor] = (), | |
buffer_names: Sequence[str] = (), | |
) -> nn.Module: | |
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model | |
load_state takes `weights` and `buffers` and assigns them to the model. | |
This is the inverse operation of `make_functional_deprecated_v1`. | |
""" | |
assert len(weight_names) == len(weights) | |
load_weights(model, weight_names, weights) | |
if len(buffers) > 0: | |
assert len(buffer_names) == len(buffers) | |
load_buffers(model, buffer_names, buffers) | |
return model | |
def make_functional_deprecated_v1(model: nn.Module): | |
"""make_functional_deprecated_v1(model) -> weights, func, weight_names | |
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights) | |
and returns a functional version of the model, `func`. This makes | |
it so that it is possible use transforms over the parameters of | |
`model`. | |
`func` can be invoked as follows: | |
``` | |
x = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
weights, func, _ = make_functional_deprecated_v1(model) | |
func(weights, (x,)) | |
``` | |
And here is an example of applying the grad transform: | |
``` | |
x = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
weights, _, func = make_functional_deprecated_v1(model) | |
grad_weights = grad(func)(weights, (x,)) | |
``` | |
To put the state back into a model, use `load_state`. | |
""" | |
buffers = list(model.buffers()) | |
if len(buffers) > 0: | |
raise RuntimeError( | |
"make_functional_deprecated_v1(model): `model` has buffers. Please use " | |
"make_functional_with_buffers_deprecated_v1(model) instead." | |
) | |
weights, descriptors, _ = extract_weights(model) | |
def fun(weights, data): | |
mutable_model = copy.deepcopy(model) | |
load_weights(mutable_model, descriptors, weights) | |
return mutable_model(*data) | |
return weights, fun, descriptors | |
def make_functional_with_buffers_deprecated_v1(model: nn.Module): | |
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names | |
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers) | |
and returns a functional version of the model, `func`. | |
`func` can be invoked as follows: | |
``` | |
x = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) | |
func(weights, buffers, (x,)) | |
``` | |
And here is an example of applying the grad transform: | |
``` | |
x = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model) | |
func(weights, buffers, (x,)) | |
grad_weights = grad(func)(weights, buffers, (x,)) | |
``` | |
To put the state back into a model, use `load_state`. | |
""" | |
weights, weight_descriptors, _ = extract_weights(model) | |
buffers, buf_descriptors, _ = extract_buffers(model) | |
def fun(weights, buffers, data): | |
mutable_model = copy.deepcopy(model) | |
load_weights(mutable_model, weight_descriptors, weights) | |
load_buffers(mutable_model, buf_descriptors, buffers) | |
return mutable_model(*data) | |
return weights, buffers, fun, weight_descriptors, buf_descriptors | |
class FunctionalModuleWithBuffers(nn.Module): | |
""" | |
This is the callable object returned by :func:`make_functional_with_buffers`. | |
""" | |
def __init__( | |
self, | |
stateless_model: nn.Module, | |
param_names: Tuple[str, ...], | |
buffer_names: Tuple[str, ...], | |
param_names_map: Dict[str, List[str]], | |
buffer_names_map: Dict[str, List[str]], | |
) -> None: | |
super().__init__() | |
self.stateless_model = stateless_model | |
self.param_names = param_names | |
self.buffer_names = buffer_names | |
self.all_names_map = dict(param_names_map) | |
self.all_names_map.update(buffer_names_map) | |
def _create_from( | |
model: nn.Module, disable_autograd_tracking: bool = False | |
) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
# TODO: We don't need to copy the model to create a stateless copy | |
model_copy = copy.deepcopy(model) | |
params, param_names, param_names_map = extract_weights(model_copy) | |
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy) | |
if disable_autograd_tracking: | |
for param in params: | |
param.requires_grad_(False) | |
return ( | |
FunctionalModuleWithBuffers( | |
model_copy, param_names, buffer_names, param_names_map, buffer_names_map | |
), | |
params, | |
buffers, | |
) | |
def forward( | |
self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs | |
) -> Any: | |
# Temporarily load the state back onto self.stateless_model | |
old_state = _swap_state( | |
self.stateless_model, | |
self.all_names_map, | |
tuple(params) + tuple(buffers), | |
) | |
try: | |
return self.stateless_model(*args, **kwargs) | |
finally: | |
# Remove the loaded state on self.stateless_model | |
_swap_state(self.stateless_model, self.all_names_map, old_state) | |
class FunctionalModule(nn.Module): | |
""" | |
This is the callable object returned by :func:`make_functional`. | |
""" | |
def __init__( | |
self, | |
stateless_model: nn.Module, | |
param_names: Tuple[str, ...], | |
names_map: Dict[str, List[str]], | |
) -> None: | |
super().__init__() | |
self.stateless_model = stateless_model | |
self.param_names = param_names | |
self.names_map = names_map | |
def _create_from( | |
model: nn.Module, disable_autograd_tracking: bool = False | |
) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]: | |
# TODO: We don't need to copy the model to create a stateless copy | |
model_copy = copy.deepcopy(model) | |
params, param_names, names_map = extract_weights(model_copy) | |
if disable_autograd_tracking: | |
for param in params: | |
param.requires_grad_(False) | |
return FunctionalModule(model_copy, param_names, names_map), params | |
def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any: | |
# Temporarily load the state back onto self.stateless_model | |
old_state = _swap_state(self.stateless_model, self.names_map, params) | |
try: | |
return self.stateless_model(*args, **kwargs) | |
finally: | |
# Remove the loaded state on self.stateless_model | |
_swap_state(self.stateless_model, self.names_map, old_state) | |
def make_functional( | |
model: nn.Module, disable_autograd_tracking: bool = False | |
) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]: | |
"""make_functional(model, disable_autograd_tracking=False) -> func, params | |
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state | |
(params) and returns a functional version of the model, ``func``. This | |
makes it so that it is possible use transforms over the parameters of | |
``model``. | |
``func`` can be invoked as follows: | |
.. code-block:: python | |
import torch | |
import torch.nn as nn | |
from functorch import make_functional | |
x = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
func, params = make_functional(model) | |
func(params, x) | |
And here is an example of applying the grad transform over the parameters | |
of a model. | |
.. code-block:: python | |
import torch | |
import torch.nn as nn | |
from functorch import make_functional, grad | |
x = torch.randn(4, 3) | |
t = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
func, params = make_functional(model) | |
def compute_loss(params, x, t): | |
y = func(params, x) | |
return nn.functional.mse_loss(y, t) | |
grad_weights = grad(compute_loss)(params, x, t) | |
If the model has any buffers, please use :func:`make_functional_with_buffers` instead. | |
Args: | |
model (torch.nn.Module): Input model. | |
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. | |
The returned params are unrelated to the set of params from the original model. If False (default), | |
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular | |
PyTorch autograd), matching the requires_grad-ness of the params from the original model. | |
Otherwise, the returned params will have ``requires_grad=False``. Default, False. | |
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or | |
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. | |
Otherwise, if you're only planning on using functorch's gradient transforms, | |
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking | |
history with PyTorch autograd. | |
""" | |
buffers = list(model.buffers()) | |
if len(buffers) > 0: | |
raise RuntimeError( | |
"make_functional(model): `model` has buffers. Please use " | |
"make_functional_with_buffers(model) instead." | |
) | |
return FunctionalModule._create_from( | |
model, disable_autograd_tracking=disable_autograd_tracking | |
) | |
def make_functional_with_buffers( | |
model: nn.Module, disable_autograd_tracking: bool = False | |
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers | |
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the | |
state (params and buffers) and returns a functional version of the model | |
``func`` that can be invoked like a function. | |
``func`` can be invoked as follows: | |
.. code-block:: python | |
import torch | |
import torch.nn as nn | |
from functorch import make_functional_with_buffers | |
x = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
func, params, buffers = make_functional_with_buffers(model) | |
func(params, buffers, x) | |
And here is an example of applying the grad transform over the parameters | |
of a model: | |
.. code-block:: python | |
import torch | |
import torch.nn as nn | |
from functorch import make_functional_with_buffers, grad | |
x = torch.randn(4, 3) | |
t = torch.randn(4, 3) | |
model = nn.Linear(3, 3) | |
func, params, buffers = make_functional_with_buffers(model) | |
def compute_loss(params, buffers, x, t): | |
y = func(params, buffers, x) | |
return nn.functional.mse_loss(y, t) | |
grad_weights = grad(compute_loss)(params, buffers, x, t) | |
Args: | |
model (torch.nn.Module): Input model. | |
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters. | |
The returned params are unrelated to the set of params from the original model. If False (default), | |
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular | |
PyTorch autograd), matching the requires_grad-ness of the params from the original model. | |
Otherwise, the returned params will have ``requires_grad=False``. Default, False. | |
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or | |
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``. | |
Otherwise, if you're only planning on using functorch's gradient transforms, | |
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking | |
history with PyTorch autograd. | |
""" | |
return FunctionalModuleWithBuffers._create_from( | |
model, disable_autograd_tracking=disable_autograd_tracking | |
) | |
def transpose_stack( | |
tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...] | |
) -> Tuple[Tensor, ...]: | |
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors)) | |
results = tuple( | |
torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors | |
) | |
return results | |
def combine_state_for_ensemble( | |
models: Sequence[nn.Module], | |
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]: | |
"""combine_state_for_ensemble(models) -> func, params, buffers | |
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`. | |
Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their | |
parameters and buffers together to make ``params`` and ``buffers``. | |
Each parameter and buffer in the result will have an additional dimension | |
of size ``M``. | |
:func:`combine_state_for_ensemble` also returns ``func``, a functional | |
version of one of the models in :attr:`models`. One cannot directly run | |
``func(params, buffers, *args, **kwargs)`` directly, you probably want to | |
use ``vmap(func, ...)(params, buffers, *args, **kwargs)`` | |
Here's an example of how to ensemble over a very simple model: | |
.. code-block:: python | |
num_models = 5 | |
batch_size = 64 | |
in_features, out_features = 3, 3 | |
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)] | |
data = torch.randn(batch_size, 3) | |
fmodel, params, buffers = combine_state_for_ensemble(models) | |
output = vmap(fmodel, (0, 0, None))(params, buffers, data) | |
assert output.shape == (num_models, batch_size, out_features) | |
.. warning:: | |
All of the modules being stacked together must be the same (except for | |
the values of their parameters/buffers). For example, they should be in the | |
same mode (training vs eval). | |
This API is subject to change -- we're investigating better ways to | |
create ensembles and would love your feedback how to improve this. | |
""" | |
if len(models) == 0: | |
raise RuntimeError( | |
"combine_state_for_ensemble: Expected at least one model, got 0." | |
) | |
if not (all(m.training for m in models) or all(not m.training for m in models)): | |
raise RuntimeError( | |
"combine_state_for_ensemble: Expected all models to " | |
"have the same training/eval mode." | |
) | |
model0_typ = type(models[0]) | |
if not all(type(m) == model0_typ for m in models): | |
raise RuntimeError( | |
"combine_state_for_ensemble: Expected all models to be of the same class." | |
) | |
funcs, params, buffers = zip( | |
*[make_functional_with_buffers(model) for model in models] | |
) | |
params = transpose_stack(params) | |
buffers = transpose_stack(buffers) | |
return funcs[0], params, buffers | |
def functional_init( | |
model_class: Type[nn.Module], | |
ensemble_shape: Union[Tuple[()], Tuple[int]] = (), | |
device: torch.types.Device = "cpu", | |
): | |
def wrapped(*args, **kwargs): | |
if len(ensemble_shape) >= 2: | |
raise ValueError("NYI: ensemble_shape with more than 1 element") | |
if len(ensemble_shape) == 0: | |
model = model_class(*args, **kwargs).to(device) | |
return make_functional_deprecated_v1(model) | |
num_models = ensemble_shape[0] # type: ignore[misc] | |
if num_models <= 0: | |
raise ValueError(f"num_models {num_models} should be > 0") | |
# NB: Not very efficient, more of a POC | |
models = tuple( | |
model_class(*args, **kwargs).to(device) for _ in range(num_models) | |
) | |
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs)) | |
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models) | |
weights = tuple(zip(*weights)) | |
weights = tuple(torch.stack(shards).detach() for shards in weights) | |
return weights, fn, names | |
return wrapped | |
def functional_init_with_buffers( | |
model_class: Type[nn.Module], | |
ensemble_shape: Union[Tuple[()], Tuple[int]] = (), | |
device: torch.types.Device = "cpu", | |
): | |
def wrapped(*args, **kwargs): | |
if len(ensemble_shape) >= 2: | |
raise ValueError("NYI: ensemble_shape with more than 1 element") | |
if len(ensemble_shape) == 0: | |
model = model_class(*args, **kwargs).to(device) | |
return make_functional_deprecated_v1(model) | |
num_models = ensemble_shape[0] # type: ignore[misc] | |
if num_models <= 0: | |
raise ValueError(f"num_models {num_models} should be > 0") | |
# NB: Not very efficient, more of a POC | |
models = tuple( | |
model_class(*args, **kwargs).to(device) for _ in range(num_models) | |
) | |
( | |
_, | |
_, | |
fn, | |
weight_names, | |
buffer_names, | |
) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs)) | |
weights, buffers = zip( | |
*tuple( | |
make_functional_with_buffers_deprecated_v1(model)[:2] | |
for model in models | |
) | |
) | |
weights = tuple(zip(*weights)) | |
weights = tuple(torch.stack(shards).detach() for shards in weights) | |
buffers = tuple(zip(*buffers)) | |
buffers = tuple(torch.stack(shards).detach() for shards in buffers) | |
return weights, buffers, fn, weight_names, buffer_names | |
return wrapped | |