|
from .module import Module |
|
|
|
from typing import Tuple, Union |
|
from torch import Tensor |
|
from torch.types import _size |
|
|
|
__all__ = ['Flatten', 'Unflatten'] |
|
|
|
class Flatten(Module): |
|
r""" |
|
Flattens a contiguous range of dims into a tensor. For use with :class:`~nn.Sequential`. |
|
|
|
Shape: |
|
- Input: :math:`(*, S_{\text{start}},..., S_{i}, ..., S_{\text{end}}, *)`,' |
|
where :math:`S_{i}` is the size at dimension :math:`i` and :math:`*` means any |
|
number of dimensions including none. |
|
- Output: :math:`(*, \prod_{i=\text{start}}^{\text{end}} S_{i}, *)`. |
|
|
|
Args: |
|
start_dim: first dim to flatten (default = 1). |
|
end_dim: last dim to flatten (default = -1). |
|
|
|
Examples:: |
|
>>> input = torch.randn(32, 1, 5, 5) |
|
>>> # With default parameters |
|
>>> m = nn.Flatten() |
|
>>> output = m(input) |
|
>>> output.size() |
|
torch.Size([32, 25]) |
|
>>> # With non-default parameters |
|
>>> m = nn.Flatten(0, 2) |
|
>>> output = m(input) |
|
>>> output.size() |
|
torch.Size([160, 5]) |
|
""" |
|
__constants__ = ['start_dim', 'end_dim'] |
|
start_dim: int |
|
end_dim: int |
|
|
|
def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None: |
|
super(Flatten, self).__init__() |
|
self.start_dim = start_dim |
|
self.end_dim = end_dim |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input.flatten(self.start_dim, self.end_dim) |
|
|
|
def extra_repr(self) -> str: |
|
return 'start_dim={}, end_dim={}'.format( |
|
self.start_dim, self.end_dim |
|
) |
|
|
|
|
|
class Unflatten(Module): |
|
r""" |
|
Unflattens a tensor dim expanding it to a desired shape. For use with :class:`~nn.Sequential`. |
|
|
|
* :attr:`dim` specifies the dimension of the input tensor to be unflattened, and it can |
|
be either `int` or `str` when `Tensor` or `NamedTensor` is used, respectively. |
|
|
|
* :attr:`unflattened_size` is the new shape of the unflattened dimension of the tensor and it can be |
|
a `tuple` of ints or a `list` of ints or `torch.Size` for `Tensor` input; a `NamedShape` |
|
(tuple of `(name, size)` tuples) for `NamedTensor` input. |
|
|
|
Shape: |
|
- Input: :math:`(*, S_{\text{dim}}, *)`, where :math:`S_{\text{dim}}` is the size at |
|
dimension :attr:`dim` and :math:`*` means any number of dimensions including none. |
|
- Output: :math:`(*, U_1, ..., U_n, *)`, where :math:`U` = :attr:`unflattened_size` and |
|
:math:`\prod_{i=1}^n U_i = S_{\text{dim}}`. |
|
|
|
Args: |
|
dim (Union[int, str]): Dimension to be unflattened |
|
unflattened_size (Union[torch.Size, Tuple, List, NamedShape]): New shape of the unflattened dimension |
|
|
|
Examples: |
|
>>> input = torch.randn(2, 50) |
|
>>> # With tuple of ints |
|
>>> m = nn.Sequential( |
|
>>> nn.Linear(50, 50), |
|
>>> nn.Unflatten(1, (2, 5, 5)) |
|
>>> ) |
|
>>> output = m(input) |
|
>>> output.size() |
|
torch.Size([2, 2, 5, 5]) |
|
>>> # With torch.Size |
|
>>> m = nn.Sequential( |
|
>>> nn.Linear(50, 50), |
|
>>> nn.Unflatten(1, torch.Size([2, 5, 5])) |
|
>>> ) |
|
>>> output = m(input) |
|
>>> output.size() |
|
torch.Size([2, 2, 5, 5]) |
|
>>> # With namedshape (tuple of tuples) |
|
>>> input = torch.randn(2, 50, names=('N', 'features')) |
|
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5))) |
|
>>> output = unflatten(input) |
|
>>> output.size() |
|
torch.Size([2, 2, 5, 5]) |
|
""" |
|
NamedShape = Tuple[Tuple[str, int]] |
|
|
|
__constants__ = ['dim', 'unflattened_size'] |
|
dim: Union[int, str] |
|
unflattened_size: Union[_size, NamedShape] |
|
|
|
def __init__(self, dim: Union[int, str], unflattened_size: Union[_size, NamedShape]) -> None: |
|
super(Unflatten, self).__init__() |
|
|
|
if isinstance(dim, int): |
|
self._require_tuple_int(unflattened_size) |
|
elif isinstance(dim, str): |
|
self._require_tuple_tuple(unflattened_size) |
|
else: |
|
raise TypeError("invalid argument type for dim parameter") |
|
|
|
self.dim = dim |
|
self.unflattened_size = unflattened_size |
|
|
|
def _require_tuple_tuple(self, input): |
|
if (isinstance(input, tuple)): |
|
for idx, elem in enumerate(input): |
|
if not isinstance(elem, tuple): |
|
raise TypeError("unflattened_size must be tuple of tuples, " + |
|
"but found element of type {} at pos {}".format(type(elem).__name__, idx)) |
|
return |
|
raise TypeError("unflattened_size must be a tuple of tuples, " + |
|
"but found type {}".format(type(input).__name__)) |
|
|
|
def _require_tuple_int(self, input): |
|
if (isinstance(input, (tuple, list))): |
|
for idx, elem in enumerate(input): |
|
if not isinstance(elem, int): |
|
raise TypeError("unflattened_size must be tuple of ints, " + |
|
"but found element of type {} at pos {}".format(type(elem).__name__, idx)) |
|
return |
|
raise TypeError("unflattened_size must be a tuple of ints, but found type {}".format(type(input).__name__)) |
|
|
|
def forward(self, input: Tensor) -> Tensor: |
|
return input.unflatten(self.dim, self.unflattened_size) |
|
|
|
def extra_repr(self) -> str: |
|
return 'dim={}, unflattened_size={}'.format(self.dim, self.unflattened_size) |
|
|