Spaces:
Running
Running
import contextlib | |
import warnings | |
from collections import defaultdict | |
from typing import Any, Dict, Iterator, Optional, Set, Tuple, Union | |
import torch | |
from torch import Tensor | |
from torch.nn.utils._named_member_accessor import NamedMemberAccessor | |
__all__ = ["functional_call"] | |
def _untie_named_tensors_map( | |
module: "torch.nn.Module", | |
parameters_and_buffers: Dict[str, Tensor], | |
) -> Dict[str, Tensor]: | |
""" | |
Unties all tied tensors in the module to parameters_and_buffers. | |
This function returns a new untied_parameters_and_buffers dictionary and leave the original | |
untied_parameters_and_buffers dictionary unchanged. It adds new (missing) keys for tied tensors | |
in the module to untied_parameters_and_buffers. The value of the new key is the user-given value | |
in the original parameters_and_buffers dictionary. | |
If there are more than one user-given values for the same tied tensor, it will raise an error. | |
For example, if the module has two tied weights self.foo and self.tied_foo and the user passes | |
{'foo': foo_value, ...}, this will return {'foo': foo_value, 'tied_foo': foo_value, ...}. If the | |
user passes {'foo': foo_value, 'tied_foo': tied_foo_value, ...}, it will raise an error. If the | |
user passes {'foo': foo_value, 'tied_foo': foo_value, ...}, it will not raise an error. | |
Args: | |
module (torch.nn.Module): the module to determine which tensors are tied. | |
parameters_and_buffers (Dict[str, Tensor]): a map of {name: tensor} for reparamaterizing the module. | |
Returns: | |
A new untied version of the parameters_and_buffers dictionary. | |
Raises: | |
ValueError: if there are more than one user-given values for the same tied tensor. | |
""" | |
# A map of {name: tensor} for all tensors (including tied ones) in the module. | |
all_named_tensors: Dict[str, Tensor] = {} | |
all_named_tensors.update(module.named_parameters(remove_duplicate=False)) | |
all_named_tensors.update(module.named_buffers(remove_duplicate=False)) | |
# A map of {tensor: set(all_tied_names)} for all tensor names in the module. | |
tensor_to_tied_names_map: Dict[Tensor, Set[str]] = defaultdict(set) | |
for name, tensor in all_named_tensors.items(): | |
tensor_to_tied_names_map[tensor].add(name) | |
# A map of {tied_name: set(all_tied_names)} for all tensor names in the module. | |
# If a name is not tied, it will not be in this map. | |
tied_names_map: Dict[str, Set[str]] = {} | |
for tied_names in tensor_to_tied_names_map.values(): | |
if len(tied_names) > 1: | |
for tied_name in tied_names: | |
tied_names_map[tied_name] = tied_names | |
# Make sure the user didn't pass multiple values for the same tied tensor. | |
given_names = set(parameters_and_buffers.keys()) | |
given_names_for_tied_tensors = given_names.intersection(tied_names_map.keys()) | |
for given_name in given_names_for_tied_tensors: | |
tied_names = tied_names_map[given_name] | |
if ( | |
# Detect if there are multiple keys present for the same tied tensor. | |
len(tied_names.intersection(given_names_for_tied_tensors)) > 1 | |
# Only raise an error if the user passed multiple values for the same tied tensor. | |
# If all given values are the same, don't raise. | |
and len({parameters_and_buffers[tied_name] for tied_name in tied_names}) | |
!= 1 | |
): | |
raise ValueError( | |
f"functional_call got multiple values for keys {sorted(tied_names)}, " | |
f"which are tied. Consider using tie_weights=False" | |
) | |
# Untie the given named tensor map | |
# Make a copy for not modifying the original dict | |
untied_parameters_and_buffers = parameters_and_buffers.copy() | |
for given_name in given_names_for_tied_tensors: | |
for tied_name in tied_names_map[given_name]: | |
untied_parameters_and_buffers[tied_name] = parameters_and_buffers[ | |
given_name | |
] | |
return untied_parameters_and_buffers | |
def _reparametrize_module( | |
module: "torch.nn.Module", | |
parameters_and_buffers: Dict[str, Tensor], | |
*, | |
tie_weights: bool = False, | |
strict: bool = False, | |
) -> Iterator[None]: | |
if tie_weights: | |
untied_parameters_and_buffers = _untie_named_tensors_map( | |
module, parameters_and_buffers | |
) | |
else: | |
untied_parameters_and_buffers = parameters_and_buffers | |
accessor = NamedMemberAccessor(module) | |
if strict: | |
missing_keys, unexpected_keys = accessor.check_keys( | |
untied_parameters_and_buffers | |
) | |
error_msgs = [] | |
if len(unexpected_keys) > 0: | |
error_msgs.append( | |
f"Unexpected key(s): {', '.join(map(repr, unexpected_keys))}." | |
) | |
if len(missing_keys) > 0: | |
error_msgs.append(f"Missing key(s): {', '.join(map(repr, missing_keys))}.") | |
if len(error_msgs) > 0: | |
raise RuntimeError( | |
"Error(s) in reparametrizing for {}:\n\t{}".format( | |
module._get_name(), "\n\t".join(error_msgs) | |
) | |
) | |
orig_parameters_and_buffers: Dict[str, Tensor] = {} | |
try: | |
orig_parameters_and_buffers, _ = accessor.swap_tensors_dict( | |
untied_parameters_and_buffers, allow_missing=True | |
) | |
yield | |
finally: | |
new_parameters_and_buffers, _ = accessor.swap_tensors_dict( | |
orig_parameters_and_buffers, allow_missing=True | |
) | |
# Sometimes the module is not completely stateless and has some in-place modifications on | |
# the _parameters and _buffers dictionaries. | |
# Write the changed parameters and buffers back to the original dict. | |
parameters_and_buffers.update( | |
{ | |
k: new_parameters_and_buffers[k] | |
for k in parameters_and_buffers | |
if k in new_parameters_and_buffers | |
} | |
) | |
def functional_call( | |
module: "torch.nn.Module", | |
parameters_and_buffers: Dict[str, Tensor], | |
args: Union[Any, Tuple], | |
kwargs: Optional[Dict[str, Any]] = None, | |
*, | |
tie_weights: bool = True, | |
strict: bool = False, | |
): | |
r"""Perform a functional call on the module by replacing the module parameters and buffers with the provided ones. | |
.. warning:: | |
This API is deprecated as of PyTorch 2.0 and will be removed in a future | |
version of PyTorch. Please use :func:`torch.func.functional_call` instead, | |
which is a drop-in replacement for this API. | |
.. note:: If the module has active parametrizations, passing a value in the | |
:attr:`parameters_and_buffers` argument with the name set to the regular parameter | |
name will completely disable the parametrization. | |
If you want to apply the parametrization function to the value passed | |
please set the key as ``{submodule_name}.parametrizations.{parameter_name}.original``. | |
.. note:: If the module performs in-place operations on parameters/buffers, these will be reflected | |
in the `parameters_and_buffers` input. | |
Example:: | |
>>> a = {'foo': torch.zeros(())} | |
>>> # xdoctest: +SKIP | |
>>> mod = Foo() # does self.foo = self.foo + 1 | |
>>> print(mod.foo) # tensor(0.) | |
>>> functional_call(mod, a, torch.ones(())) | |
>>> print(mod.foo) # tensor(0.) | |
>>> print(a['foo']) # tensor(1.) | |
.. note:: If the module has tied weights, whether or not functional_call respects the tying is determined by the | |
tie_weights flag. | |
Example:: | |
>>> a = {'foo': torch.zeros(())} | |
>>> # xdoctest: +SKIP | |
>>> mod = Foo() # has both self.foo and self.foo_tied which are tied. Returns x + self.foo + self.foo_tied | |
>>> print(mod.foo) # tensor(1.) | |
>>> mod(torch.zeros(())) # tensor(2.) | |
>>> functional_call(mod, a, torch.zeros(())) # tensor(0.) since it will change self.foo_tied too | |
>>> functional_call(mod, a, torch.zeros(()), tie_weights=False) # tensor(1.)--self.foo_tied is not updated | |
>>> new_a = {'foo': torch.zeros(()), 'foo_tied': torch.zeros(())} | |
>>> functional_call(mod, new_a, torch.zeros()) # tensor(0.) | |
Args: | |
module (torch.nn.Module): the module to call | |
parameters_and_buffers (dict of str and Tensor): the parameters that will be used in | |
the module call. | |
args (Any or tuple): arguments to be passed to the module call. If not a tuple, considered a single argument. | |
kwargs (dict): keyword arguments to be passed to the module call | |
tie_weights (bool, optional): If True, then parameters and buffers tied in the original model will be treated as | |
tied in the reparamaterized version. Therefore, if True and different values are passed for the tied | |
parameters and buffers, it will error. If False, it will not respect the originally tied parameters and | |
buffers unless the values passed for both weights are the same. Default: True. | |
strict (bool, optional): If True, then the parameters and buffers passed in must match the parameters and | |
buffers in the original module. Therefore, if True and there are any missing or unexpected keys, it will | |
error. Default: False. | |
Returns: | |
Any: the result of calling ``module``. | |
""" | |
warnings.warn( | |
"This API is deprecated as of PyTorch 2.0 and will be removed in a future " | |
"version of PyTorch. Please use torch.func.functional_call instead " | |
"which is a drop-in replacement for this API." | |
) | |
return _functional_call( | |
module, | |
parameters_and_buffers, | |
args, | |
kwargs, | |
tie_weights=tie_weights, | |
strict=strict, | |
) | |
def _functional_call( | |
module: "torch.nn.Module", | |
parameters_and_buffers: Dict[str, Tensor], | |
args: Union[Any, Tuple], | |
kwargs: Optional[Dict[str, Any]] = None, | |
*, | |
tie_weights: bool = True, | |
strict: bool = False, | |
): | |
# TODO allow kwargs such as unsafe and others for parametrization | |
if ( | |
torch.jit.is_tracing() | |
or torch.jit.is_scripting() | |
or isinstance( | |
module, | |
( | |
torch.jit.RecursiveScriptModule, | |
torch.jit.ScriptModule, | |
torch.jit.ScriptFunction, | |
), | |
) | |
): | |
raise RuntimeError("The stateless API can't be used with Jitted modules") | |
if isinstance(module, torch.nn.DataParallel): | |
raise RuntimeError( | |
"The stateless API can't be used with nn.DataParallel module" | |
) | |
if kwargs is None: | |
kwargs = {} | |
if not isinstance(args, tuple): | |
args = (args,) | |
with _reparametrize_module( | |
module, parameters_and_buffers, tie_weights=tie_weights, strict=strict | |
): | |
return module(*args, **kwargs) | |