Spaces:
Runtime error
Runtime error
| # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved. | |
| # | |
| # This source code is licensed under the BSD license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from typing import List, Optional, Sequence, Tuple, Union | |
| import torch | |
| from .common import _get_storage_base | |
| def get_stack_strides( | |
| tensors: Sequence[torch.Tensor], dim: int | |
| ) -> Optional[Tuple[int, ...]]: | |
| """ | |
| If the tensors are already stacked on dimension :code:`dim`, \ | |
| returns the strides of the stacked tensors. \ | |
| Otherwise returns :code:`None`. | |
| """ | |
| if len(tensors) <= 1 or dim > tensors[0].ndim: | |
| return None | |
| final_stride = [] | |
| for i in range(tensors[0].ndim + 1): | |
| if i == dim: | |
| final_stride.append( | |
| tensors[1].storage_offset() - tensors[0].storage_offset() | |
| ) | |
| continue | |
| if i > dim: | |
| i -= 1 | |
| final_stride.append(tensors[0].stride(i)) | |
| storage_data_ptr: Optional[int] = None | |
| for i, x in enumerate(tensors[1:]): | |
| # Sanity checks | |
| if x.shape != tensors[0].shape: | |
| return None | |
| if x.stride() != tensors[0].stride(): | |
| return None | |
| if ( | |
| x.storage_offset() | |
| != tensors[0].storage_offset() + (i + 1) * final_stride[dim] | |
| ): | |
| return None | |
| if storage_data_ptr is None: | |
| storage_data_ptr = _get_storage_base(tensors[0]) | |
| # Actual storage check | |
| if _get_storage_base(x) != storage_data_ptr: | |
| return None | |
| return tuple(final_stride) | |
| def _stack_or_none_fw( | |
| tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], | |
| dim: int, | |
| ) -> Optional[torch.Tensor]: | |
| strides = get_stack_strides(tensors, dim) | |
| if strides is not None: | |
| input_shape = list(tensors[0].shape) | |
| input_shape.insert(dim, len(tensors)) | |
| return tensors[0].as_strided(input_shape, strides) | |
| return None | |
| def _stack_fw( | |
| tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]], | |
| dim: int, | |
| ) -> torch.Tensor: | |
| out = _stack_or_none_fw(tensors, dim) | |
| if out is None: | |
| out = torch.stack(tensors, dim=dim) | |
| return out | |
| class _Unbind(torch.autograd.Function): | |
| """ | |
| See function `unbind` | |
| """ | |
| # type: ignore | |
| def forward(ctx, x: torch.Tensor, dim: int): | |
| ctx.dim = dim | |
| return x.unbind(dim) | |
| # type: ignore | |
| def backward(cls, ctx, *tensors: torch.Tensor): | |
| return _stack_fw(tensors, ctx.dim), None | |
| class _StackOrNone(torch.autograd.Function): | |
| """ | |
| See function `stack_or_none` | |
| """ | |
| # type: ignore | |
| def forward(ctx, dim: int, *tensors: torch.Tensor): | |
| ctx.dim = dim | |
| return _stack_or_none_fw(tensors, dim=dim) | |
| # type: ignore | |
| def backward(cls, ctx, grad: torch.Tensor): | |
| return (None, *grad.unbind(dim=ctx.dim)) | |
| def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]: | |
| """ | |
| Does exactly the same as :attr:`torch.unbind` for the forward. | |
| In backward, avoids a :attr:`torch.cat` if the gradients | |
| are already multiple views of the same storage | |
| """ | |
| return _Unbind.apply(x, dim) | |
| def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor: | |
| """ | |
| Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated | |
| without any memory operation. Otherwise returns None. | |
| """ | |
| return _StackOrNone.apply(dim, *tensors) | |