|  |  | 
					
						
						|  |  | 
					
						
						|  | """Assorted utilities, which do not need anything other then torch and stdlib. | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | import operator | 
					
						
						|  |  | 
					
						
						|  | import torch | 
					
						
						|  |  | 
					
						
						|  | from . import _dtypes_impl | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def is_sequence(seq): | 
					
						
						|  | if isinstance(seq, str): | 
					
						
						|  | return False | 
					
						
						|  | try: | 
					
						
						|  | len(seq) | 
					
						
						|  | except Exception: | 
					
						
						|  | return False | 
					
						
						|  | return True | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class AxisError(ValueError, IndexError): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | class UFuncTypeError(TypeError, RuntimeError): | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cast_if_needed(tensor, dtype): | 
					
						
						|  |  | 
					
						
						|  | if dtype is not None and tensor.dtype != dtype: | 
					
						
						|  | tensor = tensor.to(dtype) | 
					
						
						|  | return tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def cast_int_to_float(x): | 
					
						
						|  |  | 
					
						
						|  | if _dtypes_impl._category(x.dtype) < 2: | 
					
						
						|  | x = x.to(_dtypes_impl.default_dtypes().float_dtype) | 
					
						
						|  | return x | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def normalize_axis_index(ax, ndim, argname=None): | 
					
						
						|  | if not (-ndim <= ax < ndim): | 
					
						
						|  | raise AxisError(f"axis {ax} is out of bounds for array of dimension {ndim}") | 
					
						
						|  | if ax < 0: | 
					
						
						|  | ax += ndim | 
					
						
						|  | return ax | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def normalize_axis_tuple(axis, ndim, argname=None, allow_duplicate=False): | 
					
						
						|  | """ | 
					
						
						|  | Normalizes an axis argument into a tuple of non-negative integer axes. | 
					
						
						|  |  | 
					
						
						|  | This handles shorthands such as ``1`` and converts them to ``(1,)``, | 
					
						
						|  | as well as performing the handling of negative indices covered by | 
					
						
						|  | `normalize_axis_index`. | 
					
						
						|  |  | 
					
						
						|  | By default, this forbids axes from being specified multiple times. | 
					
						
						|  | Used internally by multi-axis-checking logic. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | axis : int, iterable of int | 
					
						
						|  | The un-normalized index or indices of the axis. | 
					
						
						|  | ndim : int | 
					
						
						|  | The number of dimensions of the array that `axis` should be normalized | 
					
						
						|  | against. | 
					
						
						|  | argname : str, optional | 
					
						
						|  | A prefix to put before the error message, typically the name of the | 
					
						
						|  | argument. | 
					
						
						|  | allow_duplicate : bool, optional | 
					
						
						|  | If False, the default, disallow an axis from being specified twice. | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | normalized_axes : tuple of int | 
					
						
						|  | The normalized axis index, such that `0 <= normalized_axis < ndim` | 
					
						
						|  | """ | 
					
						
						|  |  | 
					
						
						|  | if type(axis) not in (tuple, list): | 
					
						
						|  | try: | 
					
						
						|  | axis = [operator.index(axis)] | 
					
						
						|  | except TypeError: | 
					
						
						|  | pass | 
					
						
						|  |  | 
					
						
						|  | axis = tuple([normalize_axis_index(ax, ndim, argname) for ax in axis]) | 
					
						
						|  | if not allow_duplicate and len(set(map(int, axis))) != len(axis): | 
					
						
						|  | if argname: | 
					
						
						|  | raise ValueError(f"repeated axis in `{argname}` argument") | 
					
						
						|  | else: | 
					
						
						|  | raise ValueError("repeated axis") | 
					
						
						|  | return axis | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def allow_only_single_axis(axis): | 
					
						
						|  | if axis is None: | 
					
						
						|  | return axis | 
					
						
						|  | if len(axis) != 1: | 
					
						
						|  | raise NotImplementedError("does not handle tuple axis") | 
					
						
						|  | return axis[0] | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def expand_shape(arr_shape, axis): | 
					
						
						|  |  | 
					
						
						|  | if type(axis) not in (list, tuple): | 
					
						
						|  | axis = (axis,) | 
					
						
						|  | out_ndim = len(axis) + len(arr_shape) | 
					
						
						|  | axis = normalize_axis_tuple(axis, out_ndim) | 
					
						
						|  | shape_it = iter(arr_shape) | 
					
						
						|  | shape = [1 if ax in axis else next(shape_it) for ax in range(out_ndim)] | 
					
						
						|  | return shape | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def apply_keepdims(tensor, axis, ndim): | 
					
						
						|  | if axis is None: | 
					
						
						|  |  | 
					
						
						|  | shape = (1,) * ndim | 
					
						
						|  | tensor = tensor.expand(shape).contiguous() | 
					
						
						|  | else: | 
					
						
						|  | shape = expand_shape(tensor.shape, axis) | 
					
						
						|  | tensor = tensor.reshape(shape) | 
					
						
						|  | return tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def axis_none_flatten(*tensors, axis=None): | 
					
						
						|  | """Flatten the arrays if axis is None.""" | 
					
						
						|  | if axis is None: | 
					
						
						|  | tensors = tuple(ar.flatten() for ar in tensors) | 
					
						
						|  | return tensors, 0 | 
					
						
						|  | else: | 
					
						
						|  | return tensors, axis | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def typecast_tensor(t, target_dtype, casting): | 
					
						
						|  | """Dtype-cast tensor to target_dtype. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | t : torch.Tensor | 
					
						
						|  | The tensor to cast | 
					
						
						|  | target_dtype : torch dtype object | 
					
						
						|  | The array dtype to cast all tensors to | 
					
						
						|  | casting : str | 
					
						
						|  | The casting mode, see `np.can_cast` | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | `torch.Tensor` of the `target_dtype` dtype | 
					
						
						|  |  | 
					
						
						|  | Raises | 
					
						
						|  | ------ | 
					
						
						|  | ValueError | 
					
						
						|  | if the argument cannot be cast according to the `casting` rule | 
					
						
						|  |  | 
					
						
						|  | """ | 
					
						
						|  | can_cast = _dtypes_impl.can_cast_impl | 
					
						
						|  |  | 
					
						
						|  | if not can_cast(t.dtype, target_dtype, casting=casting): | 
					
						
						|  | raise TypeError( | 
					
						
						|  | f"Cannot cast array data from {t.dtype} to" | 
					
						
						|  | f" {target_dtype} according to the rule '{casting}'" | 
					
						
						|  | ) | 
					
						
						|  | return cast_if_needed(t, target_dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def typecast_tensors(tensors, target_dtype, casting): | 
					
						
						|  | return tuple(typecast_tensor(t, target_dtype, casting) for t in tensors) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _try_convert_to_tensor(obj): | 
					
						
						|  | try: | 
					
						
						|  | tensor = torch.as_tensor(obj) | 
					
						
						|  | except Exception as e: | 
					
						
						|  | mesg = f"failed to convert {obj} to ndarray. \nInternal error is: {str(e)}." | 
					
						
						|  | raise NotImplementedError(mesg) | 
					
						
						|  | return tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def _coerce_to_tensor(obj, dtype=None, copy=False, ndmin=0): | 
					
						
						|  | """The core logic of the array(...) function. | 
					
						
						|  |  | 
					
						
						|  | Parameters | 
					
						
						|  | ---------- | 
					
						
						|  | obj : tensor_like | 
					
						
						|  | The thing to coerce | 
					
						
						|  | dtype : torch.dtype object or None | 
					
						
						|  | Coerce to this torch dtype | 
					
						
						|  | copy : bool | 
					
						
						|  | Copy or not | 
					
						
						|  | ndmin : int | 
					
						
						|  | The results as least this many dimensions | 
					
						
						|  | is_weak : bool | 
					
						
						|  | Whether obj is a weakly typed python scalar. | 
					
						
						|  |  | 
					
						
						|  | Returns | 
					
						
						|  | ------- | 
					
						
						|  | tensor : torch.Tensor | 
					
						
						|  | a tensor object with requested dtype, ndim and copy semantics. | 
					
						
						|  |  | 
					
						
						|  | Notes | 
					
						
						|  | ----- | 
					
						
						|  | This is almost a "tensor_like" coersion function. Does not handle wrapper | 
					
						
						|  | ndarrays (those should be handled in the ndarray-aware layer prior to | 
					
						
						|  | invoking this function). | 
					
						
						|  | """ | 
					
						
						|  | if isinstance(obj, torch.Tensor): | 
					
						
						|  | tensor = obj | 
					
						
						|  | else: | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | default_dtype = torch.get_default_dtype() | 
					
						
						|  | torch.set_default_dtype(_dtypes_impl.get_default_dtype_for(torch.float32)) | 
					
						
						|  | try: | 
					
						
						|  | tensor = _try_convert_to_tensor(obj) | 
					
						
						|  | finally: | 
					
						
						|  | torch.set_default_dtype(default_dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | tensor = cast_if_needed(tensor, dtype) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | ndim_extra = ndmin - tensor.ndim | 
					
						
						|  | if ndim_extra > 0: | 
					
						
						|  | tensor = tensor.view((1,) * ndim_extra + tensor.shape) | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | if copy: | 
					
						
						|  | tensor = tensor.clone() | 
					
						
						|  |  | 
					
						
						|  | return tensor | 
					
						
						|  |  | 
					
						
						|  |  | 
					
						
						|  | def ndarrays_to_tensors(*inputs): | 
					
						
						|  | """Convert all ndarrays from `inputs` to tensors. (other things are intact)""" | 
					
						
						|  | from ._ndarray import ndarray | 
					
						
						|  |  | 
					
						
						|  | if len(inputs) == 0: | 
					
						
						|  | return ValueError() | 
					
						
						|  | elif len(inputs) == 1: | 
					
						
						|  | input_ = inputs[0] | 
					
						
						|  | if isinstance(input_, ndarray): | 
					
						
						|  | return input_.tensor | 
					
						
						|  | elif isinstance(input_, tuple): | 
					
						
						|  | result = [] | 
					
						
						|  | for sub_input in input_: | 
					
						
						|  | sub_result = ndarrays_to_tensors(sub_input) | 
					
						
						|  | result.append(sub_result) | 
					
						
						|  | return tuple(result) | 
					
						
						|  | else: | 
					
						
						|  | return input_ | 
					
						
						|  | else: | 
					
						
						|  | assert isinstance(inputs, tuple) | 
					
						
						|  | return ndarrays_to_tensors(inputs) | 
					
						
						|  |  |