reach-vb's picture
reach-vb HF staff
da0609001d38541f2e1d84b2fab95a3e5cb5413337fc2247150c3f19aae1664e
8f05c80
raw
history blame
No virus
4.72 kB
import dis
import inspect
from typing import Sequence, Union
import torch
import functorch._C
from functorch._C import dim as _C
from .tree_map import tree_flatten, tree_map
from .wrap_type import wrap_type
_C._patch_tensor_class()
dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
class DimensionMismatchError(Exception):
pass
class DimensionBindError(Exception):
pass
from . import op_properties
# use dict to avoid writing C++ bindings for set
pointwise = {t: True for t in op_properties.pointwise}
use_c = True
if not use_c:
from . import reference
class _Tensor:
# fast path around slow wrapping/unwrapping logic for simply queries used
# by the implementation...
@property
def dims(self):
return tuple(d for d in self._levels if isinstance(d, Dim))
def dim(self):
return self.ndim
if use_c:
__torch_function__ = classmethod(_C.__torch_function__)
expand = _C._instancemethod(_C.expand)
else:
__torch_function__ = reference.__torch_function__
expand = reference.expand
index = _C._instancemethod(_C.index)
def __repr__(self):
tensor, levels, ndim = self._tensor, self._levels, self.ndim
return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
TensorLike = (_Tensor, torch.Tensor)
class Dim(_C.Dim, _Tensor):
# note that _C.Dim comes before tensor because we want the Dim API for things like size to take precendence.
# Tensor defines format, but we want to print Dims with special formatting
__format__ = object.__format__
class Tensor(_Tensor, _C.Tensor):
if not use_c:
from_batched = staticmethod(_C.Tensor_from_batched)
from_positional = staticmethod(_C.Tensor_from_positional)
sum = _C._instancemethod(_C.Tensor_sum)
def cat(tensors, dim, new_dim):
n = dims()
return stack(tensors, n, dim).index([n, dim], new_dim)
if use_c:
_wrap = _C._wrap
def _def(name, *args, **kwargs):
orig = getattr(torch.Tensor, name)
setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
t__getitem__ = _C._instancemethod(_C.__getitem__)
stack = _C.stack
split = _C._instancemethod(_C.split)
else:
_wrap, _def = reference._wrap, reference._def
t__getitem__ = reference.t__getitem__
stack = reference.stack
split = reference.split
# note: there is no python reference
t__setitem__ = _C._instancemethod(_C.__setitem__)
# this is patched in the C API because otherwise torch.Tensor will
# no longer be considered a sequence and things will break
# torch.Tensor.__getitem__ = t__getitem__
_Tensor.__getitem__ = t__getitem__
# torch.Tensor.__setitem__ = t__setitem__
_Tensor.__setitem__ = t__setitem__
torch.Tensor.split = split
_Tensor.split = split
torch.Tensor.expand = _C._instancemethod(_C.expand)
torch.Tensor.index = _C._instancemethod(_C.index)
wrap_type(use_c, _Tensor, torch.Tensor, _Tensor.__torch_function__)
del _Tensor.ndim
if use_c:
_Tensor.order = _C._instancemethod(_C.order)
else:
_Tensor.order = reference.positional
_def("mean")
_def("sum")
_def("all")
_def("amax")
_def("amin")
_def("aminmax")
_def("any")
_def("count_nonzero")
_def("logsumexp")
_def("nanmean")
_def("nansum")
_def("prod")
_def("std", keepdim_offset=2)
_def("var", keepdim_offset=2)
_def("max", single_dim=True)
_def("min", single_dim=True)
_def("argmax", single_dim=True)
_def("argmin", single_dim=True)
_def("kthvalue", single_dim=True)
_def("median", single_dim=True)
_def("nanmedian", single_dim=True)
_def("mode", single_dim=True)
_def("sort", reduce=False)
_def("argsort", reduce=False)
_def("unbind", single_dim=True)
_def("chunk", dim_offset=1, reduce=False)
_def("cummax", single_dim=True, reduce=False)
_def("cummin", single_dim=True, reduce=False)
_def("cumprod", single_dim=True, reduce=False)
_def("cumprod_", single_dim=True, reduce=False)
_def("cumsum", single_dim=True, reduce=False)
_def("cumsum_", single_dim=True, reduce=False)
_def("logcumsumexp", single_dim=True, reduce=False)
_def("renorm", dim_offset=1, single_dim=True, reduce=False)
_def("softmax", single_dim=True, reduce=False)
softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
# stuff to handle in the future, because they require special
# binding logic for dims
# cross
# diag_embed
# diagonal
# diagonal_scatter
# diff
# nanquantile
# quantile
# roll
# rot90
# topk (new dimes on output)
# should these all be subsumed by inplace indexing?
# index_add_
# index_add
# index_copy
# index_copy_
# index_fill
# index_fill_
# index_select
# scatter
# scatter_
# scatter_add
# scatter_add_
# scatter_reduce