Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
# Adapted from https://botorch.org/api/_modules/botorch/utils/torch.html | |
# TODO: To be removed once (if) https://github.com/pytorch/pytorch/pull/37385 lands | |
from __future__ import annotations | |
import collections | |
from collections import OrderedDict | |
import torch | |
from torch.nn import Module | |
class BufferDict(Module): | |
r""" | |
Holds buffers in a dictionary. | |
BufferDict can be indexed like a regular Python dictionary, but buffers it contains are properly registered, and | |
will be visible by all Module methods. `torch.nn.BufferDict` is an **ordered** dictionary that respects | |
* the order of insertion, and | |
* in `torch.nn.BufferDict.update`, the order of the merged `OrderedDict` or another `torch.nn.BufferDict` (the | |
argument to `torch.nn.BufferDict.update`). | |
Note that `torch.nn.BufferDict.update` with other unordered mapping types (e.g., Python's plain `dict`) does not | |
preserve the order of the merged mapping. | |
Args: | |
buffers (iterable, optional): | |
a mapping (dictionary) of (string : `torch.Tensor`) or an iterable of key-value pairs of type (string, | |
`torch.Tensor`) | |
```python | |
class MyModule(nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.buffers = nn.BufferDict({"left": torch.randn(5, 10), "right": torch.randn(5, 10)}) | |
def forward(self, x, choice): | |
x = self.buffers[choice].mm(x) | |
return x | |
``` | |
""" | |
def __init__(self, buffers=None, persistent: bool = False): | |
r""" | |
Args: | |
buffers (`dict`): | |
A mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type | |
(string, `torch.Tensor`). | |
""" | |
super().__init__() | |
if buffers is not None: | |
self.update(buffers) | |
self.persistent = persistent | |
def __getitem__(self, key): | |
return self._buffers[key] | |
def __setitem__(self, key, buffer): | |
self.register_buffer(key, buffer, persistent=self.persistent) | |
def __delitem__(self, key): | |
del self._buffers[key] | |
def __len__(self): | |
return len(self._buffers) | |
def __iter__(self): | |
return iter(self._buffers.keys()) | |
def __contains__(self, key): | |
return key in self._buffers | |
def clear(self): | |
"""Remove all items from the BufferDict.""" | |
self._buffers.clear() | |
def pop(self, key): | |
r"""Remove key from the BufferDict and return its buffer. | |
Args: | |
key (`str`): | |
Key to pop from the BufferDict | |
""" | |
v = self[key] | |
del self[key] | |
return v | |
def keys(self): | |
r"""Return an iterable of the BufferDict keys.""" | |
return self._buffers.keys() | |
def items(self): | |
r"""Return an iterable of the BufferDict key/value pairs.""" | |
return self._buffers.items() | |
def values(self): | |
r"""Return an iterable of the BufferDict values.""" | |
return self._buffers.values() | |
def update(self, buffers): | |
r""" | |
Update the `torch.nn.BufferDict` with the key-value pairs from a mapping or an iterable, overwriting existing | |
keys. | |
Note: | |
If `buffers` is an `OrderedDict`, a `torch.nn.BufferDict`, or an iterable of key-value pairs, the order of | |
new elements in it is preserved. | |
Args: | |
buffers (iterable): | |
a mapping (dictionary) from string to `torch.Tensor`, or an iterable of key-value pairs of type | |
(string, `torch.Tensor`). | |
""" | |
if not isinstance(buffers, collections.abc.Iterable): | |
raise TypeError( | |
"BuffersDict.update should be called with an " | |
"iterable of key/value pairs, but got " + type(buffers).__name__ | |
) | |
if isinstance(buffers, collections.abc.Mapping): | |
if isinstance(buffers, (OrderedDict, BufferDict)): | |
for key, buffer in buffers.items(): | |
self[key] = buffer | |
else: | |
for key, buffer in sorted(buffers.items()): | |
self[key] = buffer | |
else: | |
for j, p in enumerate(buffers): | |
if not isinstance(p, collections.abc.Iterable): | |
raise TypeError( | |
"BufferDict update sequence element " | |
"#" + str(j) + " should be Iterable; is" + type(p).__name__ | |
) | |
if not len(p) == 2: | |
raise ValueError( | |
"BufferDict update sequence element " | |
"#" + str(j) + " has length " + str(len(p)) + "; 2 is required" | |
) | |
self[p[0]] = p[1] | |
def extra_repr(self): | |
child_lines = [] | |
for k, p in self._buffers.items(): | |
size_str = "x".join(str(size) for size in p.size()) | |
device_str = "" if not p.is_cuda else f" (GPU {p.get_device()})" | |
parastr = f"Buffer containing: [{torch.typename(p)} of size {size_str}{device_str}]" | |
child_lines.append(" (" + k + "): " + parastr) | |
tmpstr = "\n".join(child_lines) | |
return tmpstr | |
def __call__(self, input): | |
raise RuntimeError("BufferDict should not be called.") | |