import math
import torch
import triton
import triton.language as tl

# Disabling autotune for now, set num_warps=4 if headdim=64 and num_warps=8 if headdim=128
# @triton.autotune(
#     configs=[
#         triton.Config({"BLOCK_M": 128, "BLOCK_N": 128}, num_warps=4, num_stages=1),
#         # This config has a race condition when EVEN_M == False, disabling it for now.
#         # triton.Config({"BLOCK_M": 64, "BLOCK_N": 64}, num_warps=4, num_stages=1),
#     ],
#     key=['CACHE_KEY_SEQLEN_Q', 'CACHE_KEY_SEQLEN_K', 'BIAS_TYPE', 'IS_CAUSAL', 'BLOCK_HEADDIM']
# )
@triton.heuristics(
    {
        "EVEN_M": lambda args: args["seqlen_q"] % args["BLOCK_M"] == 0,
        "EVEN_N": lambda args: args["seqlen_k"] % args["BLOCK_N"] == 0,
        "EVEN_C": lambda args: args["nchunks"] % args["BLOCK_N"] == 0,
        "EVEN_W": lambda args: args["WINDOW_SIZE"] % args["BLOCK_N"] == 0,
        "EVEN_HEADDIM": lambda args: args["headdim"] == args["BLOCK_HEADDIM"],
    }
)
@triton.jit
def _fwd_eva_agg_kernel(
    Q,
    K,
    V,
    RFA_K,
    RFA_V,
    WindowMask,
    Out,
    softmax_scale,
    stride_qb, stride_qh, stride_qm,
    stride_kb, stride_kh, stride_kn,
    stride_vb, stride_vh, stride_vn,
    stride_rfa_kb, stride_rfa_kh, stride_rfa_kc,
    stride_rfa_vb, stride_rfa_vh, stride_rfa_vc,
    stride_mb, stride_mm,
    stride_ob, stride_oh, stride_om,
    nheads,
    seqlen_q,
    seqlen_k,
    nchunks,
    headdim,
    CACHE_KEY_SEQLEN_Q, # TODO: why keeping this
    CACHE_KEY_SEQLEN_K, # TODO: why keeping this
    CACHE_KEY_NCHUNKS, # TODO: why keeping this
    CHUNKS_PER_WINDOW: tl.constexpr,
    WINDOW_SIZE: tl.constexpr,
    MASK_TYPE: tl.constexpr,
    EMPTY_RFA_KV: tl.constexpr,
    BLOCK_HEADDIM: tl.constexpr,
    EVEN_M: tl.constexpr,
    EVEN_N: tl.constexpr,
    EVEN_W: tl.constexpr,
    EVEN_C: tl.constexpr,
    EVEN_HEADDIM: tl.constexpr,
    BLOCK_M: tl.constexpr,
    BLOCK_N: tl.constexpr,
):
    start_m = tl.program_id(0)
    off_bh = tl.program_id(1)
    off_h = off_bh % nheads
    off_b = off_bh // nheads
    # initialize offsets
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_w = (start_m * BLOCK_M) // WINDOW_SIZE
    offs_n = tl.arange(0, BLOCK_N)
    offs_c = tl.arange(0, BLOCK_N)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    # TODO: add paratheses or not
    q_ptrs = (
        Q +
        off_b * stride_qb +
        off_h * stride_qh +
        (offs_m[:, None] * stride_qm + offs_d[None, :])
    )
    k_ptrs = (
        K +
        off_b * stride_kb +
        off_h * stride_kh +
        (offs_n[:, None] * stride_kn + offs_d[None, :])
    )
    v_ptrs = (
        V +
        off_b * stride_vb +
        off_h * stride_vh +
        (offs_n[:, None] * stride_vn + offs_d[None, :])
    )
    if EMPTY_RFA_KV == 0:
        rfa_k_ptrs = (
            RFA_K +
            off_b * stride_rfa_kb +
            off_h * stride_rfa_kh +
            (offs_c[:, None] * stride_rfa_kc + offs_d[None, :])
        )
        rfa_v_ptrs = (
            RFA_V +
            off_b * stride_rfa_vb +
            off_h * stride_rfa_vh +
            (offs_c[:, None] * stride_rfa_vc + offs_d[None, :])
        )

    qk_scale = softmax_scale
    qk_scale *= 1.4426950408889634  # log2(e)
    if MASK_TYPE == 1:
        m_ptrs = (
            WindowMask +
            off_b * stride_mb +
            (offs_m[:, None] * stride_mm + offs_n[None, :])
        )
    m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf")
    d_i = tl.zeros([BLOCK_M], dtype=tl.float32)
    acc_o = tl.zeros([BLOCK_M, BLOCK_HEADDIM], dtype=tl.float32)
    # load q: it will stay in SRAM throughout
    # [2022-10-30] TD: Triton bug - in the case of EVEN_M=True and EVEN_N=False, if we just call
    # tl.load(q_ptrs), we get the wrong output!
    if EVEN_M & EVEN_N:
        if EVEN_HEADDIM:
            q = tl.load(
                q_ptrs
            )
        else:
            q = tl.load(
                q_ptrs,
                mask=offs_d[None, :] < headdim,
                other=0.0
            )
    else:
        if EVEN_HEADDIM:
            q = tl.load(
                q_ptrs,
                mask=offs_m[:, None] < seqlen_q,
                other=0.0
            )
        else:
            q = tl.load(
                q_ptrs,
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim),
                other=0.0
            )
    # loop over k, v and update accumulator
    # Iterate over local singletons;
    # so we only iterate over blocks within the current window
    start_idx_n = offs_w * WINDOW_SIZE
    end_idx_n = tl.minimum((start_m + 1) * BLOCK_M, seqlen_k)
    for start_n in range(start_idx_n, end_idx_n, BLOCK_N):
        start_n = tl.multiple_of(start_n, BLOCK_N)
        # -- compute qk ----
        if EVEN_N & EVEN_M:
            if EVEN_HEADDIM:
                k = tl.load(
                    k_ptrs + start_n * stride_kn
                )
            else:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=offs_d[None, :] < headdim,
                    other=0.0
                )
        else:
            if EVEN_HEADDIM:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=(start_n + offs_n)[:, None] < seqlen_k,
                    other=0.0,
                )
            else:
                k = tl.load(
                    k_ptrs + start_n * stride_kn,
                    mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                    other=0.0,
                )
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, tl.trans(k))
        # Trying to combine the two masks seem to make the result wrong
        if not EVEN_N:  # Need to mask out otherwise the softmax is wrong
            qk += tl.where((start_n + offs_n)[None, :] < seqlen_k, 0, float("-inf"))

        if MASK_TYPE == 1:
            if EVEN_M & EVEN_W:
                mask = tl.load(
                    m_ptrs + start_n - start_idx_n
                ).to(tl.float32)
            else:
                mask = tl.load(
                    m_ptrs + start_n - start_idx_n,
                    mask=(offs_m[:, None] < seqlen_q)
                    & ((start_n - start_idx_n + offs_n)[None, :] < WINDOW_SIZE),
                    other=0.0,
                ).to(tl.float32)
            # Slightly faster to multiply the softmax_scale in the tl.exp below since the compiler
            # can then fuse the mult and add into an fma instruction. But if we have bias we need to
            # to multiply with softmax_scale here.
            # we assume mask already implies the causal masking
            qk = qk * qk_scale + mask
            m_ij = tl.maximum(tl.max(qk, 1), m_i)
            p = tl.exp2(qk - m_ij[:, None])
        else:
            qk += tl.where(offs_m[:, None] >= (start_n + offs_n)[None, :], 0, float("-inf"))
            m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
            p = tl.exp2(qk * qk_scale - m_ij[:, None])

        d_ij = tl.sum(p, 1)

        # scale acc_o
        prev_scale = tl.exp2(m_i - m_ij)
        # # -- update output accumulator --
        acc_o = acc_o * prev_scale[:, None]
        # update acc_o
        if EVEN_N & EVEN_M:  # If we just do "if EVEN_N", there seems to be some race condition
            if EVEN_HEADDIM:
                v = tl.load(
                    v_ptrs + start_n * stride_vn
                )
            else:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=offs_d[None, :] < headdim,
                    other=0.0
                )
        else:
            if EVEN_HEADDIM:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=(start_n + offs_n)[:, None] < seqlen_k,
                    other=0.0,
                )
            else:
                v = tl.load(
                    v_ptrs + start_n * stride_vn,
                    mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim),
                    other=0.0,
                )
        p = p.to(v.dtype)
        acc_o = tl.dot(p, v, acc_o)

        # -- update statistics
        d_i = d_i * prev_scale + d_ij
        m_i = m_ij

    if EMPTY_RFA_KV == 0:
        # Iterate over RFA chunks
        # we only iterate over chunks before the current local singleton window
        end_idx_c = tl.minimum(offs_w * CHUNKS_PER_WINDOW, nchunks)
        for start_c in range(0, end_idx_c, BLOCK_N):
            start_c = tl.multiple_of(start_c, BLOCK_N)
            # -- compute qk ----
            if EVEN_C & EVEN_M:
                if EVEN_HEADDIM:
                    rfa_k = tl.load(
                        rfa_k_ptrs + start_c * stride_rfa_kc
                    )
                else:
                    rfa_k = tl.load(
                        rfa_k_ptrs + start_c * stride_rfa_kc,
                        mask=offs_d[None, :] < headdim,
                        other=0.0
                    )
            else:
                if EVEN_HEADDIM:
                    rfa_k = tl.load(
                        rfa_k_ptrs + start_c * stride_rfa_kc,
                        mask=(start_c + offs_c)[:, None] < nchunks,
                        other=0.0,
                    )
                else:
                    rfa_k = tl.load(
                        rfa_k_ptrs + start_c * stride_rfa_kc,
                        mask=((start_c + offs_c)[:, None] < nchunks) & (offs_d[None, :] < headdim),
                        other=0.0,
                    )
            qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
            qk += tl.dot(q, tl.trans(rfa_k))
            # Trying to combine the two masks seem to make the result wrong
            if not EVEN_C:  # Need to mask out otherwise the softmax is wrong
                qk += tl.where((start_c + offs_c)[None, :] < nchunks, 0, float("-inf"))

            m_ij = tl.maximum(tl.max(qk, 1) * qk_scale, m_i)
            p = tl.exp2(qk * qk_scale - m_ij[:, None])

            d_ij = tl.sum(p, 1)

            # scale acc_o
            prev_scale = tl.exp2(m_i - m_ij)
            # # -- update output accumulator --
            acc_o = acc_o * prev_scale[:, None]
            # update acc_o
            # TODO: If we just do "if EVEN_N", there seems to be some race condition ?
            if EVEN_C & EVEN_M:  
                if EVEN_HEADDIM:
                    rfa_v = tl.load(
                        rfa_v_ptrs + start_c * stride_rfa_vc
                    )
                else:
                    rfa_v = tl.load(
                        rfa_v_ptrs + start_c * stride_rfa_vc,
                        mask=offs_d[None, :] < headdim,
                        other=0.0
                    )
            else:
                if EVEN_HEADDIM:
                    rfa_v = tl.load(
                        rfa_v_ptrs + start_c * stride_rfa_vc,
                        mask=(start_c + offs_n)[:, None] < nchunks,
                        other=0.0,
                    )
                else:
                    rfa_v = tl.load(
                        rfa_v_ptrs + start_c * stride_rfa_vc,
                        mask=((start_c + offs_n)[:, None] < nchunks) & (offs_d[None, :] < headdim),
                        other=0.0,
                    )
            p = p.to(rfa_v.dtype)
            acc_o = tl.dot(p, rfa_v, acc_o)

            # -- update statistics
            d_i = d_i * prev_scale + d_ij
            m_i = m_ij

    # BUG: have to store and immediately load
    acc_o = acc_o / d_i[:, None]
    # TODO: understand why rematerialize offsets to save registers?
    start_m = tl.program_id(0)
    offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M)
    offs_d = tl.arange(0, BLOCK_HEADDIM)
    out_ptrs = (
        Out +
        off_b * stride_ob +
        off_h * stride_oh +
        (offs_m[:, None] * stride_om + offs_d[None, :])
    )
    if EVEN_M:
        if EVEN_HEADDIM:
            tl.store(
                out_ptrs, acc_o
            )
        else:
            tl.store(
                out_ptrs, acc_o,
                mask=offs_d[None, :] < headdim
            )
    else:
        if EVEN_HEADDIM:
            tl.store(
                out_ptrs, acc_o,
                mask=offs_m[:, None] < seqlen_q
            )
        else:
            tl.store(
                out_ptrs, acc_o,
                mask=(offs_m[:, None] < seqlen_q) & (offs_d[None, :] < headdim)
            )

def triton_eva_agg_fwd(q, k, v, rfa_k, rfa_v, window_mask, softmax_scale, window_size, chunks_per_window):
    if rfa_k is None and rfa_v is None:
        empty_rfa_kv = 1

        q, k, v = [
            x if x.stride(-1) == 1 else x.contiguous() 
            for x in [q, k, v]
        ]
    else:
        assert rfa_k is not None and rfa_v is not None, "Both rfa_k and rfa_v must either be None or have values at the same time."
        empty_rfa_kv = 0

        q, k, v, rfa_k, rfa_v = [
            x if x.stride(-1) == 1 else x.contiguous() 
            for x in [q, k, v, rfa_k, rfa_v]
        ]

    # shape constraints
    batch, nheads, seqlen_q, head_dim = q.shape
    _,     _,      seqlen_k, _        = k.shape
    if empty_rfa_kv == 0:
        nchunks = rfa_k.shape[-2]
        assert rfa_k.shape == (batch, nheads, nchunks, head_dim)
        assert rfa_v.shape == (batch, nheads, nchunks, head_dim)
        assert q.dtype == k.dtype == v.dtype == rfa_k.dtype == rfa_v.dtype
    else:
        nchunks = 0
        assert q.dtype == k.dtype == v.dtype, "All tensors must have the same type"
    assert k.shape == (batch, nheads, seqlen_k, head_dim)
    assert v.shape == (batch, nheads, seqlen_k, head_dim)

    assert head_dim <= 128, "We only test head dimensions up to 128"
    # assert q.dtype in [torch.float16, torch.bfloat16], "Only support fp16 and bf16"
    assert q.dtype in [torch.bfloat16, torch.float], "Only support bf16 and fp32 for now"
    assert q.is_cuda and k.is_cuda and v.is_cuda
    softmax_scale = softmax_scale or 1.0 / math.sqrt(head_dim)

    mask_type = 0
    if window_mask is not None:
        mask_type = 1
        assert window_mask.dtype == q.dtype, torch.float
        assert window_mask.is_cuda
        assert window_mask.dim() == 4
        assert window_mask.shape == (batch, 1, seqlen_q, window_size)
        if window_mask.stride(-1) != 1:
            window_mask = window_mask.contiguous()
    mask_strides = (
        (window_mask.stride(0), window_mask.stride(2)) 
        if mask_type == 1 else 
        (0, 0)
    )

    rfa_k_strides = (
        (rfa_k.stride(0), rfa_k.stride(1), rfa_k.stride(2))
        if empty_rfa_kv == 0 else
        (0, 0, 0)
    )
    rfa_v_strides = (
        (rfa_v.stride(0), rfa_v.stride(1), rfa_v.stride(2))
        if empty_rfa_kv == 0 else
        (0, 0, 0)
    )
    assert chunks_per_window > 0, "chunks_per_window must be greater than 0"

    o = torch.empty_like(q)

    BLOCK_HEADDIM = max(triton.next_power_of_2(head_dim), 16)
    if q.dtype == torch.float:
        BLOCK = 64
    else:
        BLOCK = 128
    num_warps = 4 if head_dim <= 64 else 8
    assert chunks_per_window >= BLOCK, "chunks_per_window must be greater than BLOCK" 
    # WINDOW_MASK_TYPE:
    # - 0: regular causal mask, simply None
    # - 1: the shape must be B, 1, W, I, J

    grid = lambda META: (triton.cdiv(seqlen_q, META["BLOCK_M"]), batch * nheads)
    _fwd_eva_agg_kernel[grid](
        q,
        k,
        v,
        rfa_k,
        rfa_v,
        window_mask,
        o,
        softmax_scale,
        q.stride(0), q.stride(1), q.stride(2),
        k.stride(0), k.stride(1), k.stride(2),
        v.stride(0), v.stride(1), v.stride(2),
        rfa_k_strides[0], rfa_k_strides[1], rfa_k_strides[2],
        rfa_v_strides[0], rfa_v_strides[1], rfa_v_strides[2],
        mask_strides[0], mask_strides[1],
        o.stride(0), o.stride(1), o.stride(2),
        nheads,
        seqlen_q,
        seqlen_k,
        nchunks,
        head_dim,
        seqlen_q // 32,
        seqlen_k // 32,
        nchunks // 32,
        chunks_per_window,
        window_size,
        mask_type,
        empty_rfa_kv,
        BLOCK_HEADDIM,
        BLOCK_M=BLOCK,
        BLOCK_N=BLOCK,
        num_warps=num_warps,
        num_stages=1,
    )
    return o