"""Export torch work functions for binary ufuncs, rename/tweak to match numpy. This listing is further exported to public symbols in the `torch._numpy/_ufuncs.py` module. """ import torch from torch import ( # noqa: F401 add, # noqa: F401 arctan2, # noqa: F401 bitwise_and, # noqa: F401 bitwise_left_shift as left_shift, # noqa: F401 bitwise_or, # noqa: F401 bitwise_right_shift as right_shift, # noqa: F401 bitwise_xor, # noqa: F401 copysign, # noqa: F401 divide, # noqa: F401 eq as equal, # noqa: F401 float_power, # noqa: F401 floor_divide, # noqa: F401 fmax, # noqa: F401 fmin, # noqa: F401 fmod, # noqa: F401 gcd, # noqa: F401 greater, # noqa: F401 greater_equal, # noqa: F401 heaviside, # noqa: F401 hypot, # noqa: F401 lcm, # noqa: F401 ldexp, # noqa: F401 less, # noqa: F401 less_equal, # noqa: F401 logaddexp, # noqa: F401 logaddexp2, # noqa: F401 logical_and, # noqa: F401 logical_or, # noqa: F401 logical_xor, # noqa: F401 maximum, # noqa: F401 minimum, # noqa: F401 multiply, # noqa: F401 nextafter, # noqa: F401 not_equal, # noqa: F401 pow as power, # noqa: F401 remainder, # noqa: F401 remainder as mod, # noqa: F401 subtract, # noqa: F401 true_divide, # noqa: F401 ) from . import _dtypes_impl, _util # work around torch limitations w.r.t. numpy def matmul(x, y): # work around: # - RuntimeError: expected scalar type Int but found Double # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Bool' # - RuntimeError: "addmm_impl_cpu_" not implemented for 'Half' dtype = _dtypes_impl.result_type_impl(x, y) is_bool = dtype == torch.bool is_half = (x.dtype == torch.float16 or y.dtype == torch.float16) and ( x.is_cpu or y.is_cpu ) work_dtype = dtype if is_bool: work_dtype = torch.uint8 if is_half: work_dtype = torch.float32 x = _util.cast_if_needed(x, work_dtype) y = _util.cast_if_needed(y, work_dtype) result = torch.matmul(x, y) if work_dtype != dtype: result = result.to(dtype) return result # a stub implementation of divmod, should be improved after # https://github.com/pytorch/pytorch/issues/90820 is fixed in pytorch def divmod(x, y): return x // y, x % y