| import warnings |
|
|
| import numpy as _np |
|
|
| import autograd.builtins as builtins |
| from autograd.extend import notrace_primitive, primitive |
|
|
| if _np.lib.NumpyVersion(_np.__version__) >= "2.0.0": |
| from numpy._core.einsumfunc import _parse_einsum_input |
| else: |
| from numpy.core.einsumfunc import _parse_einsum_input |
|
|
| numpy_version = _np.__version__ |
|
|
| notrace_functions = [_np.ndim, _np.shape, _np.iscomplexobj, _np.result_type] |
|
|
|
|
| def wrap_intdtype(cls): |
| class IntdtypeSubclass(cls): |
| __new__ = notrace_primitive(cls.__new__) |
|
|
| return IntdtypeSubclass |
|
|
|
|
| def wrap_namespace(old, new): |
| unchanged_types = {float, int, type(None), type} |
| int_types = {_np.int8, _np.int16, _np.int32, _np.int64, _np.integer} |
| for name, obj in old.items(): |
| if obj in notrace_functions: |
| new[name] = notrace_primitive(obj) |
| elif callable(obj) and type(obj) is not type: |
| new[name] = primitive(obj) |
| elif type(obj) is type and obj in int_types: |
| new[name] = wrap_intdtype(obj) |
| elif type(obj) in unchanged_types: |
| new[name] = obj |
|
|
|
|
| wrap_namespace(_np.__dict__, globals()) |
|
|
| |
|
|
|
|
| @primitive |
| def concatenate_args(axis, *args): |
| return _np.concatenate(args, axis).view(ndarray) |
|
|
|
|
| concatenate = lambda arr_list, axis=0: concatenate_args(axis, *arr_list) |
| vstack = row_stack = lambda tup: concatenate([atleast_2d(_m) for _m in tup], axis=0) |
|
|
|
|
| def hstack(tup): |
| arrs = [atleast_1d(_m) for _m in tup] |
| if arrs[0].ndim == 1: |
| return concatenate(arrs, 0) |
| return concatenate(arrs, 1) |
|
|
|
|
| def column_stack(tup): |
| arrays = [] |
| for v in tup: |
| arr = array(v) |
| if arr.ndim < 2: |
| arr = array(arr, ndmin=2).T |
| arrays.append(arr) |
| return concatenate(arrays, 1) |
|
|
|
|
| def array(A, *args, **kwargs): |
| t = builtins.type(A) |
| if t in (list, tuple): |
| return array_from_args(args, kwargs, *map(array, A)) |
| else: |
| return _array_from_scalar_or_array(args, kwargs, A) |
|
|
|
|
| def wrap_if_boxes_inside(raw_array, slow_op_name=None): |
| if raw_array.dtype is _np.dtype("O"): |
| if slow_op_name: |
| warnings.warn("{} is slow for array inputs. np.concatenate() is faster.".format(slow_op_name)) |
| return array_from_args((), {}, *raw_array.ravel()).reshape(raw_array.shape) |
| else: |
| return raw_array |
|
|
|
|
| @primitive |
| def _array_from_scalar_or_array(array_args, array_kwargs, scalar): |
| return _np.array(scalar, *array_args, **array_kwargs) |
|
|
|
|
| @primitive |
| def array_from_args(array_args, array_kwargs, *args): |
| return _np.array(args, *array_args, **array_kwargs) |
|
|
|
|
| def select(condlist, choicelist, default=0): |
| raw_array = _np.select(list(condlist), list(choicelist), default=default) |
| return array(list(raw_array.ravel())).reshape(raw_array.shape) |
|
|
|
|
| def stack(arrays, axis=0): |
| |
| |
| |
|
|
| arrays = [array(arr) for arr in arrays] |
| if not arrays: |
| raise ValueError("need at least one array to stack") |
|
|
| shapes = {arr.shape for arr in arrays} |
| if len(shapes) != 1: |
| raise ValueError("all input arrays must have the same shape") |
|
|
| result_ndim = arrays[0].ndim + 1 |
| if not -result_ndim <= axis < result_ndim: |
| raise IndexError("axis {0} out of bounds [-{1}, {1})".format(axis, result_ndim)) |
| if axis < 0: |
| axis += result_ndim |
|
|
| sl = (slice(None),) * axis + (None,) |
| return concatenate([arr[sl] for arr in arrays], axis=axis) |
|
|
|
|
| def append(arr, values, axis=None): |
| |
| arr = array(arr) |
| if axis is None: |
| if ndim(arr) != 1: |
| arr = ravel(arr) |
| values = ravel(array(values)) |
| axis = ndim(arr) - 1 |
| return concatenate((arr, values), axis=axis) |
|
|
|
|
| |
|
|
|
|
| class r_class: |
| def __getitem__(self, args): |
| raw_array = _np.r_[args] |
| return wrap_if_boxes_inside(raw_array, slow_op_name="r_") |
|
|
|
|
| r_ = r_class() |
|
|
|
|
| class c_class: |
| def __getitem__(self, args): |
| raw_array = _np.c_[args] |
| return wrap_if_boxes_inside(raw_array, slow_op_name="c_") |
|
|
|
|
| c_ = c_class() |
|
|
|
|
| |
| @primitive |
| def make_diagonal(D, offset=0, axis1=0, axis2=1): |
| |
| |
| |
| if not (offset == 0 and axis1 == -1 and axis2 == -2): |
| raise NotImplementedError("Currently make_diagonal only supports offset=0, axis1=-1, axis2=-2") |
|
|
| |
| |
| new_array = _np.zeros(D.shape + (D.shape[-1],)) |
| new_array_diag = _np.diagonal(new_array, offset=0, axis1=-1, axis2=-2) |
| new_array_diag.flags.writeable = True |
| new_array_diag[:] = D |
| return new_array |
|
|
|
|
| @notrace_primitive |
| def metadata(A): |
| return _np.shape(A), _np.ndim(A), _np.result_type(A), _np.iscomplexobj(A) |
|
|
|
|
| @notrace_primitive |
| def parse_einsum_input(*args): |
| return _parse_einsum_input(args) |
|
|
|
|
| @primitive |
| def _astype(A, dtype, order="K", casting="unsafe", subok=True, copy=True): |
| return A.astype(dtype, order, casting, subok, copy) |
|
|