import platform
from dataclasses import dataclass
import logging
from typing import Union, List, Optional, Tuple, Callable
import gc

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.utils.parametrize as P
from tqdm import tqdm
from transformers import LlamaModel, LlamaConfig
from transformers.cache_utils import Cache
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.utils import is_flash_attn_2_available

from ..utils import del_all
from .embed import Embed


class GPT(nn.Module):
    def __init__(
        self,
        gpt_config: dict,
        embed: Embed,
        use_flash_attn=False,
        use_vllm=False,
        device=torch.device("cpu"),
        device_gpt=torch.device("cpu"),
        logger=logging.getLogger(__name__),
    ):
        super().__init__()

        self.logger = logger

        self.device = device
        self.device_gpt = device_gpt

        self.generator = torch.Generator(device=device)

        self.num_vq = int(gpt_config["num_vq"])
        self.num_audio_tokens = int(gpt_config["num_audio_tokens"])
        self.num_text_tokens = int(gpt_config["num_text_tokens"])

        self.use_flash_attn = use_flash_attn
        self.is_te_llama = False
        self.is_vllm = use_vllm

        if self.is_vllm:
            return

        self.llama_config = self._build_llama_config(gpt_config)

        self.emb_code = [ec.__call__ for ec in embed.emb_code]
        self.emb_text = embed.emb_text.__call__
        self.head_text = embed.head_text.__call__
        self.head_code = [hc.__call__ for hc in embed.head_code]

    def from_pretrained(
        self, gpt_folder: str, embed_file_path: str, experimental=False
    ):
        if self.is_vllm and platform.system().lower() == "linux":

            from .velocity import LLM

            self.llm = LLM(
                model=gpt_folder,
                num_audio_tokens=self.num_audio_tokens,
                num_text_tokens=self.num_text_tokens,
                post_model_path=embed_file_path,
            )
            self.logger.info("vLLM model loaded")
            return

        self.gpt: LlamaModel = LlamaModel.from_pretrained(gpt_folder).to(
            self.device_gpt
        )
        del self.gpt.embed_tokens

        if (
            experimental
            and "cuda" in str(self.device_gpt)
            and platform.system().lower() == "linux"
        ):  # is TELlamaModel
            try:
                from .cuda import TELlamaModel

                self.logger.warning(
                    "Linux with CUDA, try NVIDIA accelerated TELlamaModel because experimental is enabled"
                )
                state_dict = self.gpt.state_dict()
                vanilla = TELlamaModel.from_state_dict(state_dict, self.llama_config)
                # Force mem release. Taken from huggingface code
                del state_dict, self.gpt
                gc.collect()
                self.gpt = vanilla
                self.is_te_llama = True
            except Exception as e:
                self.logger.warning(
                    f"use default LlamaModel for importing TELlamaModel error: {e}"
                )

    class Context:
        def __init__(self):
            self._interrupt = False

        def set(self, v: bool):
            self._interrupt = v

        def get(self) -> bool:
            return self._interrupt

    def _build_llama_config(
        self,
        config: dict,
    ) -> Tuple[LlamaModel, LlamaConfig]:

        if self.use_flash_attn and is_flash_attn_2_available():
            llama_config = LlamaConfig(
                **config,
                attn_implementation="flash_attention_2",
            )
            self.logger.warning(
                "enabling flash_attention_2 may make gpt be even slower"
            )
        else:
            llama_config = LlamaConfig(**config)

        return llama_config

    def prepare(self, compile=False):
        if self.use_flash_attn and is_flash_attn_2_available():
            self.gpt = self.gpt.to(dtype=torch.float16)
        if compile and not self.is_te_llama and not self.is_vllm:
            try:
                self.compile(backend="inductor", dynamic=True)
                self.gpt.compile(backend="inductor", dynamic=True)
            except RuntimeError as e:
                self.logger.warning(f"compile failed: {e}. fallback to normal mode.")

    @dataclass(repr=False, eq=False)
    class _GenerationInputs:
        position_ids: torch.Tensor
        cache_position: torch.Tensor
        use_cache: bool
        input_ids: Optional[torch.Tensor] = None
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
        attention_mask: Optional[torch.Tensor] = None
        inputs_embeds: Optional[torch.Tensor] = None

        def to(self, device: torch.device, dtype: torch.dtype):
            if self.attention_mask is not None:
                self.attention_mask = self.attention_mask.to(device, dtype=dtype)
            if self.position_ids is not None:
                self.position_ids = self.position_ids.to(device, dtype=dtype)
            if self.inputs_embeds is not None:
                self.inputs_embeds = self.inputs_embeds.to(device, dtype=dtype)
            if self.cache_position is not None:
                self.cache_position = self.cache_position.to(device, dtype=dtype)

    @torch.no_grad()
    def _prepare_generation_inputs(
        self,
        input_ids: torch.Tensor,
        past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        cache_position: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        use_cache=True,
    ) -> _GenerationInputs:
        # With static cache, the `past_key_values` is None
        # TODO joao: standardize interface for the different Cache classes and remove of this if
        has_static_cache = False
        if past_key_values is None:
            if hasattr(self.gpt.layers[0], "self_attn"):
                past_key_values = getattr(
                    self.gpt.layers[0].self_attn, "past_key_value", None
                )
            has_static_cache = past_key_values is not None

        past_length = 0
        if past_key_values is not None:
            if isinstance(past_key_values, Cache):
                past_length = (
                    int(cache_position[0])
                    if cache_position is not None
                    else past_key_values.get_seq_length()
                )
                max_cache_length = past_key_values.get_max_length()
                cache_length = (
                    past_length
                    if max_cache_length is None
                    else min(max_cache_length, past_length)
                )
            # TODO joao: remove this `else` after `generate` prioritizes `Cache` objects
            else:
                cache_length = past_length = past_key_values[0][0].shape[2]
                max_cache_length = None

            # Keep only the unprocessed tokens:
            # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
            # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as
            # input)
            if (
                attention_mask is not None
                and attention_mask.shape[1] > input_ids.shape[1]
            ):
                start = attention_mask.shape[1] - past_length
                input_ids = input_ids.narrow(1, -start, start)
            # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
            # input_ids based on the past_length.
            elif past_length < input_ids.shape[1]:
                input_ids = input_ids.narrow(
                    1, past_length, input_ids.size(1) - past_length
                )
            # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.

            # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
            if (
                max_cache_length is not None
                and attention_mask is not None
                and cache_length + input_ids.shape[1] > max_cache_length
            ):
                attention_mask = attention_mask.narrow(
                    1, -max_cache_length, max_cache_length
                )

        if attention_mask is not None and position_ids is None:
            # create position_ids on the fly for batch generation
            position_ids = attention_mask.long().cumsum(-1) - 1
            position_ids.masked_fill_(attention_mask.eq(0), 1)
            if past_key_values:
                position_ids = position_ids.narrow(
                    1, -input_ids.shape[1], input_ids.shape[1]
                )

        input_length = (
            position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
        )
        if cache_position is None:
            cache_position = torch.arange(
                past_length, past_length + input_length, device=input_ids.device
            )
        else:
            cache_position = cache_position.narrow(0, -input_length, input_length)

        if has_static_cache:
            past_key_values = None

        model_inputs = self._GenerationInputs(
            position_ids=position_ids,
            cache_position=cache_position,
            use_cache=use_cache,
        )

        # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs.inputs_embeds = inputs_embeds
        else:
            # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
            # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
            # TODO: use `next_tokens` directly instead.
            model_inputs.input_ids = input_ids.contiguous()

        model_inputs.past_key_values = past_key_values
        model_inputs.attention_mask = attention_mask

        return model_inputs

    @dataclass(repr=False, eq=False)
    class GenerationOutputs:
        ids: List[torch.Tensor]
        attentions: List[Optional[Tuple[torch.FloatTensor, ...]]]
        hiddens: List[torch.Tensor]

        def destroy(self):
            del_all(self.ids)
            del_all(self.attentions)
            del_all(self.hiddens)

    @torch.no_grad()
    def _prepare_generation_outputs(
        self,
        inputs_ids: torch.Tensor,
        start_idx: int,
        end_idx: torch.Tensor,
        attentions: List[Optional[Tuple[torch.FloatTensor, ...]]],
        hiddens: List[torch.Tensor],
        infer_text: bool,
    ) -> GenerationOutputs:
        inputs_ids = [
            inputs_ids[idx].narrow(0, start_idx, i) for idx, i in enumerate(end_idx)
        ]
        if infer_text:
            inputs_ids = [i.narrow(1, 0, 1).squeeze_(1) for i in inputs_ids]

        if len(hiddens) > 0:
            hiddens = torch.stack(hiddens, 1)
            hiddens = [
                hiddens[idx].narrow(0, 0, i) for idx, i in enumerate(end_idx.int())
            ]

        return self.GenerationOutputs(
            ids=inputs_ids,
            attentions=attentions,
            hiddens=hiddens,
        )

    @torch.no_grad()
    def generate(
        self,
        emb: torch.Tensor,
        inputs_ids: torch.Tensor,
        temperature: torch.Tensor,
        eos_token: Union[int, torch.Tensor],
        attention_mask: Optional[torch.Tensor] = None,
        max_new_token=2048,
        min_new_token=0,
        logits_processors: Tuple[
            Callable[[torch.LongTensor, torch.FloatTensor], torch.FloatTensor]
        ] = (),
        infer_text=False,
        return_attn=False,
        return_hidden=False,
        stream=False,
        show_tqdm=True,
        ensure_non_empty=True,
        stream_batch=24,
        manual_seed: Optional[int] = None,
        context=Context(),
    ):

        attentions: List[Optional[Tuple[torch.FloatTensor, ...]]] = []
        hiddens = []
        stream_iter = 0

        start_idx, end_idx = inputs_ids.shape[1], torch.zeros(
            inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long
        )
        finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool()

        old_temperature = temperature

        temperature = (
            temperature.unsqueeze(0)
            .expand(inputs_ids.shape[0], -1)
            .contiguous()
            .view(-1, 1)
        )

        attention_mask_cache = torch.ones(
            (
                inputs_ids.shape[0],
                inputs_ids.shape[1] + max_new_token,
            ),
            dtype=torch.bool,
            device=inputs_ids.device,
        )
        if attention_mask is not None:
            attention_mask_cache.narrow(1, 0, attention_mask.shape[1]).copy_(
                attention_mask
            )

        progress = inputs_ids.size(1)
        # pre-allocate inputs_ids
        inputs_ids_buf = torch.zeros(
            inputs_ids.size(0),
            progress + max_new_token,
            inputs_ids.size(2),
            dtype=inputs_ids.dtype,
            device=inputs_ids.device,
        )
        inputs_ids_buf.narrow(1, 0, progress).copy_(inputs_ids)
        del inputs_ids
        inputs_ids = inputs_ids_buf.narrow(1, 0, progress)

        pbar: Optional[tqdm] = None

        if show_tqdm:
            pbar = tqdm(
                total=max_new_token,
                desc="text" if infer_text else "code",
                bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt}(max) [{elapsed}, {rate_fmt}{postfix}]",
            )

        past_key_values = None

        for i in range(max_new_token):

            model_input = self._prepare_generation_inputs(
                inputs_ids,
                past_key_values,
                attention_mask_cache.narrow(1, 0, inputs_ids.shape[1]),
                use_cache=not self.is_te_llama,
            )

            if i > 0:
                del emb
                inputs_ids_emb = model_input.input_ids.to(self.device_gpt)
                if infer_text:
                    emb: torch.Tensor = self.emb_text(inputs_ids_emb[:, :, 0])
                else:
                    code_emb = [
                        self.emb_code[i](inputs_ids_emb[:, :, i])
                        for i in range(self.num_vq)
                    ]
                    emb = torch.stack(code_emb, 3).sum(3)
                del inputs_ids_emb, model_input.input_ids
            model_input.inputs_embeds = emb

            model_input.to(self.device_gpt, self.gpt.dtype)

            outputs: BaseModelOutputWithPast = self.gpt(
                attention_mask=model_input.attention_mask,
                position_ids=model_input.position_ids,
                past_key_values=model_input.past_key_values,
                inputs_embeds=model_input.inputs_embeds,
                use_cache=model_input.use_cache,
                output_attentions=return_attn,
                cache_position=model_input.cache_position,
            )
            del_all(model_input)
            attentions.append(outputs.attentions)
            hidden_states = outputs.last_hidden_state.to(
                self.device, dtype=torch.float
            )  # 🐻
            past_key_values = outputs.past_key_values
            del_all(outputs)
            if return_hidden:
                hiddens.append(hidden_states.narrow(1, -1, 1).squeeze_(1))

            with P.cached():
                if infer_text:
                    logits: torch.Tensor = self.head_text(hidden_states)
                else:
                    # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3)
                    logits = torch.empty(
                        hidden_states.size(0),
                        hidden_states.size(1),
                        self.num_audio_tokens,
                        self.num_vq,
                        dtype=torch.float,
                        device=self.device,
                    )
                    for num_vq_iter in range(self.num_vq):
                        x: torch.Tensor = self.head_code[num_vq_iter](hidden_states)
                        logits[..., num_vq_iter] = x
                        del x

            del hidden_states

            # logits = logits[:, -1].float()
            logits = logits.narrow(1, -1, 1).squeeze_(1).float()

            if not infer_text:
                # logits = rearrange(logits, "b c n -> (b n) c")
                logits = logits.permute(0, 2, 1)
                logits = logits.reshape(-1, logits.size(2))
                # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c")
                inputs_ids_sliced = inputs_ids.narrow(
                    1,
                    start_idx,
                    inputs_ids.size(1) - start_idx,
                ).permute(0, 2, 1)
                logits_token = inputs_ids_sliced.reshape(
                    inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1),
                    -1,
                ).to(self.device)
                del inputs_ids_sliced
            else:
                logits_token = (
                    inputs_ids.narrow(
                        1,
                        start_idx,
                        inputs_ids.size(1) - start_idx,
                    )
                    .narrow(2, 0, 1)
                    .to(self.device)
                )

            logits /= temperature

            for logitsProcessors in logits_processors:
                logits = logitsProcessors(logits_token, logits)

            del logits_token

            if i < min_new_token:
                logits[:, eos_token] = -torch.inf

            scores = F.softmax(logits, dim=-1)

            del logits

            if manual_seed is None:
                idx_next = torch.multinomial(scores, num_samples=1).to(finish.device)
            else:
                idx_next = torch.multinomial(
                    scores,
                    num_samples=1,
                    generator=self.generator.manual_seed(manual_seed),
                ).to(finish.device)

            del scores

            if not infer_text:
                # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq)
                idx_next = idx_next.view(-1, self.num_vq)
                finish_or = idx_next.eq(eos_token).any(1)
                finish.logical_or_(finish_or)
                del finish_or
                inputs_ids_buf.narrow(1, progress, 1).copy_(idx_next.unsqueeze_(1))
            else:
                finish_or = idx_next.eq(eos_token).any(1)
                finish.logical_or_(finish_or)
                del finish_or
                inputs_ids_buf.narrow(1, progress, 1).copy_(
                    idx_next.unsqueeze_(-1).expand(-1, -1, self.num_vq),
                )

            if i == 0 and finish.any():
                self.logger.warning(
                    "unexpected end at index %s",
                    str([unexpected_idx.item() for unexpected_idx in finish.nonzero()]),
                )
                if ensure_non_empty and manual_seed is None:
                    if show_tqdm:
                        pbar.close()
                    self.logger.warning("regenerate in order to ensure non-empty")
                    del_all(attentions)
                    del_all(hiddens)
                    del (
                        start_idx,
                        end_idx,
                        finish,
                        temperature,
                        attention_mask_cache,
                        past_key_values,
                        idx_next,
                        inputs_ids_buf,
                    )
                    new_gen = self.generate(
                        emb,
                        inputs_ids,
                        old_temperature,
                        eos_token,
                        attention_mask,
                        max_new_token,
                        min_new_token,
                        logits_processors,
                        infer_text,
                        return_attn,
                        return_hidden,
                        stream,
                        show_tqdm,
                        ensure_non_empty,
                        stream_batch,
                        manual_seed,
                        context,
                    )
                    for result in new_gen:
                        yield result
                    del inputs_ids
                return

            del idx_next
            progress += 1
            inputs_ids = inputs_ids_buf.narrow(1, 0, progress)

            not_finished = finish.logical_not().to(end_idx.device)
            end_idx.add_(not_finished.int())
            stream_iter += not_finished.any().int()
            if stream:
                if stream_iter > 0 and stream_iter % stream_batch == 0:
                    self.logger.debug("yield stream result, end: %d", end_idx)
                    yield self._prepare_generation_outputs(
                        inputs_ids,
                        start_idx,
                        end_idx,
                        attentions,
                        hiddens,
                        infer_text,
                    )
            del not_finished

            if finish.all() or context.get():
                break

            if pbar is not None:
                pbar.update(1)

        if pbar is not None:
            pbar.close()

        if not finish.all():
            if context.get():
                self.logger.warning("generation is interrupted")
            else:
                self.logger.warning(
                    f"incomplete result. hit max_new_token: {max_new_token}"
                )

        del finish, inputs_ids_buf

        yield self._prepare_generation_outputs(
            inputs_ids,
            start_idx,
            end_idx,
            attentions,
            hiddens,
            infer_text,
        )