Spaces:
Runtime error
Runtime error
"""Assorted utilities, which do not need anything other then torch and stdlib. | |
""" | |
import operator | |
import torch | |
from . import _dtypes_impl | |
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504 | |
def is_sequence(seq): | |
if isinstance(seq, str): | |
return False | |
try: | |
len(seq) | |
except Exception: | |
return False | |
return True | |
class AxisError(ValueError, IndexError): | |
pass | |
class UFuncTypeError(TypeError, RuntimeError): | |
pass | |
def cast_if_needed(tensor, dtype): | |
# NB: no casting if dtype=None | |
if dtype is not None and tensor.dtype != dtype: | |
tensor = tensor.to(dtype) | |
return tensor | |
def cast_int_to_float(x): | |
# cast integers and bools to the default float dtype | |
if _dtypes_impl._category(x.dtype) < 2: | |
x = x.to(_dtypes_impl.default_dtypes().float_dtype) | |
return x | |
# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h | |
def normalize_axis_index(ax, ndim, argname=None): | |
if not (-ndim <= ax < ndim): | |
raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") | |
if ax < 0: | |
ax += ndim | |
return ax | |
# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378 | |
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): | |
""" | |
Normalizes an axis argument into a tuple of non-negative integer axes. | |
This handles shorthands such as ``1`` and converts them to ``(1,)``, | |
as well as performing the handling of negative indices covered by | |
`normalize_axis_index`. | |
By default, this forbids axes from being specified multiple times. | |
Used internally by multi-axis-checking logic. | |
Parameters | |
---------- | |
axis : int, iterable of int | |
The un-normalized index or indices of the axis. | |
ndim : int | |
The number of dimensions of the array that `axis` should be normalized | |
against. | |
argname : str, optional | |
A prefix to put before the error message, typically the name of the | |
argument. | |
allow_duplicate : bool, optional | |
If False, the default, disallow an axis from being specified twice. | |
Returns | |
------- | |
normalized_axes : tuple of int | |
The normalized axis index, such that `0 <= normalized_axis < ndim` | |
""" | |
# Optimization to speed-up the most common cases. | |
if type(axis) not in (tuple, list): | |
try: | |
axis = [operator.index(axis)] | |
except TypeError: | |
pass | |
# Going via an iterator directly is slower than via list comprehension. | |
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) | |
if not allow_duplicate and len(set(axis)) != len(axis): | |
if argname: | |
raise ValueError(f"repeated axis in `{argname}` argument") | |
else: | |
raise ValueError("repeated axis") | |
return axis | |
def allow_only_single_axis(axis): | |
if axis is None: | |
return axis | |
if len(axis) != 1: | |
raise NotImplementedError("does not handle tuple axis") | |
return axis[0] | |
def expand_shape(arr_shape, axis): | |
# taken from numpy 1.23.x, expand_dims function | |
if type(axis) not in (list, tuple): | |
axis = (axis,) | |
out_ndim = len(axis) + len(arr_shape) | |
axis = normalize_axis_tuple(axis, out_ndim) | |
shape_it = iter(arr_shape) | |
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] | |
return shape | |
def apply_keepdims(tensor, axis, ndim): | |
if axis is None: | |
# tensor was a scalar | |
shape = (1,) * ndim | |
tensor = tensor.expand(shape).contiguous() | |
else: | |
shape = expand_shape(tensor.shape, axis) | |
tensor = tensor.reshape(shape) | |
return tensor | |
def axis_none_flatten(*tensors, axis=None): | |
"""Flatten the arrays if axis is None.""" | |
if axis is None: | |
tensors = tuple(ar.flatten() for ar in tensors) | |
return tensors, 0 | |
else: | |
return tensors, axis | |
def typecast_tensor(t, target_dtype, casting): | |
"""Dtype-cast tensor to target_dtype. | |
Parameters | |
---------- | |
t : torch.Tensor | |
The tensor to cast | |
target_dtype : torch dtype object | |
The array dtype to cast all tensors to | |
casting : str | |
The casting mode, see `np.can_cast` | |
Returns | |
------- | |
`torch.Tensor` of the `target_dtype` dtype | |
Raises | |
------ | |
ValueError | |
if the argument cannot be cast according to the `casting` rule | |
""" | |
can_cast = _dtypes_impl.can_cast_impl | |
if not can_cast(t.dtype, target_dtype, casting=casting): | |
raise TypeError( | |
f"Cannot cast array data from {t.dtype} to" | |
f" {target_dtype} according to the rule '{casting}'" | |
) | |
return cast_if_needed(t, target_dtype) | |
def typecast_tensors(tensors, target_dtype, casting): | |
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) | |
def _try_convert_to_tensor(obj): | |
try: | |
tensor = torch.as_tensor(obj) | |
except Exception as e: | |
mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." | |
raise NotImplementedError(mesg) # noqa: TRY200 | |
return tensor | |
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): | |
"""The core logic of the array(...) function. | |
Parameters | |
---------- | |
obj : tensor_like | |
The thing to coerce | |
dtype : torch.dtype object or None | |
Coerce to this torch dtype | |
copy : bool | |
Copy or not | |
ndmin : int | |
The results as least this many dimensions | |
is_weak : bool | |
Whether obj is a weakly typed python scalar. | |
Returns | |
------- | |
tensor : torch.Tensor | |
a tensor object with requested dtype, ndim and copy semantics. | |
Notes | |
----- | |
This is almost a "tensor_like" coersion function. Does not handle wrapper | |
ndarrays (those should be handled in the ndarray-aware layer prior to | |
invoking this function). | |
""" | |
if isinstance(obj, torch.Tensor): | |
tensor = obj | |
else: | |
# tensor.dtype is the pytorch default, typically float32. If obj's elements | |
# are not exactly representable in float32, we've lost precision: | |
# >>> torch.as_tensor(1e12).item() - 1e12 | |
# -4096.0 | |
default_dtype = torch.get_default_dtype() | |
torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) | |
try: | |
tensor = _try_convert_to_tensor(obj) | |
finally: | |
torch.set_default_dtype(default_dtype) | |
# type cast if requested | |
tensor = cast_if_needed(tensor, dtype) | |
# adjust ndim if needed | |
ndim_extra = ndmin - tensor.ndim | |
if ndim_extra > 0: | |
tensor = tensor.view((1,) * ndim_extra + tensor.shape) | |
# copy if requested | |
if copy: | |
tensor = tensor.clone() | |
return tensor | |
def ndarrays_to_tensors(*inputs): | |
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)""" | |
from ._ndarray import ndarray | |
if len(inputs) == 0: | |
return ValueError() | |
elif len(inputs) == 1: | |
input_ = inputs[0] | |
if isinstance(input_, ndarray): | |
return input_.tensor | |
elif isinstance(input_, tuple): | |
result = [] | |
for sub_input in input_: | |
sub_result = ndarrays_to_tensors(sub_input) | |
result.append(sub_result) | |
return tuple(result) | |
else: | |
return input_ | |
else: | |
assert isinstance(inputs, tuple) # sanity check | |
return ndarrays_to_tensors(inputs) | |