"""A thin pytorch / numpy compat layer. Things imported from here have numpy-compatible signatures but operate on pytorch tensors. """ # Contents of this module ends up in the main namespace via _funcs.py # where type annotations are used in conjunction with the @normalizer decorator. from __future__ import annotations import builtins import itertools import operator from typing import Optional, Sequence import torch from . import _dtypes_impl, _util from ._normalizations import ( ArrayLike, ArrayLikeOrScalar, CastingModes, DTypeLike, NDArray, NotImplementedType, OutArray, ) def copy( a: ArrayLike, order: NotImplementedType = "K", subok: NotImplementedType = False ): return a.clone() def copyto( dst: NDArray, src: ArrayLike, casting: Optional[CastingModes] = "same_kind", where: NotImplementedType = None, ): (src,) = _util.typecast_tensors((src,), dst.dtype, casting=casting) dst.copy_(src) def atleast_1d(*arys: ArrayLike): res = torch.atleast_1d(*arys) if isinstance(res, tuple): return list(res) else: return res def atleast_2d(*arys: ArrayLike): res = torch.atleast_2d(*arys) if isinstance(res, tuple): return list(res) else: return res def atleast_3d(*arys: ArrayLike): res = torch.atleast_3d(*arys) if isinstance(res, tuple): return list(res) else: return res def _concat_check(tup, dtype, out): if tup == (): raise ValueError("need at least one array to concatenate") """Check inputs in concatenate et al.""" if out is not None and dtype is not None: # mimic numpy raise TypeError( "concatenate() only takes `out` or `dtype` as an " "argument, but both were provided." ) def _concat_cast_helper(tensors, out=None, dtype=None, casting="same_kind"): """Figure out dtypes, cast if necessary.""" if out is not None or dtype is not None: # figure out the type of the inputs and outputs out_dtype = out.dtype.torch_dtype if dtype is None else dtype else: out_dtype = _dtypes_impl.result_type_impl(*tensors) # cast input arrays if necessary; do not broadcast them agains `out` tensors = _util.typecast_tensors(tensors, out_dtype, casting) return tensors def _concatenate( tensors, axis=0, out=None, dtype=None, casting: Optional[CastingModes] = "same_kind" ): # pure torch implementation, used below and in cov/corrcoef below tensors, axis = _util.axis_none_flatten(*tensors, axis=axis) tensors = _concat_cast_helper(tensors, out, dtype, casting) return torch.cat(tensors, axis) def concatenate( ar_tuple: Sequence[ArrayLike], axis=0, out: Optional[OutArray] = None, dtype: Optional[DTypeLike] = None, casting: Optional[CastingModes] = "same_kind", ): _concat_check(ar_tuple, dtype, out=out) result = _concatenate(ar_tuple, axis=axis, out=out, dtype=dtype, casting=casting) return result def vstack( tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting: Optional[CastingModes] = "same_kind", ): _concat_check(tup, dtype, out=None) tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) return torch.vstack(tensors) row_stack = vstack def hstack( tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting: Optional[CastingModes] = "same_kind", ): _concat_check(tup, dtype, out=None) tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) return torch.hstack(tensors) def dstack( tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting: Optional[CastingModes] = "same_kind", ): # XXX: in numpy 1.24 dstack does not have dtype and casting keywords # but {h,v}stack do. Hence add them here for consistency. _concat_check(tup, dtype, out=None) tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) return torch.dstack(tensors) def column_stack( tup: Sequence[ArrayLike], *, dtype: Optional[DTypeLike] = None, casting: Optional[CastingModes] = "same_kind", ): # XXX: in numpy 1.24 column_stack does not have dtype and casting keywords # but row_stack does. (because row_stack is an alias for vstack, really). # Hence add these keywords here for consistency. _concat_check(tup, dtype, out=None) tensors = _concat_cast_helper(tup, dtype=dtype, casting=casting) return torch.column_stack(tensors) def stack( arrays: Sequence[ArrayLike], axis=0, out: Optional[OutArray] = None, *, dtype: Optional[DTypeLike] = None, casting: Optional[CastingModes] = "same_kind", ): _concat_check(arrays, dtype, out=out) tensors = _concat_cast_helper(arrays, dtype=dtype, casting=casting) result_ndim = tensors[0].ndim + 1 axis = _util.normalize_axis_index(axis, result_ndim) return torch.stack(tensors, axis=axis) def append(arr: ArrayLike, values: ArrayLike, axis=None): if axis is None: if arr.ndim != 1: arr = arr.flatten() values = values.flatten() axis = arr.ndim - 1 return _concatenate((arr, values), axis=axis) # ### split ### def _split_helper(tensor, indices_or_sections, axis, strict=False): if isinstance(indices_or_sections, int): return _split_helper_int(tensor, indices_or_sections, axis, strict) elif isinstance(indices_or_sections, (list, tuple)): # NB: drop split=..., it only applies to split_helper_int return _split_helper_list(tensor, list(indices_or_sections), axis) else: raise TypeError("split_helper: ", type(indices_or_sections)) def _split_helper_int(tensor, indices_or_sections, axis, strict=False): if not isinstance(indices_or_sections, int): raise NotImplementedError("split: indices_or_sections") axis = _util.normalize_axis_index(axis, tensor.ndim) # numpy: l%n chunks of size (l//n + 1), the rest are sized l//n l, n = tensor.shape[axis], indices_or_sections if n <= 0: raise ValueError() if l % n == 0: num, sz = n, l // n lst = [sz] * num else: if strict: raise ValueError("array split does not result in an equal division") num, sz = l % n, l // n + 1 lst = [sz] * num lst += [sz - 1] * (n - num) return torch.split(tensor, lst, axis) def _split_helper_list(tensor, indices_or_sections, axis): if not isinstance(indices_or_sections, list): raise NotImplementedError("split: indices_or_sections: list") # numpy expects indices, while torch expects lengths of sections # also, numpy appends zero-size arrays for indices above the shape[axis] lst = [x for x in indices_or_sections if x <= tensor.shape[axis]] num_extra = len(indices_or_sections) - len(lst) lst.append(tensor.shape[axis]) lst = [ lst[0], ] + [a - b for a, b in zip(lst[1:], lst[:-1])] lst += [0] * num_extra return torch.split(tensor, lst, axis) def array_split(ary: ArrayLike, indices_or_sections, axis=0): return _split_helper(ary, indices_or_sections, axis) def split(ary: ArrayLike, indices_or_sections, axis=0): return _split_helper(ary, indices_or_sections, axis, strict=True) def hsplit(ary: ArrayLike, indices_or_sections): if ary.ndim == 0: raise ValueError("hsplit only works on arrays of 1 or more dimensions") axis = 1 if ary.ndim > 1 else 0 return _split_helper(ary, indices_or_sections, axis, strict=True) def vsplit(ary: ArrayLike, indices_or_sections): if ary.ndim < 2: raise ValueError("vsplit only works on arrays of 2 or more dimensions") return _split_helper(ary, indices_or_sections, 0, strict=True) def dsplit(ary: ArrayLike, indices_or_sections): if ary.ndim < 3: raise ValueError("dsplit only works on arrays of 3 or more dimensions") return _split_helper(ary, indices_or_sections, 2, strict=True) def kron(a: ArrayLike, b: ArrayLike): return torch.kron(a, b) def vander(x: ArrayLike, N=None, increasing=False): return torch.vander(x, N, increasing) # ### linspace, geomspace, logspace and arange ### def linspace( start: ArrayLike, stop: ArrayLike, num=50, endpoint=True, retstep=False, dtype: Optional[DTypeLike] = None, axis=0, ): if axis != 0 or retstep or not endpoint: raise NotImplementedError if dtype is None: dtype = _dtypes_impl.default_dtypes().float_dtype # XXX: raises TypeError if start or stop are not scalars return torch.linspace(start, stop, num, dtype=dtype) def geomspace( start: ArrayLike, stop: ArrayLike, num=50, endpoint=True, dtype: Optional[DTypeLike] = None, axis=0, ): if axis != 0 or not endpoint: raise NotImplementedError base = torch.pow(stop / start, 1.0 / (num - 1)) logbase = torch.log(base) return torch.logspace( torch.log(start) / logbase, torch.log(stop) / logbase, num, base=base, ) def logspace( start, stop, num=50, endpoint=True, base=10.0, dtype: Optional[DTypeLike] = None, axis=0, ): if axis != 0 or not endpoint: raise NotImplementedError return torch.logspace(start, stop, num, base=base, dtype=dtype) def arange( start: Optional[ArrayLikeOrScalar] = None, stop: Optional[ArrayLikeOrScalar] = None, step: Optional[ArrayLikeOrScalar] = 1, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None, ): if step == 0: raise ZeroDivisionError if stop is None and start is None: raise TypeError if stop is None: # XXX: this breaks if start is passed as a kwarg: # arange(start=4) should raise (no stop) but doesn't start, stop = 0, start if start is None: start = 0 # the dtype of the result if dtype is None: dtype = ( _dtypes_impl.default_dtypes().float_dtype if any(_dtypes_impl.is_float_or_fp_tensor(x) for x in (start, stop, step)) else _dtypes_impl.default_dtypes().int_dtype ) work_dtype = torch.float64 if dtype.is_complex else dtype # RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'. Fall back to eager. if any(_dtypes_impl.is_complex_or_complex_tensor(x) for x in (start, stop, step)): raise NotImplementedError if (step > 0 and start > stop) or (step < 0 and start < stop): # empty range return torch.empty(0, dtype=dtype) result = torch.arange(start, stop, step, dtype=work_dtype) result = _util.cast_if_needed(result, dtype) return result # ### zeros/ones/empty/full ### def empty( shape, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "C", *, like: NotImplementedType = None, ): if dtype is None: dtype = _dtypes_impl.default_dtypes().float_dtype return torch.empty(shape, dtype=dtype) # NB: *_like functions deliberately deviate from numpy: it has subok=True # as the default; we set subok=False and raise on anything else. def empty_like( prototype: ArrayLike, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "K", subok: NotImplementedType = False, shape=None, ): result = torch.empty_like(prototype, dtype=dtype) if shape is not None: result = result.reshape(shape) return result def full( shape, fill_value: ArrayLike, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "C", *, like: NotImplementedType = None, ): if isinstance(shape, int): shape = (shape,) if dtype is None: dtype = fill_value.dtype if not isinstance(shape, (tuple, list)): shape = (shape,) return torch.full(shape, fill_value, dtype=dtype) def full_like( a: ArrayLike, fill_value, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "K", subok: NotImplementedType = False, shape=None, ): # XXX: fill_value broadcasts result = torch.full_like(a, fill_value, dtype=dtype) if shape is not None: result = result.reshape(shape) return result def ones( shape, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "C", *, like: NotImplementedType = None, ): if dtype is None: dtype = _dtypes_impl.default_dtypes().float_dtype return torch.ones(shape, dtype=dtype) def ones_like( a: ArrayLike, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "K", subok: NotImplementedType = False, shape=None, ): result = torch.ones_like(a, dtype=dtype) if shape is not None: result = result.reshape(shape) return result def zeros( shape, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "C", *, like: NotImplementedType = None, ): if dtype is None: dtype = _dtypes_impl.default_dtypes().float_dtype return torch.zeros(shape, dtype=dtype) def zeros_like( a: ArrayLike, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "K", subok: NotImplementedType = False, shape=None, ): result = torch.zeros_like(a, dtype=dtype) if shape is not None: result = result.reshape(shape) return result # ### cov & corrcoef ### def _xy_helper_corrcoef(x_tensor, y_tensor=None, rowvar=True): """Prepare inputs for cov and corrcoef.""" # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/function_base.py#L2636 if y_tensor is not None: # make sure x and y are at least 2D ndim_extra = 2 - x_tensor.ndim if ndim_extra > 0: x_tensor = x_tensor.view((1,) * ndim_extra + x_tensor.shape) if not rowvar and x_tensor.shape[0] != 1: x_tensor = x_tensor.mT x_tensor = x_tensor.clone() ndim_extra = 2 - y_tensor.ndim if ndim_extra > 0: y_tensor = y_tensor.view((1,) * ndim_extra + y_tensor.shape) if not rowvar and y_tensor.shape[0] != 1: y_tensor = y_tensor.mT y_tensor = y_tensor.clone() x_tensor = _concatenate((x_tensor, y_tensor), axis=0) return x_tensor def corrcoef( x: ArrayLike, y: Optional[ArrayLike] = None, rowvar=True, bias=None, ddof=None, *, dtype: Optional[DTypeLike] = None, ): if bias is not None or ddof is not None: # deprecated in NumPy raise NotImplementedError xy_tensor = _xy_helper_corrcoef(x, y, rowvar) is_half = (xy_tensor.dtype == torch.float16) and xy_tensor.is_cpu if is_half: # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" dtype = torch.float32 xy_tensor = _util.cast_if_needed(xy_tensor, dtype) result = torch.corrcoef(xy_tensor) if is_half: result = result.to(torch.float16) return result def cov( m: ArrayLike, y: Optional[ArrayLike] = None, rowvar=True, bias=False, ddof=None, fweights: Optional[ArrayLike] = None, aweights: Optional[ArrayLike] = None, *, dtype: Optional[DTypeLike] = None, ): m = _xy_helper_corrcoef(m, y, rowvar) if ddof is None: ddof = 1 if bias == 0 else 0 is_half = (m.dtype == torch.float16) and m.is_cpu if is_half: # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" dtype = torch.float32 m = _util.cast_if_needed(m, dtype) result = torch.cov(m, correction=ddof, aweights=aweights, fweights=fweights) if is_half: result = result.to(torch.float16) return result def _conv_corr_impl(a, v, mode): dt = _dtypes_impl.result_type_impl(a, v) a = _util.cast_if_needed(a, dt) v = _util.cast_if_needed(v, dt) padding = v.shape[0] - 1 if mode == "full" else mode if padding == "same" and v.shape[0] % 2 == 0: # UserWarning: Using padding='same' with even kernel lengths and odd # dilation may require a zero-padded copy of the input be created # (Triggered internally at pytorch/aten/src/ATen/native/Convolution.cpp:1010.) raise NotImplementedError("mode='same' and even-length weights") # NumPy only accepts 1D arrays; PyTorch requires 2D inputs and 3D weights aa = a[None, :] vv = v[None, None, :] result = torch.nn.functional.conv1d(aa, vv, padding=padding) # torch returns a 2D result, numpy returns a 1D array return result[0, :] def convolve(a: ArrayLike, v: ArrayLike, mode="full"): # NumPy: if v is longer than a, the arrays are swapped before computation if a.shape[0] < v.shape[0]: a, v = v, a # flip the weights since numpy does and torch does not v = torch.flip(v, (0,)) return _conv_corr_impl(a, v, mode) def correlate(a: ArrayLike, v: ArrayLike, mode="valid"): v = torch.conj_physical(v) return _conv_corr_impl(a, v, mode) # ### logic & element selection ### def bincount(x: ArrayLike, /, weights: Optional[ArrayLike] = None, minlength=0): if x.numel() == 0: # edge case allowed by numpy x = x.new_empty(0, dtype=int) int_dtype = _dtypes_impl.default_dtypes().int_dtype (x,) = _util.typecast_tensors((x,), int_dtype, casting="safe") return torch.bincount(x, weights, minlength) def where( condition: ArrayLike, x: Optional[ArrayLikeOrScalar] = None, y: Optional[ArrayLikeOrScalar] = None, /, ): if (x is None) != (y is None): raise ValueError("either both or neither of x and y should be given") if condition.dtype != torch.bool: condition = condition.to(torch.bool) if x is None and y is None: result = torch.where(condition) else: result = torch.where(condition, x, y) return result # ###### module-level queries of object properties def ndim(a: ArrayLike): return a.ndim def shape(a: ArrayLike): return tuple(a.shape) def size(a: ArrayLike, axis=None): if axis is None: return a.numel() else: return a.shape[axis] # ###### shape manipulations and indexing def expand_dims(a: ArrayLike, axis): shape = _util.expand_shape(a.shape, axis) return a.view(shape) # never copies def flip(m: ArrayLike, axis=None): # XXX: semantic difference: np.flip returns a view, torch.flip copies if axis is None: axis = tuple(range(m.ndim)) else: axis = _util.normalize_axis_tuple(axis, m.ndim) return torch.flip(m, axis) def flipud(m: ArrayLike): return torch.flipud(m) def fliplr(m: ArrayLike): return torch.fliplr(m) def rot90(m: ArrayLike, k=1, axes=(0, 1)): axes = _util.normalize_axis_tuple(axes, m.ndim) return torch.rot90(m, k, axes) # ### broadcasting and indices ### def broadcast_to(array: ArrayLike, shape, subok: NotImplementedType = False): return torch.broadcast_to(array, size=shape) # This is a function from tuples to tuples, so we just reuse it from torch import broadcast_shapes def broadcast_arrays(*args: ArrayLike, subok: NotImplementedType = False): return torch.broadcast_tensors(*args) def meshgrid(*xi: ArrayLike, copy=True, sparse=False, indexing="xy"): ndim = len(xi) if indexing not in ["xy", "ij"]: raise ValueError("Valid values for `indexing` are 'xy' and 'ij'.") s0 = (1,) * ndim output = [x.reshape(s0[:i] + (-1,) + s0[i + 1 :]) for i, x in enumerate(xi)] if indexing == "xy" and ndim > 1: # switch first and second axis output[0] = output[0].reshape((1, -1) + s0[2:]) output[1] = output[1].reshape((-1, 1) + s0[2:]) if not sparse: # Return the full N-D matrix (not only the 1-D vector) output = torch.broadcast_tensors(*output) if copy: output = [x.clone() for x in output] return list(output) # match numpy, return a list def indices(dimensions, dtype: Optional[DTypeLike] = int, sparse=False): # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1691-L1791 dimensions = tuple(dimensions) N = len(dimensions) shape = (1,) * N if sparse: res = tuple() else: res = torch.empty((N,) + dimensions, dtype=dtype) for i, dim in enumerate(dimensions): idx = torch.arange(dim, dtype=dtype).reshape( shape[:i] + (dim,) + shape[i + 1 :] ) if sparse: res = res + (idx,) else: res[i] = idx return res # ### tri*-something ### def tril(m: ArrayLike, k=0): return torch.tril(m, k) def triu(m: ArrayLike, k=0): return torch.triu(m, k) def tril_indices(n, k=0, m=None): if m is None: m = n return torch.tril_indices(n, m, offset=k) def triu_indices(n, k=0, m=None): if m is None: m = n return torch.triu_indices(n, m, offset=k) def tril_indices_from(arr: ArrayLike, k=0): if arr.ndim != 2: raise ValueError("input array must be 2-d") # Return a tensor rather than a tuple to avoid a graphbreak return torch.tril_indices(arr.shape[0], arr.shape[1], offset=k) def triu_indices_from(arr: ArrayLike, k=0): if arr.ndim != 2: raise ValueError("input array must be 2-d") # Return a tensor rather than a tuple to avoid a graphbreak return torch.triu_indices(arr.shape[0], arr.shape[1], offset=k) def tri( N, M=None, k=0, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None, ): if M is None: M = N tensor = torch.ones((N, M), dtype=dtype) return torch.tril(tensor, diagonal=k) # ### equality, equivalence, allclose ### def isclose(a: ArrayLike, b: ArrayLike, rtol=1.0e-5, atol=1.0e-8, equal_nan=False): dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) return torch.isclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) def allclose(a: ArrayLike, b: ArrayLike, rtol=1e-05, atol=1e-08, equal_nan=False): dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) return torch.allclose(a, b, rtol=rtol, atol=atol, equal_nan=equal_nan) def _tensor_equal(a1, a2, equal_nan=False): # Implementation of array_equal/array_equiv. if a1.shape != a2.shape: return False cond = a1 == a2 if equal_nan: cond = cond | (torch.isnan(a1) & torch.isnan(a2)) return cond.all().item() def array_equal(a1: ArrayLike, a2: ArrayLike, equal_nan=False): return _tensor_equal(a1, a2, equal_nan=equal_nan) def array_equiv(a1: ArrayLike, a2: ArrayLike): # *almost* the same as array_equal: _equiv tries to broadcast, _equal does not try: a1_t, a2_t = torch.broadcast_tensors(a1, a2) except RuntimeError: # failed to broadcast => not equivalent return False return _tensor_equal(a1_t, a2_t) def nan_to_num( x: ArrayLike, copy: NotImplementedType = True, nan=0.0, posinf=None, neginf=None ): # work around RuntimeError: "nan_to_num" not implemented for 'ComplexDouble' if x.is_complex(): re = torch.nan_to_num(x.real, nan=nan, posinf=posinf, neginf=neginf) im = torch.nan_to_num(x.imag, nan=nan, posinf=posinf, neginf=neginf) return re + 1j * im else: return torch.nan_to_num(x, nan=nan, posinf=posinf, neginf=neginf) # ### put/take_along_axis ### def take( a: ArrayLike, indices: ArrayLike, axis=None, out: Optional[OutArray] = None, mode: NotImplementedType = "raise", ): (a,), axis = _util.axis_none_flatten(a, axis=axis) axis = _util.normalize_axis_index(axis, a.ndim) idx = (slice(None),) * axis + (indices, ...) result = a[idx] return result def take_along_axis(arr: ArrayLike, indices: ArrayLike, axis): (arr,), axis = _util.axis_none_flatten(arr, axis=axis) axis = _util.normalize_axis_index(axis, arr.ndim) return torch.take_along_dim(arr, indices, axis) def put( a: NDArray, ind: ArrayLike, v: ArrayLike, mode: NotImplementedType = "raise", ): v = v.type(a.dtype) # If ind is larger than v, expand v to at least the size of ind. Any # unnecessary trailing elements are then trimmed. if ind.numel() > v.numel(): ratio = (ind.numel() + v.numel() - 1) // v.numel() v = v.unsqueeze(0).expand((ratio,) + v.shape) # Trim unnecessary elements, regardless if v was expanded or not. Note # np.put() trims v to match ind by default too. if ind.numel() < v.numel(): v = v.flatten() v = v[: ind.numel()] a.put_(ind, v) return None def put_along_axis(arr: ArrayLike, indices: ArrayLike, values: ArrayLike, axis): (arr,), axis = _util.axis_none_flatten(arr, axis=axis) axis = _util.normalize_axis_index(axis, arr.ndim) indices, values = torch.broadcast_tensors(indices, values) values = _util.cast_if_needed(values, arr.dtype) result = torch.scatter(arr, axis, indices, values) arr.copy_(result.reshape(arr.shape)) return None def choose( a: ArrayLike, choices: Sequence[ArrayLike], out: Optional[OutArray] = None, mode: NotImplementedType = "raise", ): # First, broadcast elements of `choices` choices = torch.stack(torch.broadcast_tensors(*choices)) # Use an analog of `gather(choices, 0, a)` which broadcasts `choices` vs `a`: # (taken from https://github.com/pytorch/pytorch/issues/9407#issuecomment-1427907939) idx_list = [ torch.arange(dim).view((1,) * i + (dim,) + (1,) * (choices.ndim - i - 1)) for i, dim in enumerate(choices.shape) ] idx_list[0] = a return choices[idx_list].squeeze(0) # ### unique et al ### def unique( ar: ArrayLike, return_index: NotImplementedType = False, return_inverse=False, return_counts=False, axis=None, *, equal_nan: NotImplementedType = True, ): (ar,), axis = _util.axis_none_flatten(ar, axis=axis) axis = _util.normalize_axis_index(axis, ar.ndim) result = torch.unique( ar, return_inverse=return_inverse, return_counts=return_counts, dim=axis ) return result def nonzero(a: ArrayLike): return torch.nonzero(a, as_tuple=True) def argwhere(a: ArrayLike): return torch.argwhere(a) def flatnonzero(a: ArrayLike): return torch.flatten(a).nonzero(as_tuple=True)[0] def clip( a: ArrayLike, min: Optional[ArrayLike] = None, max: Optional[ArrayLike] = None, out: Optional[OutArray] = None, ): return torch.clamp(a, min, max) def repeat(a: ArrayLike, repeats: ArrayLikeOrScalar, axis=None): return torch.repeat_interleave(a, repeats, axis) def tile(A: ArrayLike, reps): if isinstance(reps, int): reps = (reps,) return torch.tile(A, reps) def resize(a: ArrayLike, new_shape=None): # implementation vendored from # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/fromnumeric.py#L1420-L1497 if new_shape is None: return a if isinstance(new_shape, int): new_shape = (new_shape,) a = a.flatten() new_size = 1 for dim_length in new_shape: new_size *= dim_length if dim_length < 0: raise ValueError("all elements of `new_shape` must be non-negative") if a.numel() == 0 or new_size == 0: # First case must zero fill. The second would have repeats == 0. return torch.zeros(new_shape, dtype=a.dtype) repeats = -(-new_size // a.numel()) # ceil division a = concatenate((a,) * repeats)[:new_size] return reshape(a, new_shape) # ### diag et al ### def diagonal(a: ArrayLike, offset=0, axis1=0, axis2=1): axis1 = _util.normalize_axis_index(axis1, a.ndim) axis2 = _util.normalize_axis_index(axis2, a.ndim) return torch.diagonal(a, offset, axis1, axis2) def trace( a: ArrayLike, offset=0, axis1=0, axis2=1, dtype: Optional[DTypeLike] = None, out: Optional[OutArray] = None, ): result = torch.diagonal(a, offset, dim1=axis1, dim2=axis2).sum(-1, dtype=dtype) return result def eye( N, M=None, k=0, dtype: Optional[DTypeLike] = None, order: NotImplementedType = "C", *, like: NotImplementedType = None, ): if dtype is None: dtype = _dtypes_impl.default_dtypes().float_dtype if M is None: M = N z = torch.zeros(N, M, dtype=dtype) z.diagonal(k).fill_(1) return z def identity(n, dtype: Optional[DTypeLike] = None, *, like: NotImplementedType = None): return torch.eye(n, dtype=dtype) def diag(v: ArrayLike, k=0): return torch.diag(v, k) def diagflat(v: ArrayLike, k=0): return torch.diagflat(v, k) def diag_indices(n, ndim=2): idx = torch.arange(n) return (idx,) * ndim def diag_indices_from(arr: ArrayLike): if not arr.ndim >= 2: raise ValueError("input array must be at least 2-d") # For more than d=2, the strided formula is only valid for arrays with # all dimensions equal, so we check first. s = arr.shape if s[1:] != s[:-1]: raise ValueError("All dimensions of input must be of equal length") return diag_indices(s[0], arr.ndim) def fill_diagonal(a: ArrayLike, val: ArrayLike, wrap=False): if a.ndim < 2: raise ValueError("array must be at least 2-d") if val.numel() == 0 and not wrap: a.fill_diagonal_(val) return a if val.ndim == 0: val = val.unsqueeze(0) # torch.Tensor.fill_diagonal_ only accepts scalars # If the size of val is too large, then val is trimmed if a.ndim == 2: tall = a.shape[0] > a.shape[1] # wrap does nothing for wide matrices... if not wrap or not tall: # Never wraps diag = a.diagonal() diag.copy_(val[: diag.numel()]) else: # wraps and tall... leaving one empty line between diagonals?! max_, min_ = a.shape idx = torch.arange(max_ - max_ // (min_ + 1)) mod = idx % min_ div = idx // min_ a[(div * (min_ + 1) + mod, mod)] = val[: idx.numel()] else: idx = diag_indices_from(a) # a.shape = (n, n, ..., n) a[idx] = val[: a.shape[0]] return a def vdot(a: ArrayLike, b: ArrayLike, /): # 1. torch only accepts 1D arrays, numpy flattens # 2. torch requires matching dtype, while numpy casts (?) t_a, t_b = torch.atleast_1d(a, b) if t_a.ndim > 1: t_a = t_a.flatten() if t_b.ndim > 1: t_b = t_b.flatten() dtype = _dtypes_impl.result_type_impl(t_a, t_b) is_half = dtype == torch.float16 and (t_a.is_cpu or t_b.is_cpu) is_bool = dtype == torch.bool # work around torch's "dot" not implemented for 'Half', 'Bool' if is_half: dtype = torch.float32 elif is_bool: dtype = torch.uint8 t_a = _util.cast_if_needed(t_a, dtype) t_b = _util.cast_if_needed(t_b, dtype) result = torch.vdot(t_a, t_b) if is_half: result = result.to(torch.float16) elif is_bool: result = result.to(torch.bool) return result def tensordot(a: ArrayLike, b: ArrayLike, axes=2): if isinstance(axes, (list, tuple)): axes = [[ax] if isinstance(ax, int) else ax for ax in axes] target_dtype = _dtypes_impl.result_type_impl(a, b) a = _util.cast_if_needed(a, target_dtype) b = _util.cast_if_needed(b, target_dtype) return torch.tensordot(a, b, dims=axes) def dot(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): dtype = _dtypes_impl.result_type_impl(a, b) is_bool = dtype == torch.bool if is_bool: dtype = torch.uint8 a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) if a.ndim == 0 or b.ndim == 0: result = a * b else: result = torch.matmul(a, b) if is_bool: result = result.to(torch.bool) return result def inner(a: ArrayLike, b: ArrayLike, /): dtype = _dtypes_impl.result_type_impl(a, b) is_half = dtype == torch.float16 and (a.is_cpu or b.is_cpu) is_bool = dtype == torch.bool if is_half: # work around torch's "addmm_impl_cpu_" not implemented for 'Half'" dtype = torch.float32 elif is_bool: dtype = torch.uint8 a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) result = torch.inner(a, b) if is_half: result = result.to(torch.float16) elif is_bool: result = result.to(torch.bool) return result def outer(a: ArrayLike, b: ArrayLike, out: Optional[OutArray] = None): return torch.outer(a, b) def cross(a: ArrayLike, b: ArrayLike, axisa=-1, axisb=-1, axisc=-1, axis=None): # implementation vendored from # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1486-L1685 if axis is not None: axisa, axisb, axisc = (axis,) * 3 # Check axisa and axisb are within bounds axisa = _util.normalize_axis_index(axisa, a.ndim) axisb = _util.normalize_axis_index(axisb, b.ndim) # Move working axis to the end of the shape a = torch.moveaxis(a, axisa, -1) b = torch.moveaxis(b, axisb, -1) msg = "incompatible dimensions for cross product\n(dimension must be 2 or 3)" if a.shape[-1] not in (2, 3) or b.shape[-1] not in (2, 3): raise ValueError(msg) # Create the output array shape = broadcast_shapes(a[..., 0].shape, b[..., 0].shape) if a.shape[-1] == 3 or b.shape[-1] == 3: shape += (3,) # Check axisc is within bounds axisc = _util.normalize_axis_index(axisc, len(shape)) dtype = _dtypes_impl.result_type_impl(a, b) cp = torch.empty(shape, dtype=dtype) # recast arrays as dtype a = _util.cast_if_needed(a, dtype) b = _util.cast_if_needed(b, dtype) # create local aliases for readability a0 = a[..., 0] a1 = a[..., 1] if a.shape[-1] == 3: a2 = a[..., 2] b0 = b[..., 0] b1 = b[..., 1] if b.shape[-1] == 3: b2 = b[..., 2] if cp.ndim != 0 and cp.shape[-1] == 3: cp0 = cp[..., 0] cp1 = cp[..., 1] cp2 = cp[..., 2] if a.shape[-1] == 2: if b.shape[-1] == 2: # a0 * b1 - a1 * b0 cp[...] = a0 * b1 - a1 * b0 return cp else: assert b.shape[-1] == 3 # cp0 = a1 * b2 - 0 (a2 = 0) # cp1 = 0 - a0 * b2 (a2 = 0) # cp2 = a0 * b1 - a1 * b0 cp0[...] = a1 * b2 cp1[...] = -a0 * b2 cp2[...] = a0 * b1 - a1 * b0 else: assert a.shape[-1] == 3 if b.shape[-1] == 3: cp0[...] = a1 * b2 - a2 * b1 cp1[...] = a2 * b0 - a0 * b2 cp2[...] = a0 * b1 - a1 * b0 else: assert b.shape[-1] == 2 cp0[...] = -a2 * b1 cp1[...] = a2 * b0 cp2[...] = a0 * b1 - a1 * b0 return torch.moveaxis(cp, -1, axisc) def einsum(*operands, out=None, dtype=None, order="K", casting="safe", optimize=False): # Have to manually normalize *operands and **kwargs, following the NumPy signature # We have a local import to avoid poluting the global space, as it will be then # exported in funcs.py from ._ndarray import ndarray from ._normalizations import ( maybe_copy_to, normalize_array_like, normalize_casting, normalize_dtype, wrap_tensors, ) dtype = normalize_dtype(dtype) casting = normalize_casting(casting) if out is not None and not isinstance(out, ndarray): raise TypeError("'out' must be an array") if order != "K": raise NotImplementedError("'order' parameter is not supported.") # parse arrays and normalize them sublist_format = not isinstance(operands[0], str) if sublist_format: # op, str, op, str ... [sublistout] format: normalize every other argument # - if sublistout is not given, the length of operands is even, and we pick # odd-numbered elements, which are arrays. # - if sublistout is given, the length of operands is odd, we peel off # the last one, and pick odd-numbered elements, which are arrays. # Without [:-1], we would have picked sublistout, too. array_operands = operands[:-1][::2] else: # ("ij->", arrays) format subscripts, array_operands = operands[0], operands[1:] tensors = [normalize_array_like(op) for op in array_operands] target_dtype = _dtypes_impl.result_type_impl(*tensors) if dtype is None else dtype # work around 'bmm' not implemented for 'Half' etc is_half = target_dtype == torch.float16 and all(t.is_cpu for t in tensors) if is_half: target_dtype = torch.float32 is_short_int = target_dtype in [torch.uint8, torch.int8, torch.int16, torch.int32] if is_short_int: target_dtype = torch.int64 tensors = _util.typecast_tensors(tensors, target_dtype, casting) from torch.backends import opt_einsum try: # set the global state to handle the optimize=... argument, restore on exit if opt_einsum.is_available(): old_strategy = torch.backends.opt_einsum.strategy old_enabled = torch.backends.opt_einsum.enabled # torch.einsum calls opt_einsum.contract_path, which runs into # https://github.com/dgasmith/opt_einsum/issues/219 # for strategy={True, False} if optimize is True: optimize = "auto" elif optimize is False: torch.backends.opt_einsum.enabled = False torch.backends.opt_einsum.strategy = optimize if sublist_format: # recombine operands sublists = operands[1::2] has_sublistout = len(operands) % 2 == 1 if has_sublistout: sublistout = operands[-1] operands = list(itertools.chain(*zip(tensors, sublists))) if has_sublistout: operands.append(sublistout) result = torch.einsum(*operands) else: result = torch.einsum(subscripts, *tensors) finally: if opt_einsum.is_available(): torch.backends.opt_einsum.strategy = old_strategy torch.backends.opt_einsum.enabled = old_enabled result = maybe_copy_to(out, result) return wrap_tensors(result) # ### sort and partition ### def _sort_helper(tensor, axis, kind, order): if tensor.dtype.is_complex: raise NotImplementedError(f"sorting {tensor.dtype} is not supported") (tensor,), axis = _util.axis_none_flatten(tensor, axis=axis) axis = _util.normalize_axis_index(axis, tensor.ndim) stable = kind == "stable" return tensor, axis, stable def sort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): # `order` keyword arg is only relevant for structured dtypes; so not supported here. a, axis, stable = _sort_helper(a, axis, kind, order) result = torch.sort(a, dim=axis, stable=stable) return result.values def argsort(a: ArrayLike, axis=-1, kind=None, order: NotImplementedType = None): a, axis, stable = _sort_helper(a, axis, kind, order) return torch.argsort(a, dim=axis, stable=stable) def searchsorted( a: ArrayLike, v: ArrayLike, side="left", sorter: Optional[ArrayLike] = None ): if a.dtype.is_complex: raise NotImplementedError(f"searchsorted with dtype={a.dtype}") return torch.searchsorted(a, v, side=side, sorter=sorter) # ### swap/move/roll axis ### def moveaxis(a: ArrayLike, source, destination): source = _util.normalize_axis_tuple(source, a.ndim, "source") destination = _util.normalize_axis_tuple(destination, a.ndim, "destination") return torch.moveaxis(a, source, destination) def swapaxes(a: ArrayLike, axis1, axis2): axis1 = _util.normalize_axis_index(axis1, a.ndim) axis2 = _util.normalize_axis_index(axis2, a.ndim) return torch.swapaxes(a, axis1, axis2) def rollaxis(a: ArrayLike, axis, start=0): # Straight vendor from: # https://github.com/numpy/numpy/blob/v1.24.0/numpy/core/numeric.py#L1259 # # Also note this function in NumPy is mostly retained for backwards compat # (https://stackoverflow.com/questions/29891583/reason-why-numpy-rollaxis-is-so-confusing) # so let's not touch it unless hard pressed. n = a.ndim axis = _util.normalize_axis_index(axis, n) if start < 0: start += n msg = "'%s' arg requires %d <= %s < %d, but %d was passed in" if not (0 <= start < n + 1): raise _util.AxisError(msg % ("start", -n, "start", n + 1, start)) if axis < start: # it's been removed start -= 1 if axis == start: # numpy returns a view, here we try returning the tensor itself # return tensor[...] return a axes = list(range(0, n)) axes.remove(axis) axes.insert(start, axis) return a.view(axes) def roll(a: ArrayLike, shift, axis=None): if axis is not None: axis = _util.normalize_axis_tuple(axis, a.ndim, allow_duplicate=True) if not isinstance(shift, tuple): shift = (shift,) * len(axis) return torch.roll(a, shift, axis) # ### shape manipulations ### def squeeze(a: ArrayLike, axis=None): if axis == (): result = a elif axis is None: result = a.squeeze() else: if isinstance(axis, tuple): result = a for ax in axis: result = a.squeeze(ax) else: result = a.squeeze(axis) return result def reshape(a: ArrayLike, newshape, order: NotImplementedType = "C"): # if sh = (1, 2, 3), numpy allows both .reshape(sh) and .reshape(*sh) newshape = newshape[0] if len(newshape) == 1 else newshape return a.reshape(newshape) # NB: cannot use torch.reshape(a, newshape) above, because of # (Pdb) torch.reshape(torch.as_tensor([1]), 1) # *** TypeError: reshape(): argument 'shape' (position 2) must be tuple of SymInts, not int def transpose(a: ArrayLike, axes=None): # numpy allows both .transpose(sh) and .transpose(*sh) # also older code uses axes being a list if axes in [(), None, (None,)]: axes = tuple(reversed(range(a.ndim))) elif len(axes) == 1: axes = axes[0] return a.permute(axes) def ravel(a: ArrayLike, order: NotImplementedType = "C"): return torch.flatten(a) def diff( a: ArrayLike, n=1, axis=-1, prepend: Optional[ArrayLike] = None, append: Optional[ArrayLike] = None, ): axis = _util.normalize_axis_index(axis, a.ndim) if n < 0: raise ValueError(f"order must be non-negative but got {n}") if n == 0: # match numpy and return the input immediately return a if prepend is not None: shape = list(a.shape) shape[axis] = prepend.shape[axis] if prepend.ndim > 0 else 1 prepend = torch.broadcast_to(prepend, shape) if append is not None: shape = list(a.shape) shape[axis] = append.shape[axis] if append.ndim > 0 else 1 append = torch.broadcast_to(append, shape) return torch.diff(a, n, axis=axis, prepend=prepend, append=append) # ### math functions ### def angle(z: ArrayLike, deg=False): result = torch.angle(z) if deg: result = result * (180 / torch.pi) return result def sinc(x: ArrayLike): return torch.sinc(x) # NB: have to normalize *varargs manually def gradient(f: ArrayLike, *varargs, axis=None, edge_order=1): N = f.ndim # number of dimensions varargs = _util.ndarrays_to_tensors(varargs) if axis is None: axes = tuple(range(N)) else: axes = _util.normalize_axis_tuple(axis, N) len_axes = len(axes) n = len(varargs) if n == 0: # no spacing argument - use 1 in all axes dx = [1.0] * len_axes elif n == 1 and (_dtypes_impl.is_scalar(varargs[0]) or varargs[0].ndim == 0): # single scalar or 0D tensor for all axes (np.ndim(varargs[0]) == 0) dx = varargs * len_axes elif n == len_axes: # scalar or 1d array for each axis dx = list(varargs) for i, distances in enumerate(dx): distances = torch.as_tensor(distances) if distances.ndim == 0: continue elif distances.ndim != 1: raise ValueError("distances must be either scalars or 1d") if len(distances) != f.shape[axes[i]]: raise ValueError( "when 1d, distances must match " "the length of the corresponding dimension" ) if not (distances.dtype.is_floating_point or distances.dtype.is_complex): distances = distances.double() diffx = torch.diff(distances) # if distances are constant reduce to the scalar case # since it brings a consistent speedup if (diffx == diffx[0]).all(): diffx = diffx[0] dx[i] = diffx else: raise TypeError("invalid number of arguments") if edge_order > 2: raise ValueError("'edge_order' greater than 2 not supported") # use central differences on interior and one-sided differences on the # endpoints. This preserves second order-accuracy over the full domain. outvals = [] # create slice objects --- initially all are [:, :, ..., :] slice1 = [slice(None)] * N slice2 = [slice(None)] * N slice3 = [slice(None)] * N slice4 = [slice(None)] * N otype = f.dtype if _dtypes_impl.python_type_for_torch(otype) in (int, bool): # Convert to floating point. # First check if f is a numpy integer type; if so, convert f to float64 # to avoid modular arithmetic when computing the changes in f. f = f.double() otype = torch.float64 for axis, ax_dx in zip(axes, dx): if f.shape[axis] < edge_order + 1: raise ValueError( "Shape of array too small to calculate a numerical gradient, " "at least (edge_order + 1) elements are required." ) # result allocation out = torch.empty_like(f, dtype=otype) # spacing for the current axis (NB: np.ndim(ax_dx) == 0) uniform_spacing = _dtypes_impl.is_scalar(ax_dx) or ax_dx.ndim == 0 # Numerical differentiation: 2nd order interior slice1[axis] = slice(1, -1) slice2[axis] = slice(None, -2) slice3[axis] = slice(1, -1) slice4[axis] = slice(2, None) if uniform_spacing: out[tuple(slice1)] = (f[tuple(slice4)] - f[tuple(slice2)]) / (2.0 * ax_dx) else: dx1 = ax_dx[0:-1] dx2 = ax_dx[1:] a = -(dx2) / (dx1 * (dx1 + dx2)) b = (dx2 - dx1) / (dx1 * dx2) c = dx1 / (dx2 * (dx1 + dx2)) # fix the shape for broadcasting shape = [1] * N shape[axis] = -1 a = a.reshape(shape) b = b.reshape(shape) c = c.reshape(shape) # 1D equivalent -- out[1:-1] = a * f[:-2] + b * f[1:-1] + c * f[2:] out[tuple(slice1)] = ( a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] ) # Numerical differentiation: 1st order edges if edge_order == 1: slice1[axis] = 0 slice2[axis] = 1 slice3[axis] = 0 dx_0 = ax_dx if uniform_spacing else ax_dx[0] # 1D equivalent -- out[0] = (f[1] - f[0]) / (x[1] - x[0]) out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_0 slice1[axis] = -1 slice2[axis] = -1 slice3[axis] = -2 dx_n = ax_dx if uniform_spacing else ax_dx[-1] # 1D equivalent -- out[-1] = (f[-1] - f[-2]) / (x[-1] - x[-2]) out[tuple(slice1)] = (f[tuple(slice2)] - f[tuple(slice3)]) / dx_n # Numerical differentiation: 2nd order edges else: slice1[axis] = 0 slice2[axis] = 0 slice3[axis] = 1 slice4[axis] = 2 if uniform_spacing: a = -1.5 / ax_dx b = 2.0 / ax_dx c = -0.5 / ax_dx else: dx1 = ax_dx[0] dx2 = ax_dx[1] a = -(2.0 * dx1 + dx2) / (dx1 * (dx1 + dx2)) b = (dx1 + dx2) / (dx1 * dx2) c = -dx1 / (dx2 * (dx1 + dx2)) # 1D equivalent -- out[0] = a * f[0] + b * f[1] + c * f[2] out[tuple(slice1)] = ( a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] ) slice1[axis] = -1 slice2[axis] = -3 slice3[axis] = -2 slice4[axis] = -1 if uniform_spacing: a = 0.5 / ax_dx b = -2.0 / ax_dx c = 1.5 / ax_dx else: dx1 = ax_dx[-2] dx2 = ax_dx[-1] a = (dx2) / (dx1 * (dx1 + dx2)) b = -(dx2 + dx1) / (dx1 * dx2) c = (2.0 * dx2 + dx1) / (dx2 * (dx1 + dx2)) # 1D equivalent -- out[-1] = a * f[-3] + b * f[-2] + c * f[-1] out[tuple(slice1)] = ( a * f[tuple(slice2)] + b * f[tuple(slice3)] + c * f[tuple(slice4)] ) outvals.append(out) # reset the slice object in this dimension to ":" slice1[axis] = slice(None) slice2[axis] = slice(None) slice3[axis] = slice(None) slice4[axis] = slice(None) if len_axes == 1: return outvals[0] else: return outvals # ### Type/shape etc queries ### def round(a: ArrayLike, decimals=0, out: Optional[OutArray] = None): if a.is_floating_point(): result = torch.round(a, decimals=decimals) elif a.is_complex(): # RuntimeError: "round_cpu" not implemented for 'ComplexFloat' result = torch.complex( torch.round(a.real, decimals=decimals), torch.round(a.imag, decimals=decimals), ) else: # RuntimeError: "round_cpu" not implemented for 'int' result = a return result around = round round_ = round def real_if_close(a: ArrayLike, tol=100): if not torch.is_complex(a): return a if tol > 1: # Undocumented in numpy: if tol < 1, it's an absolute tolerance! # Otherwise, tol > 1 is relative tolerance, in units of the dtype epsilon # https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L577 tol = tol * torch.finfo(a.dtype).eps mask = torch.abs(a.imag) < tol return a.real if mask.all() else a def real(a: ArrayLike): return torch.real(a) def imag(a: ArrayLike): if a.is_complex(): return a.imag return torch.zeros_like(a) def iscomplex(x: ArrayLike): if torch.is_complex(x): return x.imag != 0 return torch.zeros_like(x, dtype=torch.bool) def isreal(x: ArrayLike): if torch.is_complex(x): return x.imag == 0 return torch.ones_like(x, dtype=torch.bool) def iscomplexobj(x: ArrayLike): return torch.is_complex(x) def isrealobj(x: ArrayLike): return not torch.is_complex(x) def isneginf(x: ArrayLike, out: Optional[OutArray] = None): return torch.isneginf(x) def isposinf(x: ArrayLike, out: Optional[OutArray] = None): return torch.isposinf(x) def i0(x: ArrayLike): return torch.special.i0(x) def isscalar(a): # We need to use normalize_array_like, but we don't want to export it in funcs.py from ._normalizations import normalize_array_like try: t = normalize_array_like(a) return t.numel() == 1 except Exception: return False # ### Filter windows ### def hamming(M): dtype = _dtypes_impl.default_dtypes().float_dtype return torch.hamming_window(M, periodic=False, dtype=dtype) def hanning(M): dtype = _dtypes_impl.default_dtypes().float_dtype return torch.hann_window(M, periodic=False, dtype=dtype) def kaiser(M, beta): dtype = _dtypes_impl.default_dtypes().float_dtype return torch.kaiser_window(M, beta=beta, periodic=False, dtype=dtype) def blackman(M): dtype = _dtypes_impl.default_dtypes().float_dtype return torch.blackman_window(M, periodic=False, dtype=dtype) def bartlett(M): dtype = _dtypes_impl.default_dtypes().float_dtype return torch.bartlett_window(M, periodic=False, dtype=dtype) # ### Dtype routines ### # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/type_check.py#L666 array_type = [ [torch.float16, torch.float32, torch.float64], [None, torch.complex64, torch.complex128], ] array_precision = { torch.float16: 0, torch.float32: 1, torch.float64: 2, torch.complex64: 1, torch.complex128: 2, } def common_type(*tensors: ArrayLike): is_complex = False precision = 0 for a in tensors: t = a.dtype if iscomplexobj(a): is_complex = True if not (t.is_floating_point or t.is_complex): p = 2 # array_precision[_nx.double] else: p = array_precision.get(t, None) if p is None: raise TypeError("can't get common type for non-numeric array") precision = builtins.max(precision, p) if is_complex: return array_type[1][precision] else: return array_type[0][precision] # ### histograms ### def histogram( a: ArrayLike, bins: ArrayLike = 10, range=None, normed=None, weights: Optional[ArrayLike] = None, density=None, ): if normed is not None: raise ValueError("normed argument is deprecated, use density= instead") if weights is not None and weights.dtype.is_complex: raise NotImplementedError("complex weights histogram.") is_a_int = not (a.dtype.is_floating_point or a.dtype.is_complex) is_w_int = weights is None or not weights.dtype.is_floating_point if is_a_int: a = a.double() if weights is not None: weights = _util.cast_if_needed(weights, a.dtype) if isinstance(bins, torch.Tensor): if bins.ndim == 0: # bins was a single int bins = operator.index(bins) else: bins = _util.cast_if_needed(bins, a.dtype) if range is None: h, b = torch.histogram(a, bins, weight=weights, density=bool(density)) else: h, b = torch.histogram( a, bins, range=range, weight=weights, density=bool(density) ) if not density and is_w_int: h = h.long() if is_a_int: b = b.long() return h, b def histogram2d( x, y, bins=10, range: Optional[ArrayLike] = None, normed=None, weights: Optional[ArrayLike] = None, density=None, ): # vendored from https://github.com/numpy/numpy/blob/v1.24.0/numpy/lib/twodim_base.py#L655-L821 if len(x) != len(y): raise ValueError("x and y must have the same length.") try: N = len(bins) except TypeError: N = 1 if N != 1 and N != 2: bins = [bins, bins] h, e = histogramdd((x, y), bins, range, normed, weights, density) return h, e[0], e[1] def histogramdd( sample, bins=10, range: Optional[ArrayLike] = None, normed=None, weights: Optional[ArrayLike] = None, density=None, ): # have to normalize manually because `sample` interpretation differs # for a list of lists and a 2D array if normed is not None: raise ValueError("normed argument is deprecated, use density= instead") from ._normalizations import normalize_array_like, normalize_seq_array_like if isinstance(sample, (list, tuple)): sample = normalize_array_like(sample).T else: sample = normalize_array_like(sample) sample = torch.atleast_2d(sample) if not (sample.dtype.is_floating_point or sample.dtype.is_complex): sample = sample.double() # bins is either an int, or a sequence of ints or a sequence of arrays bins_is_array = not ( isinstance(bins, int) or builtins.all(isinstance(b, int) for b in bins) ) if bins_is_array: bins = normalize_seq_array_like(bins) bins_dtypes = [b.dtype for b in bins] bins = [_util.cast_if_needed(b, sample.dtype) for b in bins] if range is not None: range = range.flatten().tolist() if weights is not None: # range=... is required : interleave min and max values per dimension mm = sample.aminmax(dim=0) range = torch.cat(mm).reshape(2, -1).T.flatten() range = tuple(range.tolist()) weights = _util.cast_if_needed(weights, sample.dtype) w_kwd = {"weight": weights} else: w_kwd = {} h, b = torch.histogramdd(sample, bins, range, density=bool(density), **w_kwd) if bins_is_array: b = [_util.cast_if_needed(bb, dtyp) for bb, dtyp in zip(b, bins_dtypes)] return h, b # ### odds and ends def min_scalar_type(a: ArrayLike, /): # https://github.com/numpy/numpy/blob/maintenance/1.24.x/numpy/core/src/multiarray/convert_datatype.c#L1288 from ._dtypes import DType if a.numel() > 1: # numpy docs: "For non-scalar array a, returns the vector’s dtype unmodified." return DType(a.dtype) if a.dtype == torch.bool: dtype = torch.bool elif a.dtype.is_complex: fi = torch.finfo(torch.float32) fits_in_single = a.dtype == torch.complex64 or ( fi.min <= a.real <= fi.max and fi.min <= a.imag <= fi.max ) dtype = torch.complex64 if fits_in_single else torch.complex128 elif a.dtype.is_floating_point: for dt in [torch.float16, torch.float32, torch.float64]: fi = torch.finfo(dt) if fi.min <= a <= fi.max: dtype = dt break else: # must be integer for dt in [torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64]: # Prefer unsigned int where possible, as numpy does. ii = torch.iinfo(dt) if ii.min <= a <= ii.max: dtype = dt break return DType(dtype) def pad(array: ArrayLike, pad_width: ArrayLike, mode="constant", **kwargs): if mode != "constant": raise NotImplementedError value = kwargs.get("constant_values", 0) # `value` must be a python scalar for torch.nn.functional.pad typ = _dtypes_impl.python_type_for_torch(array.dtype) value = typ(value) pad_width = torch.broadcast_to(pad_width, (array.ndim, 2)) pad_width = torch.flip(pad_width, (0,)).flatten() return torch.nn.functional.pad(array, tuple(pad_width), value=value)