import re
from dataclasses import dataclass
from typing import List, Optional

import torch

from transformers.modeling_flash_attention_utils import _flash_attention_forward


@dataclass
class IterStep:
    """A helper class for the iteration plan"""
    layer_slice: slice = slice(None)
    requires_grad: bool = True
    update: bool = True

@dataclass
class LayerType:
    """A helper class to collect the layer type information"""
    layer_idx: int
    use_sliding_window: bool
    attends_to: int
    attends_top: bool
    computes_kv: bool

class LayerTypeParser:
    """
    A helper class to parse the layer type string and provide some useful methods.

    Arguments:
        layer_type (str): A string of integers separated by underscores. The i-th integer
            means the layer will use the key-value pair in the i-th layer as the kv cache.
            Special characters may be placed after the integers:
            - `s` means the layer will use sliding window attention.

    >>> layer_type = LayerTypeParser("0_0_0_5s_5s_5s_8_8_8")[3]
    >>> layer_type.attends_to
    5
    >>> layer_type.attends_top
    True
    >>> layer_type.use_sliding_window
    True
    """
    def __init__(self, layer_type: str):
        self._layer_type = layer_type

        # parse the layer type
        self.layer_indices = []
        self.sliding_window = []
        for s in layer_type.split("_"):
            layer_idx, sliding_window = re.match(r"^(\d+)(s)?$", s).groups()
            self.layer_indices.append(int(layer_idx))
            self.sliding_window.append(bool(sliding_window))

    def __len__(self):
        return len(self.layer_indices)

    def __getitem__(self, layer_idx: int) -> LayerType:
        """return the layer type information for the given layer index"""
        return LayerType(
            layer_idx=layer_idx,
            use_sliding_window=self.sliding_window[layer_idx],
            attends_to=self.layer_indices[layer_idx],
            attends_top=self.layer_indices[layer_idx] > layer_idx,
            computes_kv=layer_idx in self.layer_indices,
        )

    def use_sliding_window(self) -> bool:
        """whether there exists a layer that uses sliding window attention"""
        return any(self.sliding_window)

    def attends_top(self) -> bool:
        """whether there exists a layer that attends to layers above it"""
        return any(self.layer_indices[i] > i for i in range(len(self)))

    def iteration_plan(self, forward_passes: int = 7, backward_passes: int = 2) -> List[IterStep]:
        """
        Return a iteration plan for the layer types. The plan is a list of IterStep objects.
        """
        # if there is no cyclic dependency, return the default plan
        if not self.attends_top():
            return [IterStep()]

        # otherwise, return the plan for the cyclic dependency
        plan = []
        i = 0
        while i < len(self):

            # if the layer attends to top layers, resolve the cyclic dependency
            if self[i].attends_top:

                # find the top layer in the cyclic dependency
                top = self[i].attends_to
                while top < max(self.layer_indices[i: top + 1]):
                    top = max(self.layer_indices[i: top + 1])
                top += 1

                # create iteration plan for this group
                layer_slice = slice(i, top)
                plan.extend([
                    *forward_passes * [IterStep(layer_slice, requires_grad=False, update=False)],
                    *(backward_passes - 1) * [IterStep(layer_slice, update=False)],
                    IterStep(layer_slice)
                ])

            # otherwise, create a default plan
            else:

                top = i + 1
                while top < len(self) and not self[top].attends_top:
                    top += 1
                plan.append(IterStep(slice(i, top)))

            # update the index
            i = top

        return plan

    def check(self, num_hidden_layers: int):
        """Check if the layer type is valid"""
        if len(self.layer_indices) != num_hidden_layers:
            raise ValueError("The number of layer types should be equal to the number of hidden layers.")
        for i in range(num_hidden_layers):
            if self.layer_indices[i] not in range(num_hidden_layers):
                raise ValueError("The layer type should be in the range of the number of hidden layers.")


def flash_attention_forward(
    query_states: torch.Tensor,
    key_states: torch.Tensor,
    value_states: torch.Tensor,
    attention_mask: torch.Tensor,
    query_length: int,
    is_causal: bool,
    dropout: float = 0.0,
    position_ids: Optional[torch.Tensor] = None,
    softmax_scale: Optional[float] = None,
    sliding_window: Optional[int] = None,
    use_top_left_mask: bool = False,
    softcap: Optional[float] = None,
    deterministic: bool = None,
    no_diag: bool = False,
):
    """
    This function is a wrapper around the _flash_attention_forward function in the
    transformers library. It adds support to mask the diagonal elements of the attention
    matrix. The diagonal mask is used to resolve the cyclic dependencies in the LCKV model.
    """
    prune_query = False
    if no_diag:
        if key_states.size(1) == 1:
            b, l, _, d = value_states.size()
            _, _, h, _ = query_states.size()
            return value_states.new_zeros((b, l, h, d))

        if key_states.size(1) == query_states.size(1):
            prune_query = True
            query_states = query_states[:, 1:, :, :]
            query_length -= 1

            if attention_mask is not None:
                attention_mask = attention_mask[:, 1:]

        key_states = key_states[:, :-1, :, :]
        value_states = value_states[:, :-1, :, :]

        if sliding_window is not None:
            sliding_window = sliding_window - 1

    result: torch.Tensor = _flash_attention_forward(
        query_states=query_states,
        key_states=key_states,
        value_states=value_states,
        attention_mask=attention_mask,
        query_length=query_length,
        is_causal=is_causal,
        dropout=dropout,
        position_ids=position_ids,
        softmax_scale=softmax_scale,
        sliding_window=sliding_window,
        use_top_left_mask=use_top_left_mask,
        softcap=softcap,
        deterministic=deterministic,
    )

    if prune_query:
        b, _, h, d = result.size()
        result = torch.cat([result.new_zeros((b, 1, h, d)), result], dim=1)

    return result