import logging
import os
import random
from typing import Any

import numpy as np
import pandas as pd
from pytorch_lightning import Trainer, LightningModule, LightningDataModule
from pytorch_lightning.utilities.types import OptimizerLRScheduler, STEP_OUTPUT, EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset
from torchmetrics.classification import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryPrecision, BinaryRecall
from transformers import BertModel, BatchEncoding, BertTokenizer, TrainingArguments, AutoModelForSequenceClassification
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
import torch
from torch import nn
from datasets import load_dataset, IterableDataset
from huggingface_hub import PyTorchModelHubMixin

from dotenv import load_dotenv
from huggingface_hub import login

timber = logging.getLogger()
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)  # change to level=logging.DEBUG to print more logs...

black = "\u001b[30m"
red = "\u001b[31m"
green = "\u001b[32m"
yellow = "\u001b[33m"
blue = "\u001b[34m"
magenta = "\u001b[35m"
cyan = "\u001b[36m"
white = "\u001b[37m"

FORWARD = "FORWARD_INPUT"
BACKWARD = "BACKWARD_INPUT"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PRETRAINED_MODEL_NAME: str = "LongSafari/hyenadna-small-32k-seqlen-hf"


def login_inside_huggingface_virtualmachine():
  # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
  try:
    load_dotenv()  # Only useful on your laptop if .env exists
    print(".env file loaded successfully.")
  except Exception as e:
    print(f"Warning: Could not load .env file. Exception: {e}")

  # Try to get the token from environment variables
  try:
    token = os.getenv("HF_TOKEN")

    if not token:
      raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")

    # Log in to Hugging Face Hub
    login(token)
    print("Logged in to Hugging Face Hub successfully.")

  except Exception as e:
    print(f"Error during Hugging Face login: {e}")
    # Handle the error appropriately (e.g., exit or retry)


def one_hot_e(dna_seq: str) -> np.ndarray:
  mydict = {'A': np.asarray([1.0, 0.0, 0.0, 0.0]), 'C': np.asarray([0.0, 1.0, 0.0, 0.0]),
            'G': np.asarray([0.0, 0.0, 1.0, 0.0]), 'T': np.asarray([0.0, 0.0, 0.0, 1.0]),
            'N': np.asarray([0.0, 0.0, 0.0, 0.0]), 'H': np.asarray([0.0, 0.0, 0.0, 0.0]),
            'a': np.asarray([1.0, 0.0, 0.0, 0.0]), 'c': np.asarray([0.0, 1.0, 0.0, 0.0]),
            'g': np.asarray([0.0, 0.0, 1.0, 0.0]), 't': np.asarray([0.0, 0.0, 0.0, 1.0]),
            'n': np.asarray([0.0, 0.0, 0.0, 0.0]), '-': np.asarray([0.0, 0.0, 0.0, 0.0])}

  size_of_a_seq: int = len(dna_seq)

  # forward = np.zeros(shape=(size_of_a_seq, 4))

  forward_list: list = [mydict[dna_seq[i]] for i in range(0, size_of_a_seq)]
  encoded = np.asarray(forward_list)
  encoded_transposed = encoded.transpose()  # todo: Needs review
  return encoded_transposed


def one_hot_e_column(column: pd.Series) -> np.ndarray:
  tmp_list: list = [one_hot_e(seq) for seq in column]
  encoded_column = np.asarray(tmp_list).astype(np.float32)
  return encoded_column


def reverse_dna_seq(dna_seq: str) -> str:
  # m_reversed = ""
  # for i in range(0, len(dna_seq)):
  #     m_reversed = dna_seq[i] + m_reversed
  # return m_reversed
  return dna_seq[::-1]


def complement_dna_seq(dna_seq: str) -> str:
  comp_map = {"A": "T", "C": "G", "T": "A", "G": "C",
              "a": "t", "c": "g", "t": "a", "g": "c",
              "N": "N", "H": "H", "-": "-",
              "n": "n", "h": "h"
              }

  comp_dna_seq_list: list = [comp_map[nucleotide] for nucleotide in dna_seq]
  comp_dna_seq: str = "".join(comp_dna_seq_list)
  return comp_dna_seq


def reverse_complement_dna_seq(dna_seq: str) -> str:
  return reverse_dna_seq(complement_dna_seq(dna_seq))


def reverse_complement_column(column: pd.Series) -> np.ndarray:
  rc_column: list = [reverse_complement_dna_seq(seq) for seq in column]
  return rc_column


class TorchMetrics:
  def __init__(self, device=DEVICE):
    self.binary_accuracy = BinaryAccuracy().to(device)
    self.binary_auc = BinaryAUROC().to(device)
    self.binary_f1_score = BinaryF1Score().to(device)
    self.binary_precision = BinaryPrecision().to(device)
    self.binary_recall = BinaryRecall().to(device)
    pass

  def update_on_each_step(self, batch_predicted_labels, batch_actual_labels):  # todo: Add log if needed
    self.binary_accuracy.update(preds=batch_predicted_labels, target=batch_actual_labels)
    self.binary_auc.update(preds=batch_predicted_labels, target=batch_actual_labels)
    self.binary_f1_score.update(preds=batch_predicted_labels, target=batch_actual_labels)
    self.binary_precision.update(preds=batch_predicted_labels, target=batch_actual_labels)
    self.binary_recall.update(preds=batch_predicted_labels, target=batch_actual_labels)
    pass

  def compute_and_reset_on_epoch_end(self, log, log_prefix: str, log_color: str = green):
    b_accuracy = self.binary_accuracy.compute()
    b_auc = self.binary_auc.compute()
    b_f1_score = self.binary_f1_score.compute()
    b_precision = self.binary_precision.compute()
    b_recall = self.binary_recall.compute()
    timber.info(
      log_color + f"{log_prefix}_acc = {b_accuracy}, {log_prefix}_auc = {b_auc}, {log_prefix}_f1_score = {b_f1_score}, {log_prefix}_precision = {b_precision}, {log_prefix}_recall = {b_recall}")
    log(f"{log_prefix}_accuracy", b_accuracy)
    log(f"{log_prefix}_auc", b_auc)
    log(f"{log_prefix}_f1_score", b_f1_score)
    log(f"{log_prefix}_precision", b_precision)
    log(f"{log_prefix}_recall", b_recall)

    self.binary_accuracy.reset()
    self.binary_auc.reset()
    self.binary_f1_score.reset()
    self.binary_precision.reset()
    self.binary_recall.reset()
    pass


def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
  start = 0
  end = len(seq)
  rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
  random_end = rand_pos + len(DEBUG_MOTIF)
  output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
  assert len(seq) == len(output)
  return output


class PagingMQTLDataset(IterableDataset):
  def __init__(self,
               m_dataset,
               seq_len,
               tokenizer,
               max_length=512,
               check_if_pipeline_is_ok_by_inserting_debug_motif=False):
    self.dataset = m_dataset
    self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
    self.debug_motif = "ATCGCCTA"
    self.seq_len = seq_len

    self.bert_tokenizer = tokenizer
    self.max_length = max_length
    pass

  def __iter__(self):
    for row in self.dataset:
      processed = self.preprocess(row)
      if processed is not None:
        yield processed

  def preprocess(self, row):
    sequence = row['sequence']  # Fetch the 'sequence' column
    if len(sequence) != self.seq_len:
      return None  # skip problematic row!
    label = row['label']  # Fetch the 'label' column (or whatever target you use)
    if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
      sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)
    ohe_sequence = one_hot_e(dna_seq=sequence)
    one_seq_tensor = torch.from_numpy(ohe_sequence).to(torch.int64)
    # Tokenize the sequence
    encoded_sequence_tokenized: BatchEncoding = self.bert_tokenizer(one_seq_tensor)
    input_ids = encoded_sequence_tokenized["input_ids"]
    # encoded_sequence_squeezed = {key: val.squeeze() for key, val in encoded_sequence.items()}
    return input_ids, label


# def collate_fn(batch):
#   sequences, labels = zip(*batch)
#   ohe_seq, ohe_seq_rc = sequences[0], sequences[1]
#   # Pad sequences to the maximum length in this batch
#   padded_sequences = pad_sequence(ohe_seq, batch_first=True, padding_value=0)
#   padded_sequences_rc = pad_sequence(ohe_seq_rc, batch_first=True, padding_value=0)
#   # Convert labels to a tensor
#   labels = torch.stack(labels)
#   return [padded_sequences, padded_sequences_rc], labels


class MqtlDataModule(LightningDataModule):
  def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
    super().__init__()
    self.batch_size = batch_size
    self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
                                   # collate_fn=collate_fn,
                                   num_workers=1,
                                   # persistent_workers=True
                                   )
    self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
                                      # collate_fn=collate_fn,
                                      num_workers=1,
                                      # persistent_workers=True
                                      )
    self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
                                  # collate_fn=collate_fn,
                                  num_workers=1,
                                  # persistent_workers=True
                                  )
    pass

  def prepare_data(self):
    pass

  def setup(self, stage: str) -> None:
    timber.info(f"inside setup: {stage = }")
    pass

  def train_dataloader(self) -> TRAIN_DATALOADERS:
    return self.train_loader

  def val_dataloader(self) -> EVAL_DATALOADERS:
    return self.validate_loader

  def test_dataloader(self) -> EVAL_DATALOADERS:
    return self.test_loader


class MQtlBertClassifierLightningModule(LightningModule):
  def __init__(self,
               classifier: nn.Module,
               criterion=None,  # nn.BCEWithLogitsLoss(),
               regularization: int = 2,  # 1 == L1, 2 == L2, 3 (== 1 | 2) == both l1 and l2, else ignore / don't care
               l1_lambda=0.001,
               l2_wright_decay=0.001,
               *args: Any,
               **kwargs: Any):
    super().__init__(*args, **kwargs)
    self.classifier = classifier
    self.criterion = criterion
    self.train_metrics = TorchMetrics()
    self.validate_metrics = TorchMetrics()
    self.test_metrics = TorchMetrics()

    self.regularization = regularization
    self.l1_lambda = l1_lambda
    self.l2_weight_decay = l2_wright_decay
    pass

  def forward(self, x, *args: Any, **kwargs: Any) -> Any:
    input_ids: torch.tensor = x["input_ids"]
    return self.classifier.forward(input_ids)

  def configure_optimizers(self) -> OptimizerLRScheduler:
    # Here we add weight decay (L2 regularization) to the optimizer
    weight_decay = 0.0
    if self.regularization == 2 or self.regularization == 3:
      weight_decay = self.l2_weight_decay
    return torch.optim.Adam(self.parameters(), lr=1e-3, weight_decay=weight_decay)  # , weight_decay=0.005)

  def training_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
    # Accuracy on training batch data
    x, y = batch
    preds = self.forward(x)
    loss = self.criterion(preds, y)

    if self.regularization == 1 or self.regularization == 3:  # apply l1 regularization
      l1_norm = sum(p.abs().sum() for p in self.parameters())
      loss += self.l1_lambda * l1_norm

    self.log("train_loss", loss)
    # calculate the scores start
    self.train_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
    # calculate the scores end
    return loss

  def on_train_epoch_end(self) -> None:
    self.train_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="train")
    pass

  def validation_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
    # Accuracy on validation batch data
    # print(f"debug { batch = }")
    x, y = batch
    preds = self.forward(x)
    loss = self.criterion(preds, y)
    self.log("valid_loss", loss)
    # calculate the scores start
    self.validate_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
    # calculate the scores end
    return loss

  def on_validation_epoch_end(self) -> None:
    self.validate_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="validate", log_color=blue)
    return None

  def test_step(self, batch, batch_idx, *args: Any, **kwargs: Any) -> STEP_OUTPUT:
    # Accuracy on validation batch data
    x, y = batch
    preds = self.forward(x)
    loss = self.criterion(preds, y)
    self.log("test_loss", loss)  # do we need this?
    # calculate the scores start
    self.test_metrics.update_on_each_step(batch_predicted_labels=preds.squeeze(), batch_actual_labels=y)
    # calculate the scores end
    return loss

  def on_test_epoch_end(self) -> None:
    self.test_metrics.compute_and_reset_on_epoch_end(log=self.log, log_prefix="test", log_color=magenta)
    return None

  pass


def start_bert(classifier_model, criterion, m_optimizer=torch.optim.Adam, WINDOW=200,
               is_binned=True, is_debug=False, max_epochs=10, batch_size=8):
  file_suffix = ""
  if is_binned:
    file_suffix = "_binned"

  data_files = {
    # small samples
    "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
    "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
    "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
    # medium samples
    "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
    "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
    "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",

    # large samples
    "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
    "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
    "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
  }

  dataset_map = None
  is_my_laptop = os.path.isfile("/src/inputdata/dataset_4000_test_binned.csv")
  if is_my_laptop:
    dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
  else:
    dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)

  tokenizer = AutoModelForSequenceClassification.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME,
                                                                 trust_remote_code=True)

  train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
                                    check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
                                    tokenizer=tokenizer,
                                    seq_len=WINDOW
                                    )
  val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
                                  check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
                                  tokenizer=tokenizer,
                                  seq_len=WINDOW)
  test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
                                   check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
                                   tokenizer=tokenizer,
                                   seq_len=WINDOW)

  data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)

  classifier_model = classifier_model  #.to(DEVICE)
  try:
    classifier_model = classifier_model.from_pretrained(classifier_model.model_repository_name)
  except Exception as x:
    print(x)

  # classifier_module = MQtlBertClassifierLightningModule(
  #   classifier=classifier_model,
  #   regularization=2, criterion=criterion)

  # if os.path.exists(model_save_path):
  #   classifier_module.load_state_dict(torch.load(model_save_path))
  args = {
    "output_dir": "tmp",
    "num_train_epochs": 1,
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 4,
    "gradient_checkpointing": True,
    "learning_rate": 2e-5,
  }
  training_args = TrainingArguments(**args)

  trainer = Trainer(model=classifier_model, args=training_args, datamodule=data_module, max_epochs=max_epochs,
                    precision="32")
  trainer.fit(model=classifier_model)
  timber.info("\n\n")
  trainer.test(model=classifier_model)
  timber.info("\n\n")
  # torch.save(classifier_module.state_dict(), model_save_path)  # deprecated, use classifier_model.save_pretrained(model_subdirectory) instead

  #  save locally
  model_subdirectory = classifier_model.model_repository_name
  classifier_model.save_pretrained(model_subdirectory)

  # push to the hub
  commit_message = f":tada: Push model for window size {WINDOW} from huggingface space"
  if is_my_laptop:
    commit_message = f":tada: Push model for window size {WINDOW} from zephyrus"

  classifier_model.push_to_hub(
    repo_id=f"fahimfarhan/{classifier_model.model_repository_name}",
    # subfolder=f"my-awesome-model-{WINDOW}", subfolder didn't work :/
    commit_message=commit_message  # f":tada: Push model for window size {WINDOW}"
  )

  # reload
  # classifier_model = classifier_model.from_pretrained(f"fahimfarhan/{classifier_model.model_repository_name}")
  # classifier_model = classifier_model.from_pretrained(model_subdirectory)

  pass


class CommonAttentionLayer(nn.Module):
  def __init__(self, hidden_size, *args, **kwargs):
    super().__init__(*args, **kwargs)
    self.attention_linear = nn.Linear(hidden_size, 1)
    pass

  def forward(self, hidden_states):
    # Apply linear layer
    attn_weights = self.attention_linear(hidden_states)
    # Apply softmax to get attention scores
    attn_weights = torch.softmax(attn_weights, dim=1)
    # Apply attention weights to hidden states
    context_vector = torch.sum(attn_weights * hidden_states, dim=1)
    return context_vector, attn_weights


class ReshapedBCEWithLogitsLoss(nn.BCEWithLogitsLoss):
  def forward(self, input, target):
    return super().forward(input.squeeze(), target.float())


class HyenaDnaMQTLClassifier(nn.Module):
  def __init__(self,
               seq_len: int, model_repository_name: str,
               bert_model=BertModel.from_pretrained(pretrained_model_name_or_path=PRETRAINED_MODEL_NAME),
               hidden_size=768,
               num_classes=1,
               *args,
               **kwargs
               ):
    super().__init__(*args, **kwargs)
    self.seq_len = seq_len
    self.model_repository_name = model_repository_name

    self.model_name = "MQtlDnaBERT6Classifier"

    self.bert_model = bert_model
    self.attention = CommonAttentionLayer(hidden_size)
    self.classifier = nn.Linear(hidden_size, num_classes)
    pass

  def forward(self, input_ids: torch.tensor):
    """
    # torch.Size([128, 1, 512]) --> [128, 512]
    input_ids = input_ids.squeeze(dim=1).to(DEVICE)
    # torch.Size([16, 1, 512]) --> [16, 512]
    attention_mask = attention_mask.squeeze(dim=1).to(DEVICE)
    token_type_ids = token_type_ids.squeeze(dim=1).to(DEVICE)
    """
    bert_output: BaseModelOutputWithPoolingAndCrossAttentions = self.bert_model(input_ids=input_ids)

    last_hidden_state = bert_output.last_hidden_state
    context_vector, ignore_attention_weight = self.attention(last_hidden_state)
    y = self.classifier(context_vector)
    return y


if __name__ == '__main__':
  login_inside_huggingface_virtualmachine()

  WINDOW = 1000
  some_model = BertModel.from_pretrained(
    pretrained_model_name_or_path=PRETRAINED_MODEL_NAME)  #  HyenaDnaMQTLClassifier(seq_len=WINDOW, model_repository_name="hyenadna-sm-32k-mqtl-classifier")
  criterion = None

  start_bert(
    classifier_model=some_model,
    criterion=criterion,
    WINDOW=WINDOW,
    is_debug=False,
    max_epochs=20,
    batch_size=16
  )
  pass