Spaces:
Runtime error
Runtime error
File size: 7,785 Bytes
b2659ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 |
"""Assorted utilities, which do not need anything other then torch and stdlib.
"""
import operator
import torch
from . import _dtypes_impl
# https://github.com/numpy/numpy/blob/v1.23.0/numpy/distutils/misc_util.py#L497-L504
def is_sequence(seq):
if isinstance(seq, str):
return False
try:
len(seq)
except Exception:
return False
return True
class AxisError(ValueError, IndexError):
pass
class UFuncTypeError(TypeError, RuntimeError):
pass
def cast_if_needed(tensor, dtype):
# NB: no casting if dtype=None
if dtype is not None and tensor.dtype != dtype:
tensor = tensor.to(dtype)
return tensor
def cast_int_to_float(x):
# cast integers and bools to the default float dtype
if _dtypes_impl._category(x.dtype) < 2:
x = x.to(_dtypes_impl.default_dtypes().float_dtype)
return x
# a replica of the version in ./numpy/numpy/core/src/multiarray/common.h
def normalize_axis_index(ax, ndim, argname=None):
if not (-ndim <= ax < ndim):
raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}")
if ax < 0:
ax += ndim
return ax
# from https://github.com/numpy/numpy/blob/main/numpy/core/numeric.py#L1378
def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False):
"""
Normalizes an axis argument into a tuple of non-negative integer axes.
This handles shorthands such as ``1`` and converts them to ``(1,)``,
as well as performing the handling of negative indices covered by
`normalize_axis_index`.
By default, this forbids axes from being specified multiple times.
Used internally by multi-axis-checking logic.
Parameters
----------
axis : int, iterable of int
The un-normalized index or indices of the axis.
ndim : int
The number of dimensions of the array that `axis` should be normalized
against.
argname : str, optional
A prefix to put before the error message, typically the name of the
argument.
allow_duplicate : bool, optional
If False, the default, disallow an axis from being specified twice.
Returns
-------
normalized_axes : tuple of int
The normalized axis index, such that `0 <= normalized_axis < ndim`
"""
# Optimization to speed-up the most common cases.
if type(axis) not in (tuple, list):
try:
axis = [operator.index(axis)]
except TypeError:
pass
# Going via an iterator directly is slower than via list comprehension.
axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis])
if not allow_duplicate and len(set(axis)) != len(axis):
if argname:
raise ValueError(f"repeated axis in `{argname}` argument")
else:
raise ValueError("repeated axis")
return axis
def allow_only_single_axis(axis):
if axis is None:
return axis
if len(axis) != 1:
raise NotImplementedError("does not handle tuple axis")
return axis[0]
def expand_shape(arr_shape, axis):
# taken from numpy 1.23.x, expand_dims function
if type(axis) not in (list, tuple):
axis = (axis,)
out_ndim = len(axis) + len(arr_shape)
axis = normalize_axis_tuple(axis, out_ndim)
shape_it = iter(arr_shape)
shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)]
return shape
def apply_keepdims(tensor, axis, ndim):
if axis is None:
# tensor was a scalar
shape = (1,) * ndim
tensor = tensor.expand(shape).contiguous()
else:
shape = expand_shape(tensor.shape, axis)
tensor = tensor.reshape(shape)
return tensor
def axis_none_flatten(*tensors, axis=None):
"""Flatten the arrays if axis is None."""
if axis is None:
tensors = tuple(ar.flatten() for ar in tensors)
return tensors, 0
else:
return tensors, axis
def typecast_tensor(t, target_dtype, casting):
"""Dtype-cast tensor to target_dtype.
Parameters
----------
t : torch.Tensor
The tensor to cast
target_dtype : torch dtype object
The array dtype to cast all tensors to
casting : str
The casting mode, see `np.can_cast`
Returns
-------
`torch.Tensor` of the `target_dtype` dtype
Raises
------
ValueError
if the argument cannot be cast according to the `casting` rule
"""
can_cast = _dtypes_impl.can_cast_impl
if not can_cast(t.dtype, target_dtype, casting=casting):
raise TypeError(
f"Cannot cast array data from {t.dtype} to"
f" {target_dtype} according to the rule '{casting}'"
)
return cast_if_needed(t, target_dtype)
def typecast_tensors(tensors, target_dtype, casting):
return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors)
def _try_convert_to_tensor(obj):
try:
tensor = torch.as_tensor(obj)
except Exception as e:
mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}."
raise NotImplementedError(mesg) # noqa: TRY200
return tensor
def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0):
"""The core logic of the array(...) function.
Parameters
----------
obj : tensor_like
The thing to coerce
dtype : torch.dtype object or None
Coerce to this torch dtype
copy : bool
Copy or not
ndmin : int
The results as least this many dimensions
is_weak : bool
Whether obj is a weakly typed python scalar.
Returns
-------
tensor : torch.Tensor
a tensor object with requested dtype, ndim and copy semantics.
Notes
-----
This is almost a "tensor_like" coersion function. Does not handle wrapper
ndarrays (those should be handled in the ndarray-aware layer prior to
invoking this function).
"""
if isinstance(obj, torch.Tensor):
tensor = obj
else:
# tensor.dtype is the pytorch default, typically float32. If obj's elements
# are not exactly representable in float32, we've lost precision:
# >>> torch.as_tensor(1e12).item() - 1e12
# -4096.0
default_dtype = torch.get_default_dtype()
torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32))
try:
tensor = _try_convert_to_tensor(obj)
finally:
torch.set_default_dtype(default_dtype)
# type cast if requested
tensor = cast_if_needed(tensor, dtype)
# adjust ndim if needed
ndim_extra = ndmin - tensor.ndim
if ndim_extra > 0:
tensor = tensor.view((1,) * ndim_extra + tensor.shape)
# copy if requested
if copy:
tensor = tensor.clone()
return tensor
def ndarrays_to_tensors(*inputs):
"""Convert all ndarrays from `inputs` to tensors. (other things are intact)"""
from ._ndarray import ndarray
if len(inputs) == 0:
return ValueError()
elif len(inputs) == 1:
input_ = inputs[0]
if isinstance(input_, ndarray):
return input_.tensor
elif isinstance(input_, tuple):
result = []
for sub_input in input_:
sub_result = ndarrays_to_tensors(sub_input)
result.append(sub_result)
return tuple(result)
else:
return input_
else:
assert isinstance(inputs, tuple) # sanity check
return ndarrays_to_tensors(inputs)
|