Image-Generator / torch /_numpy /_funcs_impl.py
Adi-69s's picture
Upload 5061 files
b2659ad verified
"""A thin pytorch / numpy compat layer.
Things imported from here have numpy-compatible signatures but operate on
pytorch tensors.
"""
# Contents of this module ends up in the main namespace via _funcs.py
# where type annotations are used in conjunction with the @normalizer decorator.
from __future__ import annotations
import builtins
import itertools
import operator
from typing import Optional, Sequence
import torch
from . import _dtypes_impl, _util
from ._normalizations import (
ArrayLike,
ArrayLikeOrScalar,
CastingModes,
DTypeLike,
NDArray,
NotImplementedType,
OutArray,
)
def copy(
a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False
):
return a.clone()
def copyto(
dst: NDArray,
src: ArrayLike,
casting: Optional[CastingModes] = "same_kind",
where: NotImplementedType = None,
):
(src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting)
dst.copy_(src)
def atleast_1d(*arys: ArrayLike):
res = torch.atleast_1d(*arys)
if isinstance(res, tuple):
return list(res)
else:
return res
def atleast_2d(*arys: ArrayLike):
res = torch.atleast_2d(*arys)
if isinstance(res, tuple):
return list(res)
else:
return res
def atleast_3d(*arys: ArrayLike):
res = torch.atleast_3d(*arys)
if isinstance(res, tuple):
return list(res)
else:
return res
def _concat_check(tup, dtype, out):
if tup == ():
raise ValueError("need at least one array to concatenate")
"""Check inputs in concatenate et al."""
if out is not None and dtype is not None:
# mimic numpy
raise TypeError(
"concatenate() only takes `out` or `dtype` as an "
"argument, but both were provided."
)
def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"):
"""Figure out dtypes, cast if necessary."""
if out is not None or dtype is not None:
# figure out the type of the inputs and outputs
out_dtype = out.dtype.torch_dtype if dtype is None else dtype
else:
out_dtype = _dtypes_impl.result_type_impl(*tensors)
# cast input arrays if necessary; do not broadcast them agains `out`
tensors = _util.typecast_tensors(tensors, out_dtype, casting)
return tensors
def _concatenate(
tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind"
):
# pure torch implementation, used below and in cov/corrcoef below
tensors, axis = _util.axis_none_flatten(*tensors, axis=axis)
tensors = _concat_cast_helper(tensors, out, dtype, casting)
return torch.cat(tensors, axis)
def concatenate(
ar_tuple: Sequence[ArrayLike],
axis=0,
out: Optional[OutArray] = None,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(ar_tuple, dtype, out=out)
result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting)
return result
def vstack(
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(tup, dtype, out=None)
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
return torch.vstack(tensors)
row_stack = vstack
def hstack(
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(tup, dtype, out=None)
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
return torch.hstack(tensors)
def dstack(
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
# XXX: in numpy 1.24 dstack does not have dtype and casting keywords
# but {h,v}stack do. Hence add them here for consistency.
_concat_check(tup, dtype, out=None)
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
return torch.dstack(tensors)
def column_stack(
tup: Sequence[ArrayLike],
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
# XXX: in numpy 1.24 column_stack does not have dtype and casting keywords
# but row_stack does. (because row_stack is an alias for vstack, really).
# Hence add these keywords here for consistency.
_concat_check(tup, dtype, out=None)
tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting)
return torch.column_stack(tensors)
def stack(
arrays: Sequence[ArrayLike],
axis=0,
out: Optional[OutArray] = None,
*,
dtype: Optional[DTypeLike] = None,
casting: Optional[CastingModes] = "same_kind",
):
_concat_check(arrays, dtype, out=out)
tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting)
result_ndim = tensors[0].ndim + 1
axis = _util.normalize_axis_index(axis, result_ndim)
return torch.stack(tensors, axis=axis)
def append(arr: ArrayLike, values: ArrayLike, axis=None):
if axis is None:
if arr.ndim != 1:
arr = arr.flatten()
values = values.flatten()
axis = arr.ndim - 1
return _concatenate((arr, values), axis=axis)
# ### split ###
def _split_helper(tensor, indices_or_sections, axis, strict=False):
if isinstance(indices_or_sections, int):
return _split_helper_int(tensor, indices_or_sections, axis, strict)
elif isinstance(indices_or_sections, (list, tuple)):
# NB: drop split=..., it only applies to split_helper_int
return _split_helper_list(tensor, list(indices_or_sections), axis)
else:
raise TypeError("split_helper: ", type(indices_or_sections))
def _split_helper_int(tensor, indices_or_sections, axis, strict=False):
if not isinstance(indices_or_sections, int):
raise NotImplementedError("split: indices_or_sections")
axis = _util.normalize_axis_index(axis, tensor.ndim)
# numpy: l%n chunks of size (l//n + 1), the rest are sized l//n
l, n = tensor.shape[axis], indices_or_sections
if n <= 0:
raise ValueError()
if l % n == 0:
num, sz = n, l // n
lst = [sz] * num
else:
if strict:
raise ValueError("array split does not result in an equal division")
num, sz = l % n, l // n + 1
lst = [sz] * num
lst += [sz - 1] * (n - num)
return torch.split(tensor, lst, axis)
def _split_helper_list(tensor, indices_or_sections, axis):
if not isinstance(indices_or_sections, list):
raise NotImplementedError("split: indices_or_sections: list")
# numpy expects indices, while torch expects lengths of sections
# also, numpy appends zero-size arrays for indices above the shape[axis]
lst = [x for x in indices_or_sections if x <= tensor.shape[axis]]
num_extra = len(indices_or_sections) - len(lst)
lst.append(tensor.shape[axis])
lst = [
lst[0],
] + [a - b for a, b in zip(lst[1:], lst[:-1])]
lst += [0] * num_extra
return torch.split(tensor, lst, axis)
def array_split(ary: ArrayLike, indices_or_sections, axis=0):
return _split_helper(ary, indices_or_sections, axis)
def split(ary: ArrayLike, indices_or_sections, axis=0):
return _split_helper(ary, indices_or_sections, axis, strict=True)
def hsplit(ary: ArrayLike, indices_or_sections):
if ary.ndim == 0:
raise ValueError("hsplit only works on arrays of 1 or more dimensions")
axis = 1 if ary.ndim > 1 else 0
return _split_helper(ary, indices_or_sections, axis, strict=True)
def vsplit(ary: ArrayLike, indices_or_sections):
if ary.ndim < 2:
raise ValueError("vsplit only works on arrays of 2 or more dimensions")
return _split_helper(ary, indices_or_sections, 0, strict=True)
def dsplit(ary: ArrayLike, indices_or_sections):
if ary.ndim < 3:
raise ValueError("dsplit only works on arrays of 3 or more dimensions")
return _split_helper(ary, indices_or_sections, 2, strict=True)
def kron(a: ArrayLike, b: ArrayLike):
return torch.kron(a, b)
def vander(x: ArrayLike, N=None, increasing=False):
return torch.vander(x, N, increasing)
# ### linspace, geomspace, logspace and arange ###
def linspace(
start: ArrayLike,
stop: ArrayLike,
num=50,
endpoint=True,
retstep=False,
dtype: Optional[DTypeLike] = None,
axis=0,
):
if axis != 0 or retstep or not endpoint:
raise NotImplementedError
if dtype is None:
dtype = _dtypes_impl.default_dtypes().float_dtype
# XXX: raises TypeError if start or stop are not scalars
return torch.linspace(start, stop, num, dtype=dtype)
def geomspace(
start: ArrayLike,
stop: ArrayLike,
num=50,
endpoint=True,
dtype: Optional[DTypeLike] = None,
axis=0,
):
if axis != 0 or not endpoint:
raise NotImplementedError
base = torch.pow(stop / start, 1.0 / (num - 1))
logbase = torch.log(base)
return torch.logspace(
torch.log(start) / logbase,
torch.log(stop) / logbase,
num,
base=base,
)
def logspace(
start,
stop,
num=50,
endpoint=True,
base=10.0,
dtype: Optional[DTypeLike] = None,
axis=0,
):
if axis != 0 or not endpoint:
raise NotImplementedError
return torch.logspace(start, stop, num, base=base, dtype=dtype)
def arange(
start: Optional[ArrayLikeOrScalar] = None,
stop: Optional[ArrayLikeOrScalar] = None,
step: Optional[ArrayLikeOrScalar] = 1,
dtype: Optional[DTypeLike] = None,
*,
like: NotImplementedType = None,
):
if step == 0:
raise ZeroDivisionError
if stop is None and start is None:
raise TypeError
if stop is None:
# XXX: this breaks if start is passed as a kwarg:
# arange(start=4) should raise (no stop) but doesn't
start, stop = 0, start
if start is None:
start = 0
# the dtype of the result
if dtype is None:
dtype = (
_dtypes_impl.default_dtypes().float_dtype
if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step))
else _dtypes_impl.default_dtypes().int_dtype
)
work_dtype = torch.float64 if dtype.is_complex else dtype
# RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager.
if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)):
raise NotImplementedError
if (step > 0 and start > stop) or (step < 0 and start < stop):
# empty range
return torch.empty(0, dtype=dtype)
result = torch.arange(start, stop, step, dtype=work_dtype)
result = _util.cast_if_needed(result, dtype)
return result
# ### zeros/ones/empty/full ###
def empty(
shape,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "C",
*,
like: NotImplementedType = None,
):
if dtype is None:
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.empty(shape, dtype=dtype)
# NB: *_like functions deliberately deviate from numpy: it has subok=True
# as the default; we set subok=False and raise on anything else.
def empty_like(
prototype: ArrayLike,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "K",
subok: NotImplementedType = False,
shape=None,
):
result = torch.empty_like(prototype, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result
def full(
shape,
fill_value: ArrayLike,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "C",
*,
like: NotImplementedType = None,
):
if isinstance(shape, int):
shape = (shape,)
if dtype is None:
dtype = fill_value.dtype
if not isinstance(shape, (tuple, list)):
shape = (shape,)
return torch.full(shape, fill_value, dtype=dtype)
def full_like(
a: ArrayLike,
fill_value,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "K",
subok: NotImplementedType = False,
shape=None,
):
# XXX: fill_value broadcasts
result = torch.full_like(a, fill_value, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result
def ones(
shape,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "C",
*,
like: NotImplementedType = None,
):
if dtype is None:
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.ones(shape, dtype=dtype)
def ones_like(
a: ArrayLike,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "K",
subok: NotImplementedType = False,
shape=None,
):
result = torch.ones_like(a, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result
def zeros(
shape,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "C",
*,
like: NotImplementedType = None,
):
if dtype is None:
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.zeros(shape, dtype=dtype)
def zeros_like(
a: ArrayLike,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "K",
subok: NotImplementedType = False,
shape=None,
):
result = torch.zeros_like(a, dtype=dtype)
if shape is not None:
result = result.reshape(shape)
return result
# ### cov & corrcoef ###
def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True):
"""Prepare inputs for cov and corrcoef."""
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636
if y_tensor is not None:
# make sure x and y are at least 2D
ndim_extra = 2 - x_tensor.ndim
if ndim_extra > 0:
x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape)
if not rowvar and x_tensor.shape[0] != 1:
x_tensor = x_tensor.mT
x_tensor = x_tensor.clone()
ndim_extra = 2 - y_tensor.ndim
if ndim_extra > 0:
y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape)
if not rowvar and y_tensor.shape[0] != 1:
y_tensor = y_tensor.mT
y_tensor = y_tensor.clone()
x_tensor = _concatenate((x_tensor, y_tensor), axis=0)
return x_tensor
def corrcoef(
x: ArrayLike,
y: Optional[ArrayLike] = None,
rowvar=True,
bias=None,
ddof=None,
*,
dtype: Optional[DTypeLike] = None,
):
if bias is not None or ddof is not None:
# deprecated in NumPy
raise NotImplementedError
xy_tensor = _xy_helper_corrcoef(x, y, rowvar)
is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu
if is_half:
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
dtype = torch.float32
xy_tensor = _util.cast_if_needed(xy_tensor, dtype)
result = torch.corrcoef(xy_tensor)
if is_half:
result = result.to(torch.float16)
return result
def cov(
m: ArrayLike,
y: Optional[ArrayLike] = None,
rowvar=True,
bias=False,
ddof=None,
fweights: Optional[ArrayLike] = None,
aweights: Optional[ArrayLike] = None,
*,
dtype: Optional[DTypeLike] = None,
):
m = _xy_helper_corrcoef(m, y, rowvar)
if ddof is None:
ddof = 1 if bias == 0 else 0
is_half = (m.dtype == torch.float16) and m.is_cpu
if is_half:
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
dtype = torch.float32
m = _util.cast_if_needed(m, dtype)
result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights)
if is_half:
result = result.to(torch.float16)
return result
def _conv_corr_impl(a, v, mode):
dt = _dtypes_impl.result_type_impl(a, v)
a = _util.cast_if_needed(a, dt)
v = _util.cast_if_needed(v, dt)
padding = v.shape[0] - 1 if mode == "full" else mode
if padding == "same" and v.shape[0] % 2 == 0:
# UserWarning: Using padding='same' with even kernel lengths and odd
# dilation may require a zero-padded copy of the input be created
# (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.)
raise NotImplementedError("mode='same' and even-length weights")
# NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights
aa = a[None, :]
vv = v[None, None, :]
result = torch.nn.functional.conv1d(aa, vv, padding=padding)
# torch returns a 2D result, numpy returns a 1D array
return result[0, :]
def convolve(a: ArrayLike, v: ArrayLike, mode="full"):
# NumPy: if v is longer than a, the arrays are swapped before computation
if a.shape[0] < v.shape[0]:
a, v = v, a
# flip the weights since numpy does and torch does not
v = torch.flip(v, (0,))
return _conv_corr_impl(a, v, mode)
def correlate(a: ArrayLike, v: ArrayLike, mode="valid"):
v = torch.conj_physical(v)
return _conv_corr_impl(a, v, mode)
# ### logic & element selection ###
def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0):
if x.numel() == 0:
# edge case allowed by numpy
x = x.new_empty(0, dtype=int)
int_dtype = _dtypes_impl.default_dtypes().int_dtype
(x,) = _util.typecast_tensors((x,), int_dtype, casting="safe")
return torch.bincount(x, weights, minlength)
def where(
condition: ArrayLike,
x: Optional[ArrayLikeOrScalar] = None,
y: Optional[ArrayLikeOrScalar] = None,
/,
):
if (x is None) != (y is None):
raise ValueError("either both or neither of x and y should be given")
if condition.dtype != torch.bool:
condition = condition.to(torch.bool)
if x is None and y is None:
result = torch.where(condition)
else:
result = torch.where(condition, x, y)
return result
# ###### module-level queries of object properties
def ndim(a: ArrayLike):
return a.ndim
def shape(a: ArrayLike):
return tuple(a.shape)
def size(a: ArrayLike, axis=None):
if axis is None:
return a.numel()
else:
return a.shape[axis]
# ###### shape manipulations and indexing
def expand_dims(a: ArrayLike, axis):
shape = _util.expand_shape(a.shape, axis)
return a.view(shape) # never copies
def flip(m: ArrayLike, axis=None):
# XXX: semantic difference: np.flip returns a view, torch.flip copies
if axis is None:
axis = tuple(range(m.ndim))
else:
axis = _util.normalize_axis_tuple(axis, m.ndim)
return torch.flip(m, axis)
def flipud(m: ArrayLike):
return torch.flipud(m)
def fliplr(m: ArrayLike):
return torch.fliplr(m)
def rot90(m: ArrayLike, k=1, axes=(0, 1)):
axes = _util.normalize_axis_tuple(axes, m.ndim)
return torch.rot90(m, k, axes)
# ### broadcasting and indices ###
def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False):
return torch.broadcast_to(array, size=shape)
# This is a function from tuples to tuples, so we just reuse it
from torch import broadcast_shapes
def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False):
return torch.broadcast_tensors(*args)
def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"):
ndim = len(xi)
if indexing not in ["xy", "ij"]:
raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.")
s0 = (1,) * ndim
output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)]
if indexing == "xy" and ndim > 1:
# switch first and second axis
output[0] = output[0].reshape((1, -1) + s0[2:])
output[1] = output[1].reshape((-1, 1) + s0[2:])
if not sparse:
# Return the full N-D matrix (not only the 1-D vector)
output = torch.broadcast_tensors(*output)
if copy:
output = [x.clone() for x in output]
return list(output) # match numpy, return a list
def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False):
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791
dimensions = tuple(dimensions)
N = len(dimensions)
shape = (1,) * N
if sparse:
res = tuple()
else:
res = torch.empty((N,) + dimensions, dtype=dtype)
for i, dim in enumerate(dimensions):
idx = torch.arange(dim, dtype=dtype).reshape(
shape[:i] + (dim,) + shape[i + 1 :]
)
if sparse:
res = res + (idx,)
else:
res[i] = idx
return res
# ### tri*-something ###
def tril(m: ArrayLike, k=0):
return torch.tril(m, k)
def triu(m: ArrayLike, k=0):
return torch.triu(m, k)
def tril_indices(n, k=0, m=None):
if m is None:
m = n
return torch.tril_indices(n, m, offset=k)
def triu_indices(n, k=0, m=None):
if m is None:
m = n
return torch.triu_indices(n, m, offset=k)
def tril_indices_from(arr: ArrayLike, k=0):
if arr.ndim != 2:
raise ValueError("input array must be 2-d")
# Return a tensor rather than a tuple to avoid a graphbreak
return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k)
def triu_indices_from(arr: ArrayLike, k=0):
if arr.ndim != 2:
raise ValueError("input array must be 2-d")
# Return a tensor rather than a tuple to avoid a graphbreak
return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k)
def tri(
N,
M=None,
k=0,
dtype: Optional[DTypeLike] = None,
*,
like: NotImplementedType = None,
):
if M is None:
M = N
tensor = torch.ones((N, M), dtype=dtype)
return torch.tril(tensor, diagonal=k)
# ### equality, equivalence, allclose ###
def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False):
dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False):
dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan)
def _tensor_equal(a1, a2, equal_nan=False):
# Implementation of array_equal/array_equiv.
if a1.shape != a2.shape:
return False
cond = a1 == a2
if equal_nan:
cond = cond | (torch.isnan(a1) & torch.isnan(a2))
return cond.all().item()
def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False):
return _tensor_equal(a1, a2, equal_nan=equal_nan)
def array_equiv(a1: ArrayLike, a2: ArrayLike):
# *almost* the same as array_equal: _equiv tries to broadcast, _equal does not
try:
a1_t, a2_t = torch.broadcast_tensors(a1, a2)
except RuntimeError:
# failed to broadcast => not equivalent
return False
return _tensor_equal(a1_t, a2_t)
def nan_to_num(
x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None
):
# work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble'
if x.is_complex():
re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf)
im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf)
return re + 1j * im
else:
return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf)
# ### put/take_along_axis ###
def take(
a: ArrayLike,
indices: ArrayLike,
axis=None,
out: Optional[OutArray] = None,
mode: NotImplementedType = "raise",
):
(a,), axis = _util.axis_none_flatten(a, axis=axis)
axis = _util.normalize_axis_index(axis, a.ndim)
idx = (slice(None),) * axis + (indices, ...)
result = a[idx]
return result
def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis):
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
axis = _util.normalize_axis_index(axis, arr.ndim)
return torch.take_along_dim(arr, indices, axis)
def put(
a: NDArray,
ind: ArrayLike,
v: ArrayLike,
mode: NotImplementedType = "raise",
):
v = v.type(a.dtype)
# If ind is larger than v, expand v to at least the size of ind. Any
# unnecessary trailing elements are then trimmed.
if ind.numel() > v.numel():
ratio = (ind.numel() + v.numel() - 1) // v.numel()
v = v.unsqueeze(0).expand((ratio,) + v.shape)
# Trim unnecessary elements, regardless if v was expanded or not. Note
# np.put() trims v to match ind by default too.
if ind.numel() < v.numel():
v = v.flatten()
v = v[: ind.numel()]
a.put_(ind, v)
return None
def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis):
(arr,), axis = _util.axis_none_flatten(arr, axis=axis)
axis = _util.normalize_axis_index(axis, arr.ndim)
indices, values = torch.broadcast_tensors(indices, values)
values = _util.cast_if_needed(values, arr.dtype)
result = torch.scatter(arr, axis, indices, values)
arr.copy_(result.reshape(arr.shape))
return None
def choose(
a: ArrayLike,
choices: Sequence[ArrayLike],
out: Optional[OutArray] = None,
mode: NotImplementedType = "raise",
):
# First, broadcast elements of `choices`
choices = torch.stack(torch.broadcast_tensors(*choices))
# Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`:
# (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939)
idx_list = [
torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1))
for i, dim in enumerate(choices.shape)
]
idx_list[0] = a
return choices[idx_list].squeeze(0)
# ### unique et al ###
def unique(
ar: ArrayLike,
return_index: NotImplementedType = False,
return_inverse=False,
return_counts=False,
axis=None,
*,
equal_nan: NotImplementedType = True,
):
(ar,), axis = _util.axis_none_flatten(ar, axis=axis)
axis = _util.normalize_axis_index(axis, ar.ndim)
result = torch.unique(
ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis
)
return result
def nonzero(a: ArrayLike):
return torch.nonzero(a, as_tuple=True)
def argwhere(a: ArrayLike):
return torch.argwhere(a)
def flatnonzero(a: ArrayLike):
return torch.flatten(a).nonzero(as_tuple=True)[0]
def clip(
a: ArrayLike,
min: Optional[ArrayLike] = None,
max: Optional[ArrayLike] = None,
out: Optional[OutArray] = None,
):
return torch.clamp(a, min, max)
def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None):
return torch.repeat_interleave(a, repeats, axis)
def tile(A: ArrayLike, reps):
if isinstance(reps, int):
reps = (reps,)
return torch.tile(A, reps)
def resize(a: ArrayLike, new_shape=None):
# implementation vendored from
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497
if new_shape is None:
return a
if isinstance(new_shape, int):
new_shape = (new_shape,)
a = a.flatten()
new_size = 1
for dim_length in new_shape:
new_size *= dim_length
if dim_length < 0:
raise ValueError("all elements of `new_shape` must be non-negative")
if a.numel() == 0 or new_size == 0:
# First case must zero fill. The second would have repeats == 0.
return torch.zeros(new_shape, dtype=a.dtype)
repeats = -(-new_size // a.numel()) # ceil division
a = concatenate((a,) * repeats)[:new_size]
return reshape(a, new_shape)
# ### diag et al ###
def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1):
axis1 = _util.normalize_axis_index(axis1, a.ndim)
axis2 = _util.normalize_axis_index(axis2, a.ndim)
return torch.diagonal(a, offset, axis1, axis2)
def trace(
a: ArrayLike,
offset=0,
axis1=0,
axis2=1,
dtype: Optional[DTypeLike] = None,
out: Optional[OutArray] = None,
):
result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype)
return result
def eye(
N,
M=None,
k=0,
dtype: Optional[DTypeLike] = None,
order: NotImplementedType = "C",
*,
like: NotImplementedType = None,
):
if dtype is None:
dtype = _dtypes_impl.default_dtypes().float_dtype
if M is None:
M = N
z = torch.zeros(N, M, dtype=dtype)
z.diagonal(k).fill_(1)
return z
def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None):
return torch.eye(n, dtype=dtype)
def diag(v: ArrayLike, k=0):
return torch.diag(v, k)
def diagflat(v: ArrayLike, k=0):
return torch.diagflat(v, k)
def diag_indices(n, ndim=2):
idx = torch.arange(n)
return (idx,) * ndim
def diag_indices_from(arr: ArrayLike):
if not arr.ndim >= 2:
raise ValueError("input array must be at least 2-d")
# For more than d=2, the strided formula is only valid for arrays with
# all dimensions equal, so we check first.
s = arr.shape
if s[1:] != s[:-1]:
raise ValueError("All dimensions of input must be of equal length")
return diag_indices(s[0], arr.ndim)
def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False):
if a.ndim < 2:
raise ValueError("array must be at least 2-d")
if val.numel() == 0 and not wrap:
a.fill_diagonal_(val)
return a
if val.ndim == 0:
val = val.unsqueeze(0)
# torch.Tensor.fill_diagonal_ only accepts scalars
# If the size of val is too large, then val is trimmed
if a.ndim == 2:
tall = a.shape[0] > a.shape[1]
# wrap does nothing for wide matrices...
if not wrap or not tall:
# Never wraps
diag = a.diagonal()
diag.copy_(val[: diag.numel()])
else:
# wraps and tall... leaving one empty line between diagonals?!
max_, min_ = a.shape
idx = torch.arange(max_ - max_ // (min_ + 1))
mod = idx % min_
div = idx // min_
a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()]
else:
idx = diag_indices_from(a)
# a.shape = (n, n, ..., n)
a[idx] = val[: a.shape[0]]
return a
def vdot(a: ArrayLike, b: ArrayLike, /):
# 1. torch only accepts 1D arrays, numpy flattens
# 2. torch requires matching dtype, while numpy casts (?)
t_a, t_b = torch.atleast_1d(a, b)
if t_a.ndim > 1:
t_a = t_a.flatten()
if t_b.ndim > 1:
t_b = t_b.flatten()
dtype = _dtypes_impl.result_type_impl(t_a, t_b)
is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu)
is_bool = dtype == torch.bool
# work around torch's "dot" not implemented for 'Half', 'Bool'
if is_half:
dtype = torch.float32
elif is_bool:
dtype = torch.uint8
t_a = _util.cast_if_needed(t_a, dtype)
t_b = _util.cast_if_needed(t_b, dtype)
result = torch.vdot(t_a, t_b)
if is_half:
result = result.to(torch.float16)
elif is_bool:
result = result.to(torch.bool)
return result
def tensordot(a: ArrayLike, b: ArrayLike, axes=2):
if isinstance(axes, (list, tuple)):
axes = [[ax] if isinstance(ax, int) else ax for ax in axes]
target_dtype = _dtypes_impl.result_type_impl(a, b)
a = _util.cast_if_needed(a, target_dtype)
b = _util.cast_if_needed(b, target_dtype)
return torch.tensordot(a, b, dims=axes)
def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
dtype = _dtypes_impl.result_type_impl(a, b)
is_bool = dtype == torch.bool
if is_bool:
dtype = torch.uint8
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
if a.ndim == 0 or b.ndim == 0:
result = a * b
else:
result = torch.matmul(a, b)
if is_bool:
result = result.to(torch.bool)
return result
def inner(a: ArrayLike, b: ArrayLike, /):
dtype = _dtypes_impl.result_type_impl(a, b)
is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu)
is_bool = dtype == torch.bool
if is_half:
# work around torch's "addmm_impl_cpu_" not implemented for 'Half'"
dtype = torch.float32
elif is_bool:
dtype = torch.uint8
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
result = torch.inner(a, b)
if is_half:
result = result.to(torch.float16)
elif is_bool:
result = result.to(torch.bool)
return result
def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None):
return torch.outer(a, b)
def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None):
# implementation vendored from
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685
if axis is not None:
axisa, axisb, axisc = (axis,) * 3
# Check axisa and axisb are within bounds
axisa = _util.normalize_axis_index(axisa, a.ndim)
axisb = _util.normalize_axis_index(axisb, b.ndim)
# Move working axis to the end of the shape
a = torch.moveaxis(a, axisa, -1)
b = torch.moveaxis(b, axisb, -1)
msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)"
if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3):
raise ValueError(msg)
# Create the output array
shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape)
if a.shape[-1] == 3 or b.shape[-1] == 3:
shape += (3,)
# Check axisc is within bounds
axisc = _util.normalize_axis_index(axisc, len(shape))
dtype = _dtypes_impl.result_type_impl(a, b)
cp = torch.empty(shape, dtype=dtype)
# recast arrays as dtype
a = _util.cast_if_needed(a, dtype)
b = _util.cast_if_needed(b, dtype)
# create local aliases for readability
a0 = a[..., 0]
a1 = a[..., 1]
if a.shape[-1] == 3:
a2 = a[..., 2]
b0 = b[..., 0]
b1 = b[..., 1]
if b.shape[-1] == 3:
b2 = b[..., 2]
if cp.ndim != 0 and cp.shape[-1] == 3:
cp0 = cp[..., 0]
cp1 = cp[..., 1]
cp2 = cp[..., 2]
if a.shape[-1] == 2:
if b.shape[-1] == 2:
# a0 * b1 - a1 * b0
cp[...] = a0 * b1 - a1 * b0
return cp
else:
assert b.shape[-1] == 3
# cp0 = a1 * b2 - 0 (a2 = 0)
# cp1 = 0 - a0 * b2 (a2 = 0)
# cp2 = a0 * b1 - a1 * b0
cp0[...] = a1 * b2
cp1[...] = -a0 * b2
cp2[...] = a0 * b1 - a1 * b0
else:
assert a.shape[-1] == 3
if b.shape[-1] == 3:
cp0[...] = a1 * b2 - a2 * b1
cp1[...] = a2 * b0 - a0 * b2
cp2[...] = a0 * b1 - a1 * b0
else:
assert b.shape[-1] == 2
cp0[...] = -a2 * b1
cp1[...] = a2 * b0
cp2[...] = a0 * b1 - a1 * b0
return torch.moveaxis(cp, -1, axisc)
def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False):
# Have to manually normalize *operands and **kwargs, following the NumPy signature
# We have a local import to avoid poluting the global space, as it will be then
# exported in funcs.py
from ._ndarray import ndarray
from ._normalizations import (
maybe_copy_to,
normalize_array_like,
normalize_casting,
normalize_dtype,
wrap_tensors,
)
dtype = normalize_dtype(dtype)
casting = normalize_casting(casting)
if out is not None and not isinstance(out, ndarray):
raise TypeError("'out' must be an array")
if order != "K":
raise NotImplementedError("'order' parameter is not supported.")
# parse arrays and normalize them
sublist_format = not isinstance(operands[0], str)
if sublist_format:
# op, str, op, str ... [sublistout] format: normalize every other argument
# - if sublistout is not given, the length of operands is even, and we pick
# odd-numbered elements, which are arrays.
# - if sublistout is given, the length of operands is odd, we peel off
# the last one, and pick odd-numbered elements, which are arrays.
# Without [:-1], we would have picked sublistout, too.
array_operands = operands[:-1][::2]
else:
# ("ij->", arrays) format
subscripts, array_operands = operands[0], operands[1:]
tensors = [normalize_array_like(op) for op in array_operands]
target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype
# work around 'bmm' not implemented for 'Half' etc
is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors)
if is_half:
target_dtype = torch.float32
is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32]
if is_short_int:
target_dtype = torch.int64
tensors = _util.typecast_tensors(tensors, target_dtype, casting)
from torch.backends import opt_einsum
try:
# set the global state to handle the optimize=... argument, restore on exit
if opt_einsum.is_available():
old_strategy = torch.backends.opt_einsum.strategy
old_enabled = torch.backends.opt_einsum.enabled
# torch.einsum calls opt_einsum.contract_path, which runs into
# https://github.com/dgasmith/opt_einsum/issues/219
# for strategy={True, False}
if optimize is True:
optimize = "auto"
elif optimize is False:
torch.backends.opt_einsum.enabled = False
torch.backends.opt_einsum.strategy = optimize
if sublist_format:
# recombine operands
sublists = operands[1::2]
has_sublistout = len(operands) % 2 == 1
if has_sublistout:
sublistout = operands[-1]
operands = list(itertools.chain(*zip(tensors, sublists)))
if has_sublistout:
operands.append(sublistout)
result = torch.einsum(*operands)
else:
result = torch.einsum(subscripts, *tensors)
finally:
if opt_einsum.is_available():
torch.backends.opt_einsum.strategy = old_strategy
torch.backends.opt_einsum.enabled = old_enabled
result = maybe_copy_to(out, result)
return wrap_tensors(result)
# ### sort and partition ###
def _sort_helper(tensor, axis, kind, order):
if tensor.dtype.is_complex:
raise NotImplementedError(f"sorting {tensor.dtype} is not supported")
(tensor,), axis = _util.axis_none_flatten(tensor, axis=axis)
axis = _util.normalize_axis_index(axis, tensor.ndim)
stable = kind == "stable"
return tensor, axis, stable
def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
# `order` keyword arg is only relevant for structured dtypes; so not supported here.
a, axis, stable = _sort_helper(a, axis, kind, order)
result = torch.sort(a, dim=axis, stable=stable)
return result.values
def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None):
a, axis, stable = _sort_helper(a, axis, kind, order)
return torch.argsort(a, dim=axis, stable=stable)
def searchsorted(
a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None
):
if a.dtype.is_complex:
raise NotImplementedError(f"searchsorted with dtype={a.dtype}")
return torch.searchsorted(a, v, side=side, sorter=sorter)
# ### swap/move/roll axis ###
def moveaxis(a: ArrayLike, source, destination):
source = _util.normalize_axis_tuple(source, a.ndim, "source")
destination = _util.normalize_axis_tuple(destination, a.ndim, "destination")
return torch.moveaxis(a, source, destination)
def swapaxes(a: ArrayLike, axis1, axis2):
axis1 = _util.normalize_axis_index(axis1, a.ndim)
axis2 = _util.normalize_axis_index(axis2, a.ndim)
return torch.swapaxes(a, axis1, axis2)
def rollaxis(a: ArrayLike, axis, start=0):
# Straight vendor from:
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259
#
# Also note this function in NumPy is mostly retained for backwards compat
# (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing)
# so let's not touch it unless hard pressed.
n = a.ndim
axis = _util.normalize_axis_index(axis, n)
if start < 0:
start += n
msg = "'%s' arg requires %d <= %s < %d, but %d was passed in"
if not (0 <= start < n + 1):
raise _util.AxisError(msg % ("start", -n, "start", n + 1, start))
if axis < start:
# it's been removed
start -= 1
if axis == start:
# numpy returns a view, here we try returning the tensor itself
# return tensor[...]
return a
axes = list(range(0, n))
axes.remove(axis)
axes.insert(start, axis)
return a.view(axes)
def roll(a: ArrayLike, shift, axis=None):
if axis is not None:
axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True)
if not isinstance(shift, tuple):
shift = (shift,) * len(axis)
return torch.roll(a, shift, axis)
# ### shape manipulations ###
def squeeze(a: ArrayLike, axis=None):
if axis == ():
result = a
elif axis is None:
result = a.squeeze()
else:
if isinstance(axis, tuple):
result = a
for ax in axis:
result = a.squeeze(ax)
else:
result = a.squeeze(axis)
return result
def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"):
# if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh)
newshape = newshape[0] if len(newshape) == 1 else newshape
return a.reshape(newshape)
# NB: cannot use torch.reshape(a, newshape) above, because of
# (Pdb) torch.reshape(torch.as_tensor([1]), 1)
# *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int
def transpose(a: ArrayLike, axes=None):
# numpy allows both .transpose(sh) and .transpose(*sh)
# also older code uses axes being a list
if axes in [(), None, (None,)]:
axes = tuple(reversed(range(a.ndim)))
elif len(axes) == 1:
axes = axes[0]
return a.permute(axes)
def ravel(a: ArrayLike, order: NotImplementedType = "C"):
return torch.flatten(a)
def diff(
a: ArrayLike,
n=1,
axis=-1,
prepend: Optional[ArrayLike] = None,
append: Optional[ArrayLike] = None,
):
axis = _util.normalize_axis_index(axis, a.ndim)
if n < 0:
raise ValueError(f"order must be non-negative but got {n}")
if n == 0:
# match numpy and return the input immediately
return a
if prepend is not None:
shape = list(a.shape)
shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1
prepend = torch.broadcast_to(prepend, shape)
if append is not None:
shape = list(a.shape)
shape[axis] = append.shape[axis] if append.ndim > 0 else 1
append = torch.broadcast_to(append, shape)
return torch.diff(a, n, axis=axis, prepend=prepend, append=append)
# ### math functions ###
def angle(z: ArrayLike, deg=False):
result = torch.angle(z)
if deg:
result = result * (180 / torch.pi)
return result
def sinc(x: ArrayLike):
return torch.sinc(x)
# NB: have to normalize *varargs manually
def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1):
N = f.ndim # number of dimensions
varargs = _util.ndarrays_to_tensors(varargs)
if axis is None:
axes = tuple(range(N))
else:
axes = _util.normalize_axis_tuple(axis, N)
len_axes = len(axes)
n = len(varargs)
if n == 0:
# no spacing argument - use 1 in all axes
dx = [1.0] * len_axes
elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0):
# single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0)
dx = varargs * len_axes
elif n == len_axes:
# scalar or 1d array for each axis
dx = list(varargs)
for i, distances in enumerate(dx):
distances = torch.as_tensor(distances)
if distances.ndim == 0:
continue
elif distances.ndim != 1:
raise ValueError("distances must be either scalars or 1d")
if len(distances) != f.shape[axes[i]]:
raise ValueError(
"when 1d, distances must match "
"the length of the corresponding dimension"
)
if not (distances.dtype.is_floating_point or distances.dtype.is_complex):
distances = distances.double()
diffx = torch.diff(distances)
# if distances are constant reduce to the scalar case
# since it brings a consistent speedup
if (diffx == diffx[0]).all():
diffx = diffx[0]
dx[i] = diffx
else:
raise TypeError("invalid number of arguments")
if edge_order > 2:
raise ValueError("'edge_order' greater than 2 not supported")
# use central differences on interior and one-sided differences on the
# endpoints. This preserves second order-accuracy over the full domain.
outvals = []
# create slice objects --- initially all are [:, :, ..., :]
slice1 = [slice(None)] * N
slice2 = [slice(None)] * N
slice3 = [slice(None)] * N
slice4 = [slice(None)] * N
otype = f.dtype
if _dtypes_impl.python_type_for_torch(otype) in (int, bool):
# Convert to floating point.
# First check if f is a numpy integer type; if so, convert f to float64
# to avoid modular arithmetic when computing the changes in f.
f = f.double()
otype = torch.float64
for axis, ax_dx in zip(axes, dx):
if f.shape[axis] < edge_order + 1:
raise ValueError(
"Shape of array too small to calculate a numerical gradient, "
"at least (edge_order + 1) elements are required."
)
# result allocation
out = torch.empty_like(f, dtype=otype)
# spacing for the current axis (NB: np.ndim(ax_dx) == 0)
uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0
# Numerical differentiation: 2nd order interior
slice1[axis] = slice(1, -1)
slice2[axis] = slice(None, -2)
slice3[axis] = slice(1, -1)
slice4[axis] = slice(2, None)
if uniform_spacing:
out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx)
else:
dx1 = ax_dx[0:-1]
dx2 = ax_dx[1:]
a = -(dx2) / (dx1 * (dx1 + dx2))
b = (dx2 - dx1) / (dx1 * dx2)
c = dx1 / (dx2 * (dx1 + dx2))
# fix the shape for broadcasting
shape = [1] * N
shape[axis] = -1
a = a.reshape(shape)
b = b.reshape(shape)
c = c.reshape(shape)
# 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:]
out[tuple(slice1)] = (
a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
)
# Numerical differentiation: 1st order edges
if edge_order == 1:
slice1[axis] = 0
slice2[axis] = 1
slice3[axis] = 0
dx_0 = ax_dx if uniform_spacing else ax_dx[0]
# 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0])
out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0
slice1[axis] = -1
slice2[axis] = -1
slice3[axis] = -2
dx_n = ax_dx if uniform_spacing else ax_dx[-1]
# 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2])
out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n
# Numerical differentiation: 2nd order edges
else:
slice1[axis] = 0
slice2[axis] = 0
slice3[axis] = 1
slice4[axis] = 2
if uniform_spacing:
a = -1.5 / ax_dx
b = 2.0 / ax_dx
c = -0.5 / ax_dx
else:
dx1 = ax_dx[0]
dx2 = ax_dx[1]
a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2))
b = (dx1 + dx2) / (dx1 * dx2)
c = -dx1 / (dx2 * (dx1 + dx2))
# 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2]
out[tuple(slice1)] = (
a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
)
slice1[axis] = -1
slice2[axis] = -3
slice3[axis] = -2
slice4[axis] = -1
if uniform_spacing:
a = 0.5 / ax_dx
b = -2.0 / ax_dx
c = 1.5 / ax_dx
else:
dx1 = ax_dx[-2]
dx2 = ax_dx[-1]
a = (dx2) / (dx1 * (dx1 + dx2))
b = -(dx2 + dx1) / (dx1 * dx2)
c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2))
# 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1]
out[tuple(slice1)] = (
a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)]
)
outvals.append(out)
# reset the slice object in this dimension to ":"
slice1[axis] = slice(None)
slice2[axis] = slice(None)
slice3[axis] = slice(None)
slice4[axis] = slice(None)
if len_axes == 1:
return outvals[0]
else:
return outvals
# ### Type/shape etc queries ###
def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None):
if a.is_floating_point():
result = torch.round(a, decimals=decimals)
elif a.is_complex():
# RuntimeError: "round_cpu" not implemented for 'ComplexFloat'
result = torch.complex(
torch.round(a.real, decimals=decimals),
torch.round(a.imag, decimals=decimals),
)
else:
# RuntimeError: "round_cpu" not implemented for 'int'
result = a
return result
around = round
round_ = round
def real_if_close(a: ArrayLike, tol=100):
if not torch.is_complex(a):
return a
if tol > 1:
# Undocumented in numpy: if tol < 1, it's an absolute tolerance!
# Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon
# https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577
tol = tol * torch.finfo(a.dtype).eps
mask = torch.abs(a.imag) < tol
return a.real if mask.all() else a
def real(a: ArrayLike):
return torch.real(a)
def imag(a: ArrayLike):
if a.is_complex():
return a.imag
return torch.zeros_like(a)
def iscomplex(x: ArrayLike):
if torch.is_complex(x):
return x.imag != 0
return torch.zeros_like(x, dtype=torch.bool)
def isreal(x: ArrayLike):
if torch.is_complex(x):
return x.imag == 0
return torch.ones_like(x, dtype=torch.bool)
def iscomplexobj(x: ArrayLike):
return torch.is_complex(x)
def isrealobj(x: ArrayLike):
return not torch.is_complex(x)
def isneginf(x: ArrayLike, out: Optional[OutArray] = None):
return torch.isneginf(x)
def isposinf(x: ArrayLike, out: Optional[OutArray] = None):
return torch.isposinf(x)
def i0(x: ArrayLike):
return torch.special.i0(x)
def isscalar(a):
# We need to use normalize_array_like, but we don't want to export it in funcs.py
from ._normalizations import normalize_array_like
try:
t = normalize_array_like(a)
return t.numel() == 1
except Exception:
return False
# ### Filter windows ###
def hamming(M):
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.hamming_window(M, periodic=False, dtype=dtype)
def hanning(M):
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.hann_window(M, periodic=False, dtype=dtype)
def kaiser(M, beta):
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype)
def blackman(M):
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.blackman_window(M, periodic=False, dtype=dtype)
def bartlett(M):
dtype = _dtypes_impl.default_dtypes().float_dtype
return torch.bartlett_window(M, periodic=False, dtype=dtype)
# ### Dtype routines ###
# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666
array_type = [
[torch.float16, torch.float32, torch.float64],
[None, torch.complex64, torch.complex128],
]
array_precision = {
torch.float16: 0,
torch.float32: 1,
torch.float64: 2,
torch.complex64: 1,
torch.complex128: 2,
}
def common_type(*tensors: ArrayLike):
is_complex = False
precision = 0
for a in tensors:
t = a.dtype
if iscomplexobj(a):
is_complex = True
if not (t.is_floating_point or t.is_complex):
p = 2 # array_precision[_nx.double]
else:
p = array_precision.get(t, None)
if p is None:
raise TypeError("can't get common type for non-numeric array")
precision = builtins.max(precision, p)
if is_complex:
return array_type[1][precision]
else:
return array_type[0][precision]
# ### histograms ###
def histogram(
a: ArrayLike,
bins: ArrayLike = 10,
range=None,
normed=None,
weights: Optional[ArrayLike] = None,
density=None,
):
if normed is not None:
raise ValueError("normed argument is deprecated, use density= instead")
if weights is not None and weights.dtype.is_complex:
raise NotImplementedError("complex weights histogram.")
is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex)
is_w_int = weights is None or not weights.dtype.is_floating_point
if is_a_int:
a = a.double()
if weights is not None:
weights = _util.cast_if_needed(weights, a.dtype)
if isinstance(bins, torch.Tensor):
if bins.ndim == 0:
# bins was a single int
bins = operator.index(bins)
else:
bins = _util.cast_if_needed(bins, a.dtype)
if range is None:
h, b = torch.histogram(a, bins, weight=weights, density=bool(density))
else:
h, b = torch.histogram(
a, bins, range=range, weight=weights, density=bool(density)
)
if not density and is_w_int:
h = h.long()
if is_a_int:
b = b.long()
return h, b
def histogram2d(
x,
y,
bins=10,
range: Optional[ArrayLike] = None,
normed=None,
weights: Optional[ArrayLike] = None,
density=None,
):
# vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821
if len(x) != len(y):
raise ValueError("x and y must have the same length.")
try:
N = len(bins)
except TypeError:
N = 1
if N != 1 and N != 2:
bins = [bins, bins]
h, e = histogramdd((x, y), bins, range, normed, weights, density)
return h, e[0], e[1]
def histogramdd(
sample,
bins=10,
range: Optional[ArrayLike] = None,
normed=None,
weights: Optional[ArrayLike] = None,
density=None,
):
# have to normalize manually because `sample` interpretation differs
# for a list of lists and a 2D array
if normed is not None:
raise ValueError("normed argument is deprecated, use density= instead")
from ._normalizations import normalize_array_like, normalize_seq_array_like
if isinstance(sample, (list, tuple)):
sample = normalize_array_like(sample).T
else:
sample = normalize_array_like(sample)
sample = torch.atleast_2d(sample)
if not (sample.dtype.is_floating_point or sample.dtype.is_complex):
sample = sample.double()
# bins is either an int, or a sequence of ints or a sequence of arrays
bins_is_array = not (
isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins)
)
if bins_is_array:
bins = normalize_seq_array_like(bins)
bins_dtypes = [b.dtype for b in bins]
bins = [_util.cast_if_needed(b, sample.dtype) for b in bins]
if range is not None:
range = range.flatten().tolist()
if weights is not None:
# range=... is required : interleave min and max values per dimension
mm = sample.aminmax(dim=0)
range = torch.cat(mm).reshape(2, -1).T.flatten()
range = tuple(range.tolist())
weights = _util.cast_if_needed(weights, sample.dtype)
w_kwd = {"weight": weights}
else:
w_kwd = {}
h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd)
if bins_is_array:
b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)]
return h, b
# ### odds and ends
def min_scalar_type(a: ArrayLike, /):
# https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288
from ._dtypes import DType
if a.numel() > 1:
# numpy docs: "For non-scalar array a, returns the vector’s dtype unmodified."
return DType(a.dtype)
if a.dtype == torch.bool:
dtype = torch.bool
elif a.dtype.is_complex:
fi = torch.finfo(torch.float32)
fits_in_single = a.dtype == torch.complex64 or (
fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max
)
dtype = torch.complex64 if fits_in_single else torch.complex128
elif a.dtype.is_floating_point:
for dt in [torch.float16, torch.float32, torch.float64]:
fi = torch.finfo(dt)
if fi.min <= a <= fi.max:
dtype = dt
break
else:
# must be integer
for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]:
# Prefer unsigned int where possible, as numpy does.
ii = torch.iinfo(dt)
if ii.min <= a <= ii.max:
dtype = dt
break
return DType(dtype)
def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs):
if mode != "constant":
raise NotImplementedError
value = kwargs.get("constant_values", 0)
# `value` must be a python scalar for torch.nn.functional.pad
typ = _dtypes_impl.python_type_for_torch(array.dtype)
value = typ(value)
pad_width = torch.broadcast_to(pad_width, (array.ndim, 2))
pad_width = torch.flip(pad_width, (0,)).flatten()
return torch.nn.functional.pad(array, tuple(pad_width), value=value)