Spaces:
Runtime error
Runtime error
File size: 23,374 Bytes
b2659ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 |
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import copy
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NoReturn,
Sequence,
Tuple,
Type,
Union,
)
import torch
import torch.nn as nn
from torch import Tensor
from torch.nn.utils._named_member_accessor import NamedMemberAccessor
# Utilities to make nn.Module "functional"
# In particular the goal is to be able to provide a function that takes as input
# the parameters and evaluate the nn.Module using fixed inputs.
def raise_parameter_tying_error() -> NoReturn:
raise RuntimeError(
"make_functional(module): we don't yet support models that "
"do parameter tying (also sometimes known as weight sharing). "
"Please try to rewrite your model by replacing all instances of the "
"tied parameter with another and/or comment your support in "
"https://github.com/pytorch/functorch/issues/446"
)
def create_names_map(
named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
tied_named_params: Union[Dict[str, Tensor], Iterable[Tuple[str, Tensor]]],
) -> Dict[str, List[str]]:
"""
named_params is a dictionary of tensors: {'A': A, 'B': B}
tied_named_params is another dictionary of tensors {'A': A, 'B': B, 'B_tied': B}
with potentially tied (or 'duplicated') tensors
This function creates a mapping from the names in named_params to the
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
"""
named_params = dict(named_params)
tied_named_params = dict(tied_named_params)
tensors_dict_keys = set(named_params.keys())
tied_tensors_dict_keys = set(tied_named_params.keys())
assert tensors_dict_keys.issubset(tied_tensors_dict_keys)
tensor_to_mapping: Dict[Tensor, Tuple[str, List[str]]] = {}
for key, tensor in named_params.items():
tensor_to_mapping[tensor] = (key, [])
for key, tensor in tied_named_params.items():
assert tensor in tensor_to_mapping
tensor_to_mapping[tensor][1].append(key)
return dict(tensor_to_mapping.values())
def _extract_members(
mod: nn.Module,
named_members: Callable[..., Iterable[Tuple[str, Tensor]]],
subclass: Callable[[Tensor], Tensor],
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
all_named_members = tuple(named_members(remove_duplicate=False))
unique_named_members = tuple(named_members(remove_duplicate=True))
names_map = create_names_map(unique_named_members, all_named_members)
# Remove all the members in the model
memo = {}
accessor = NamedMemberAccessor(mod)
for name, p in all_named_members:
if p not in memo:
memo[p] = subclass(torch.empty_like(p, device="meta"))
replacement = memo[p]
accessor.set_tensor(name, replacement)
if len(unique_named_members) == 0:
names, params = (), ()
else:
names, params = zip(*unique_named_members) # type: ignore[assignment]
return params, names, names_map
def extract_weights(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
"""
This function removes all the Parameters from the model and
return them as a tuple as well as their original attribute names.
The weights must be re-loaded with `load_weights` before the model
can be used again.
Note that this function modifies the model in place and after this
call, mod.parameters() will be empty.
"""
return _extract_members(mod, mod.named_parameters, nn.Parameter)
def extract_buffers(
mod: nn.Module,
) -> Tuple[Tuple[Tensor, ...], Tuple[str, ...], Dict[str, List[str]]]:
return _extract_members(mod, mod.named_buffers, lambda x: x)
def load_weights(
mod: nn.Module,
names: Sequence[str],
params: Sequence[Tensor],
as_params: bool = False,
) -> None:
"""
Reload a set of weights so that `mod` can be used again to perform a forward pass.
Note that the `params` are regular Tensors (that can have history) and so are left
as Tensors. This means that mod.parameters() will still be empty after this call.
"""
accessor = NamedMemberAccessor(mod)
if as_params:
params = [nn.Parameter(p) for p in params]
accessor.set_tensors(names, params)
def _swap_state(
mod: nn.Module, names_map: Dict[str, List[str]], elems: Iterable[Tensor]
) -> List[Tensor]:
result: List[Tensor] = []
accessor = NamedMemberAccessor(mod)
for (_, attr_names), elem in zip(names_map.items(), elems):
for i, attr_name in enumerate(attr_names):
if i == 0:
result.append(accessor.swap_tensor(attr_name, elem))
else:
accessor.set_tensor(attr_name, elem)
return result
def load_buffers(
mod: nn.Module,
names: Sequence[str],
buffers: Sequence[Tensor],
as_params: bool = False,
) -> None:
accessor = NamedMemberAccessor(mod)
accessor.set_tensors(names, buffers)
def load_state(
model: nn.Module,
weights: Sequence[Tensor],
weight_names: Sequence[str],
buffers: Sequence[Tensor] = (),
buffer_names: Sequence[str] = (),
) -> nn.Module:
"""load_state(model, weights, weight_names, buffers=(), buffer_names=()) -> model
load_state takes `weights` and `buffers` and assigns them to the model.
This is the inverse operation of `make_functional_deprecated_v1`.
"""
assert len(weight_names) == len(weights)
load_weights(model, weight_names, weights)
if len(buffers) > 0:
assert len(buffer_names) == len(buffers)
load_buffers(model, buffer_names, buffers)
return model
def make_functional_deprecated_v1(model: nn.Module):
"""make_functional_deprecated_v1(model) -> weights, func, weight_names
Given an nn.Module, make_functional_deprecated_v1 extracts the state (weights)
and returns a functional version of the model, `func`. This makes
it so that it is possible use transforms over the parameters of
`model`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, func, _ = make_functional_deprecated_v1(model)
func(weights, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, _, func = make_functional_deprecated_v1(model)
grad_weights = grad(func)(weights, (x,))
```
To put the state back into a model, use `load_state`.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError(
"make_functional_deprecated_v1(model): `model` has buffers. Please use "
"make_functional_with_buffers_deprecated_v1(model) instead."
)
weights, descriptors, _ = extract_weights(model)
def fun(weights, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, descriptors, weights)
return mutable_model(*data)
return weights, fun, descriptors
def make_functional_with_buffers_deprecated_v1(model: nn.Module):
"""make_functional_with_buffers_deprecated_v1(model) -> weights, buffers, func, weight_names, buffer_names
Given an nn.Module, make_functional_with_buffers_deprecated_v1 extracts the state (weights and buffers)
and returns a functional version of the model, `func`.
`func` can be invoked as follows:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
```
And here is an example of applying the grad transform:
```
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
weights, buffers, func, _, _ = make_functional_with_buffers_deprecated_v1(model)
func(weights, buffers, (x,))
grad_weights = grad(func)(weights, buffers, (x,))
```
To put the state back into a model, use `load_state`.
"""
weights, weight_descriptors, _ = extract_weights(model)
buffers, buf_descriptors, _ = extract_buffers(model)
def fun(weights, buffers, data):
mutable_model = copy.deepcopy(model)
load_weights(mutable_model, weight_descriptors, weights)
load_buffers(mutable_model, buf_descriptors, buffers)
return mutable_model(*data)
return weights, buffers, fun, weight_descriptors, buf_descriptors
class FunctionalModuleWithBuffers(nn.Module):
"""
This is the callable object returned by :func:`make_functional_with_buffers`.
"""
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
buffer_names: Tuple[str, ...],
param_names_map: Dict[str, List[str]],
buffer_names_map: Dict[str, List[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.buffer_names = buffer_names
self.all_names_map = dict(param_names_map)
self.all_names_map.update(buffer_names_map)
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModuleWithBuffers", Tuple[Tensor, ...], Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, param_names_map = extract_weights(model_copy)
buffers, buffer_names, buffer_names_map = extract_buffers(model_copy)
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return (
FunctionalModuleWithBuffers(
model_copy, param_names, buffer_names, param_names_map, buffer_names_map
),
params,
buffers,
)
def forward(
self, params: Iterable[Tensor], buffers: Iterable[Tensor], *args, **kwargs
) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(
self.stateless_model,
self.all_names_map,
tuple(params) + tuple(buffers),
)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.all_names_map, old_state)
class FunctionalModule(nn.Module):
"""
This is the callable object returned by :func:`make_functional`.
"""
def __init__(
self,
stateless_model: nn.Module,
param_names: Tuple[str, ...],
names_map: Dict[str, List[str]],
) -> None:
super().__init__()
self.stateless_model = stateless_model
self.param_names = param_names
self.names_map = names_map
@staticmethod
def _create_from(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple["FunctionalModule", Tuple[Tensor, ...]]:
# TODO: We don't need to copy the model to create a stateless copy
model_copy = copy.deepcopy(model)
params, param_names, names_map = extract_weights(model_copy)
if disable_autograd_tracking:
for param in params:
param.requires_grad_(False)
return FunctionalModule(model_copy, param_names, names_map), params
def forward(self, params: Iterable[Tensor], *args, **kwargs) -> Any:
# Temporarily load the state back onto self.stateless_model
old_state = _swap_state(self.stateless_model, self.names_map, params)
try:
return self.stateless_model(*args, **kwargs)
finally:
# Remove the loaded state on self.stateless_model
_swap_state(self.stateless_model, self.names_map, old_state)
def make_functional(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple[FunctionalModule, Tuple[Tensor, ...]]:
"""make_functional(model, disable_autograd_tracking=False) -> func, params
Given a ``torch.nn.Module``, :func:`make_functional` extracts the state
(params) and returns a functional version of the model, ``func``. This
makes it so that it is possible use transforms over the parameters of
``model``.
``func`` can be invoked as follows:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
func(params, x)
And here is an example of applying the grad transform over the parameters
of a model.
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params = make_functional(model)
def compute_loss(params, x, t):
y = func(params, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, x, t)
If the model has any buffers, please use :func:`make_functional_with_buffers` instead.
Args:
model (torch.nn.Module): Input model.
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
The returned params are unrelated to the set of params from the original model. If False (default),
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
Otherwise, if you're only planning on using functorch's gradient transforms,
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
history with PyTorch autograd.
"""
buffers = list(model.buffers())
if len(buffers) > 0:
raise RuntimeError(
"make_functional(model): `model` has buffers. Please use "
"make_functional_with_buffers(model) instead."
)
return FunctionalModule._create_from(
model, disable_autograd_tracking=disable_autograd_tracking
)
def make_functional_with_buffers(
model: nn.Module, disable_autograd_tracking: bool = False
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
"""make_functional_with_buffers(model, disable_autograd_tracking=False) -> func, params, buffers
Given a ``torch.nn.Module``, make_functional_with_buffers extracts the
state (params and buffers) and returns a functional version of the model
``func`` that can be invoked like a function.
``func`` can be invoked as follows:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers
x = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
func(params, buffers, x)
And here is an example of applying the grad transform over the parameters
of a model:
.. code-block:: python
import torch
import torch.nn as nn
from functorch import make_functional_with_buffers, grad
x = torch.randn(4, 3)
t = torch.randn(4, 3)
model = nn.Linear(3, 3)
func, params, buffers = make_functional_with_buffers(model)
def compute_loss(params, buffers, x, t):
y = func(params, buffers, x)
return nn.functional.mse_loss(y, t)
grad_weights = grad(compute_loss)(params, buffers, x, t)
Args:
model (torch.nn.Module): Input model.
disable_autograd_tracking (bool): Flag to disable gradients tracking for output parameters.
The returned params are unrelated to the set of params from the original model. If False (default),
the params will have ``requires_grad=True`` on them (aka they will be trackable with regular
PyTorch autograd), matching the requires_grad-ness of the params from the original model.
Otherwise, the returned params will have ``requires_grad=False``. Default, False.
If you plan on using regular PyTorch autograd (e.g., if you want to call ``.backward()`` or
``torch.autograd.grad()``, then set ``disable_autograd_tracking=False``.
Otherwise, if you're only planning on using functorch's gradient transforms,
then please set ``disable_autograd_tracking=True`` to avoid unnecessarily tracking
history with PyTorch autograd.
"""
return FunctionalModuleWithBuffers._create_from(
model, disable_autograd_tracking=disable_autograd_tracking
)
def transpose_stack(
tuple_of_tuple_of_tensors: Tuple[Tuple[Tensor, ...], ...]
) -> Tuple[Tensor, ...]:
tuple_of_tuple_of_tensors = tuple(zip(*tuple_of_tuple_of_tensors))
results = tuple(
torch.stack(shards).detach() for shards in tuple_of_tuple_of_tensors
)
return results
def combine_state_for_ensemble(
models: Sequence[nn.Module],
) -> Tuple[FunctionalModuleWithBuffers, Tuple[Tensor, ...], Tuple[Tensor, ...]]:
"""combine_state_for_ensemble(models) -> func, params, buffers
Prepares a list of torch.nn.Modules for ensembling with :func:`vmap`.
Given a list of ``M`` ``nn.Modules`` of the same class, stacks all of their
parameters and buffers together to make ``params`` and ``buffers``.
Each parameter and buffer in the result will have an additional dimension
of size ``M``.
:func:`combine_state_for_ensemble` also returns ``func``, a functional
version of one of the models in :attr:`models`. One cannot directly run
``func(params, buffers, *args, **kwargs)`` directly, you probably want to
use ``vmap(func, ...)(params, buffers, *args, **kwargs)``
Here's an example of how to ensemble over a very simple model:
.. code-block:: python
num_models = 5
batch_size = 64
in_features, out_features = 3, 3
models = [torch.nn.Linear(in_features, out_features) for i in range(num_models)]
data = torch.randn(batch_size, 3)
fmodel, params, buffers = combine_state_for_ensemble(models)
output = vmap(fmodel, (0, 0, None))(params, buffers, data)
assert output.shape == (num_models, batch_size, out_features)
.. warning::
All of the modules being stacked together must be the same (except for
the values of their parameters/buffers). For example, they should be in the
same mode (training vs eval).
This API is subject to change -- we're investigating better ways to
create ensembles and would love your feedback how to improve this.
"""
if len(models) == 0:
raise RuntimeError(
"combine_state_for_ensemble: Expected at least one model, got 0."
)
if not (all(m.training for m in models) or all(not m.training for m in models)):
raise RuntimeError(
"combine_state_for_ensemble: Expected all models to "
"have the same training/eval mode."
)
model0_typ = type(models[0])
if not all(type(m) == model0_typ for m in models):
raise RuntimeError(
"combine_state_for_ensemble: Expected all models to be of the same class."
)
funcs, params, buffers = zip(
*[make_functional_with_buffers(model) for model in models]
)
params = transpose_stack(params)
buffers = transpose_stack(buffers)
return funcs[0], params, buffers
def functional_init(
model_class: Type[nn.Module],
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
device: torch.types.Device = "cpu",
):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError("NYI: ensemble_shape with more than 1 element")
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0] # type: ignore[misc]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(
model_class(*args, **kwargs).to(device) for _ in range(num_models)
)
_, fn, names = make_functional_deprecated_v1(model_class(*args, **kwargs))
weights = tuple(make_functional_deprecated_v1(model)[0] for model in models)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
return weights, fn, names
return wrapped
def functional_init_with_buffers(
model_class: Type[nn.Module],
ensemble_shape: Union[Tuple[()], Tuple[int]] = (),
device: torch.types.Device = "cpu",
):
def wrapped(*args, **kwargs):
if len(ensemble_shape) >= 2:
raise ValueError("NYI: ensemble_shape with more than 1 element")
if len(ensemble_shape) == 0:
model = model_class(*args, **kwargs).to(device)
return make_functional_deprecated_v1(model)
num_models = ensemble_shape[0] # type: ignore[misc]
if num_models <= 0:
raise ValueError(f"num_models {num_models} should be > 0")
# NB: Not very efficient, more of a POC
models = tuple(
model_class(*args, **kwargs).to(device) for _ in range(num_models)
)
(
_,
_,
fn,
weight_names,
buffer_names,
) = make_functional_with_buffers_deprecated_v1(model_class(*args, **kwargs))
weights, buffers = zip(
*tuple(
make_functional_with_buffers_deprecated_v1(model)[:2]
for model in models
)
)
weights = tuple(zip(*weights))
weights = tuple(torch.stack(shards).detach() for shards in weights)
buffers = tuple(zip(*buffers))
buffers = tuple(torch.stack(shards).detach() for shards in buffers)
return weights, buffers, fn, weight_names, buffer_names
return wrapped
|