FlexTheAi's picture
Upload folder using huggingface_hub
e202b16 verified
raw
history blame
3.56 kB
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
#
# This source code is licensed under the BSD license found in the
# LICENSE file in the root directory of this source tree.
from typing import List, Optional, Sequence, Tuple, Union
import torch
from .common import _get_storage_base
def get_stack_strides(
tensors: Sequence[torch.Tensor], dim: int
) -> Optional[Tuple[int, ...]]:
"""
If the tensors are already stacked on dimension :code:`dim`, \
returns the strides of the stacked tensors. \
Otherwise returns :code:`None`.
"""
if len(tensors) <= 1 or dim > tensors[0].ndim:
return None
final_stride = []
for i in range(tensors[0].ndim + 1):
if i == dim:
final_stride.append(
tensors[1].storage_offset() - tensors[0].storage_offset()
)
continue
if i > dim:
i -= 1
final_stride.append(tensors[0].stride(i))
storage_data_ptr: Optional[int] = None
for i, x in enumerate(tensors[1:]):
# Sanity checks
if x.shape != tensors[0].shape:
return None
if x.stride() != tensors[0].stride():
return None
if (
x.storage_offset()
!= tensors[0].storage_offset() + (i + 1) * final_stride[dim]
):
return None
if storage_data_ptr is None:
storage_data_ptr = _get_storage_base(tensors[0])
# Actual storage check
if _get_storage_base(x) != storage_data_ptr:
return None
return tuple(final_stride)
def _stack_or_none_fw(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
dim: int,
) -> Optional[torch.Tensor]:
strides = get_stack_strides(tensors, dim)
if strides is not None:
input_shape = list(tensors[0].shape)
input_shape.insert(dim, len(tensors))
return tensors[0].as_strided(input_shape, strides)
return None
def _stack_fw(
tensors: Union[Tuple[torch.Tensor, ...], List[torch.Tensor]],
dim: int,
) -> torch.Tensor:
out = _stack_or_none_fw(tensors, dim)
if out is None:
out = torch.stack(tensors, dim=dim)
return out
class _Unbind(torch.autograd.Function):
"""
See function `unbind`
"""
@staticmethod
# type: ignore
def forward(ctx, x: torch.Tensor, dim: int):
ctx.dim = dim
return x.unbind(dim)
@classmethod
# type: ignore
def backward(cls, ctx, *tensors: torch.Tensor):
return _stack_fw(tensors, ctx.dim), None
class _StackOrNone(torch.autograd.Function):
"""
See function `stack_or_none`
"""
@staticmethod
# type: ignore
def forward(ctx, dim: int, *tensors: torch.Tensor):
ctx.dim = dim
return _stack_or_none_fw(tensors, dim=dim)
@classmethod
# type: ignore
def backward(cls, ctx, grad: torch.Tensor):
return (None, *grad.unbind(dim=ctx.dim))
def unbind(x: torch.Tensor, dim: int) -> Tuple[torch.Tensor, ...]:
"""
Does exactly the same as :attr:`torch.unbind` for the forward.
In backward, avoids a :attr:`torch.cat` if the gradients
are already multiple views of the same storage
"""
return _Unbind.apply(x, dim)
def stack_or_none(tensors: Sequence[torch.Tensor], dim: int) -> torch.Tensor:
"""
Does exactly the same as :attr:`torch.stack` if the tensors can be concatenated
without any memory operation. Otherwise returns None.
"""
return _StackOrNone.apply(dim, *tensors)