import functools
from contextlib import nullcontext
from typing import Any, Callable, Dict, Sequence
from warnings import warn

import torch

import torch._decomp
import torch._prims

import torch._refs
import torch._refs.nn
import torch._refs.nn.functional
import torch._refs.special
import torch.overrides
from torch._prims.nvfuser_executor import NvfuserPrimOperatorSupport

from torch._prims_common import torch_function_passthrough
from torch.fx.experimental.proxy_tensor import get_isolated_graphmodule


@functools.lru_cache(None)
def torch_to_refs_map():
    """
    Mapping of torch API functions to torch._refs functions.
    E.g. torch_to_refs_map()[torch.add] == torch._refs.add
    """
    modules = [
        (torch, torch._refs),
        (torch.nn, torch._refs.nn),
        (torch.nn.functional, torch._refs.nn.functional),
        (torch.special, torch._refs.special),
        (torch.fft, torch._refs.fft),
        (torch.linalg, torch._refs.linalg),
    ]
    r: Dict[Any, Any] = {
        torch.Tensor.__invert__: torch._refs.bitwise_not,
        torch.Tensor.__xor__: torch._refs.bitwise_xor,
        torch.Tensor.__and__: torch._refs.bitwise_and,
        torch.Tensor.__or__: torch._refs.bitwise_or,
        torch.Tensor.__eq__: torch._refs.eq,
        torch.Tensor.__rsub__: torch._refs.rsub,
        torch.Tensor.__rtruediv__: torch._refs.rtruediv,
        torch.Tensor.__floordiv__: torch._refs.floor_divide,
        torch.Tensor.__rfloordiv__: torch._refs.rfloordiv,
        torch.Tensor.__pow__: torch._refs.pow,
        torch.Tensor.__rpow__: torch._refs.rpow,
        torch.Tensor.new_empty: torch._refs.new_empty,
        torch.Tensor.new_full: torch._refs.new_full,
        torch.Tensor.new_zeros: torch._refs.new_zeros,
        torch.Tensor.new_ones: torch._refs.new_ones,
        torch.Tensor.fill_: torch._refs.fill_,
        torch.Tensor.zero_: torch._refs.zero_,
        torch.Tensor.to: torch._refs.to,
        torch.Tensor.sum_to_size: torch._refs.sum_to_size,
        # TODO: Should these methods be mapped some other way?
        torch.Tensor.copy_: torch._prims.copy_to,
        torch.Tensor.resize: torch._prims.resize,
    }
    for mod_torch, mod_refs in modules:
        for s in mod_refs.__all__:  # type: ignore[attr-defined]
            r[mod_torch.__dict__.get(s)] = mod_refs.__dict__.get(s)

    # Support remapping torch.Tensor.foo to _refs.foo
    for s in dir(torch.Tensor):
        if s in torch._refs.__all__:
            r[getattr(torch.Tensor, s)] = torch._refs.__dict__.get(s)
    return r


@functools.lru_cache(None)
def all_prims():
    """
    Set of all prim functions, e.g., torch._prims.add in all_prims()
    """
    return {torch._prims.__dict__.get(s) for s in torch._prims.__all__}


class NvfuserPrimsMode(torch.overrides.TorchFunctionMode):
    """
    Switches the interpretation of torch.ops.prims.* functions to
    use nvFuser's prims in torch.ops.nvprims.*

    >>> # xdoctest: +SKIP("undefined vars")
    >>> with NvfuserPrimsMode():
    ...     torch.ops.prims.add(x, y)  # calls torch.ops.nvprims.add(x, y)

    By default, this context manager will fall back on the torch.ops.prims* if the
    nvprim does not exist.
    It's possible to skip certain prims by passing their names to the skip_ops
    argument. skip_ops is expected to be a sequence of strings, e.g.,
    ["prims.add.default"] In order to check the expected name of a prim, one can
    use the `torch.overrides.resolve_name`.

    >>> # xdoctest: +SKIP("undefined vars")
    >>> with NvfuserPrimsMode(skips_ops=("prims.add.default")):
    ...     torch.ops.prims.add.default(x, y)  # does not call torch.ops.nvprims.add.default(x, y)
    """

    def __init__(self, *, skip_ops=()):
        self.skip_ops = skip_ops

    def __torch_function__(
        self,
        orig_func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Dict = None,
    ):
        if kwargs is None:
            kwargs = {}

        # If the function is in the skip list, then we don't want to
        # remap it to the nvprims.
        if torch.overrides.resolve_name(orig_func) in self.skip_ops:
            return orig_func(*args, **kwargs)

        if isinstance(orig_func, torch._ops.OpOverload) or isinstance(
            orig_func, torch._ops.OpOverloadPacket
        ):
            namespace = str(orig_func).split(".")[0]
            name = str(orig_func).split(".")[1]
            if namespace == "prims":
                nvfunc = getattr(torch.ops.nvprims, name, None)
                if nvfunc is not None:
                    return nvfunc(*args, **kwargs)
        return orig_func(*args, **kwargs)


class TorchRefsMode(torch.overrides.TorchFunctionMode):
    """
    Switches the interpretation of torch.* functions and Tensor methods to
    use PrimTorch refs in torch._refs.  (Direct calls to _refs are unaffected.)

    >>> # xdoctest: +SKIP
    >>> with TorchRefsMode():
    ...     torch.add(x, y)  # calls torch._refs.add(x, y)

    By default, this context manager will fall back on the torch.* if the
    ref does not exist; set strict=True to error if this occurs.
    If the ref exists we still would like to fall back on the torch.* sometimes,
    this behavior can be customized by passing a function to should_fallback_fn.
    """

    def __init__(
        self,
        strict=False,
        should_fallback_fn=lambda *_: False,
        prims_mode_cls=nullcontext,
    ):
        self.strict = strict
        self.should_fallback_fn = should_fallback_fn
        self.prims_mode_cls = prims_mode_cls

    def __torch_function__(
        self,
        orig_func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Dict = None,
    ):
        if kwargs is None:
            kwargs = {}
        # For primitive operations, run them as is without interception
        # Unless we are in prims_mode, in which case we want to use nvprims
        if orig_func in torch_function_passthrough or orig_func in all_prims():
            with self.prims_mode_cls():
                return orig_func(*args, **kwargs)
        mapping = torch_to_refs_map()
        func = mapping.get(orig_func, None)

        # For torch.ops.aten.*, use registered decompositions from torch._decomp
        # torch._decomp.decomposition_table provides a mapping from
        # torch.ops.aten.* to torch._refs or torch._decomp.decompositions
        # implementations.
        # There're other ways to implement this functionality,
        # see https://github.com/pytorch/pytorch/pull/82657#discussion_r939776417
        if func is None and isinstance(orig_func, torch._ops.OpOverload):
            func = torch._decomp.decomposition_table.get(orig_func, None)

        if func is not None:
            # If the ref exists query whether we should use it or not
            if self.should_fallback_fn(self, orig_func, func, args, kwargs):
                return orig_func(*args, **kwargs)
            # torch calls inside func should be interpreted as refs calls
            with self:
                return func(*args, **kwargs)
        if self.strict:
            raise RuntimeError(
                f"no _refs support for {torch.overrides.resolve_name(orig_func)}"
            )
        return orig_func(*args, **kwargs)


def _is_node_supported_nvfuser(node):
    return (
        node.op == "call_function"
        and getattr(node.target, "impl_nvfuser", None) is not None
    )


def _is_func_unsupported_nvfuser(
    torch_function_mode, orig_func, func, args, kwargs, *, skip_ops=()
):
    """
    This function traces the `func` under `torch_function_mode` and checks if
    any of the traced nodes are not supported by nvFuser. If so, we should
    fallback to the original function.

    `skip_ops` argument is expected to be a list of strings of function names
    that would match with `torch.overrides.resolve_name`.

    Args:
        torch_function_mode: The torch_function_mode context manager. orig_func:
        The original function, its name will be used to check if
                   it should be skipped.
        func: The function to be traced. args: The args to be passed to the
        function. kwargs: The kwargs to be passed to the function.
    Keyword args:
        skip_ops: A list of ops to skip when checking if the function is
        supported.
    """
    # One supported case is easy to check: if the resolved name of the original
    # function in the skip list, skip it.
    if torch.overrides.resolve_name(orig_func) in skip_ops:
        return True

    with torch_function_mode:
        try:
            gm = get_isolated_graphmodule(func, args, kwargs)
        except Exception as e:
            warn(
                "get_isolated_graphmodule failed on decomposition: "
                + func.__name__
                + " with error message: "
                + str(e)
            )
            # returns unsupported when tracing fails.
            return True

    supported_ops = NvfuserPrimOperatorSupport()
    call_function_nodes = filter(lambda n: n.op == "call_function", gm.graph.nodes)
    any_unsupported = any(
        not supported_ops.is_node_supported(None, node) for node in call_function_nodes
    )
    return any_unsupported


class TorchRefsNvfuserCapabilityMode(TorchRefsMode):
    def __init__(self, *, skip_ops=()):
        super().__init__(
            strict=False,
            should_fallback_fn=functools.partial(
                _is_func_unsupported_nvfuser, skip_ops=skip_ops
            ),
            prims_mode_cls=functools.partial(NvfuserPrimsMode, skip_ops=skip_ops),
        )

    def _is_var_mean(self, func):
        return "torch.var_mean" == torch.overrides.resolve_name(func) or (
            (
                isinstance(func, torch._ops.OpOverload)
                or isinstance(func, torch._ops.OpOverloadPacket)
            )
            and "aten.var_mean" in str(func)
        )

    def _is_native_batch_norm(self, func):
        return "torch.native_batch_norm" == torch.overrides.resolve_name(func) or (
            func == torch.ops.aten.native_batch_norm.default
            or func == torch.ops.aten.native_batch_norm
        )

    def _is_rand_like(self, func):
        result = "torch.rand_like" == torch.overrides.resolve_name(func) or (
            func == torch.ops.aten.rand_like or func == torch.ops.aten.rand_like.default
        )
        return result

    def __torch_function__(
        self,
        orig_func: Callable,
        types: Sequence,
        args: Sequence[Any] = (),
        kwargs: Dict = None,
    ):
        if kwargs is None:
            kwargs = {}
        # First we intercept calls for nvfuser-specific prims bypassing generic torch._refs
        if self._is_var_mean(orig_func):
            return torch.ops.nvprims.var_mean(*args, **kwargs)

        if self._is_native_batch_norm(orig_func):
            return torch.ops.nvprims.native_batch_norm(*args, **kwargs)

        if self._is_rand_like(orig_func):
            if len(kwargs) > 0:
                warn("rand_like has ignored kwars!")
            return torch.ops.nvprims.rand_like(*args)

        # Then we use TorchRefsMode to interpret the rest
        return super().__torch_function__(orig_func, types, args, kwargs)