import os
import random

import huggingface_hub
from datasets import load_dataset, Dataset
from dotenv import load_dotenv
from pytorch_lightning import LightningDataModule
from pytorch_lightning.utilities.types import TRAIN_DATALOADERS, EVAL_DATALOADERS
from torch.utils.data import DataLoader, IterableDataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from transformers import TrainingArguments, Trainer
import torch
import logging
import wandb

timber = logging.getLogger()
# logging.basicConfig(level=logging.DEBUG)
logging.basicConfig(level=logging.INFO)  # change to level=logging.DEBUG to print more logs...

black = "\u001b[30m"
red = "\u001b[31m"
green = "\u001b[32m"
yellow = "\u001b[33m"
blue = "\u001b[34m"
magenta = "\u001b[35m"
cyan = "\u001b[36m"
white = "\u001b[37m"

FORWARD = "FORWARD_INPUT"
BACKWARD = "BACKWARD_INPUT"

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

PRETRAINED_MODEL_NAME: str = "LongSafari/hyenadna-small-32k-seqlen-hf"


def insert_debug_motif_at_random_position(seq, DEBUG_MOTIF):
  start = 0
  end = len(seq)
  rand_pos = random.randrange(start, (end - len(DEBUG_MOTIF)))
  random_end = rand_pos + len(DEBUG_MOTIF)
  output = seq[start: rand_pos] + DEBUG_MOTIF + seq[random_end: end]
  assert len(seq) == len(output)
  return output


class PagingMQTLDataset(IterableDataset):
  def __init__(self,
               m_dataset,
               seq_len,
               tokenizer,
               max_length=512,
               check_if_pipeline_is_ok_by_inserting_debug_motif=False):
    self.dataset = m_dataset
    self.check_if_pipeline_is_ok_by_inserting_debug_motif = check_if_pipeline_is_ok_by_inserting_debug_motif
    self.debug_motif = "ATCGCCTA"
    self.seq_len = seq_len

    self.bert_tokenizer = tokenizer
    self.max_length = max_length
    pass

  def __iter__(self):
    for row in self.dataset:
      processed = self.preprocess(row)
      if processed is not None:
        yield processed

  def preprocess(self, row):
    sequence = row['sequence']  # Fetch the 'sequence' column
    if len(sequence) != self.seq_len:
      return None  # skip problematic row!
    label = row['label']  # Fetch the 'label' column (or whatever target you use)
    if label == 1 and self.check_if_pipeline_is_ok_by_inserting_debug_motif:
      sequence = insert_debug_motif_at_random_position(seq=sequence, DEBUG_MOTIF=self.debug_motif)

    input_ids = self.bert_tokenizer(sequence)["input_ids"]
    tokenized_tensor = torch.tensor(input_ids)
    label_tensor = torch.tensor(label)
    output_dict = {"input_ids": tokenized_tensor, "labels": label_tensor} # so this is now you do it?
    return output_dict  # tokenized_tensor, label_tensor


class MqtlDataModule(LightningDataModule):
  def __init__(self, train_ds, val_ds, test_ds, batch_size=16):
    super().__init__()
    self.batch_size = batch_size
    self.train_loader = DataLoader(train_ds, batch_size=self.batch_size, shuffle=False,
                                   # collate_fn=collate_fn,
                                   num_workers=1,
                                   # persistent_workers=True
                                   )
    self.validate_loader = DataLoader(val_ds, batch_size=self.batch_size, shuffle=False,
                                      # collate_fn=collate_fn,
                                      num_workers=1,
                                      # persistent_workers=True
                                      )
    self.test_loader = DataLoader(test_ds, batch_size=self.batch_size, shuffle=False,
                                  # collate_fn=collate_fn,
                                  num_workers=1,
                                  # persistent_workers=True
                                  )
    pass

  def prepare_data(self):
    pass

  def setup(self, stage: str) -> None:
    timber.info(f"inside setup: {stage = }")
    pass

  def train_dataloader(self) -> TRAIN_DATALOADERS:
    return self.train_loader

  def val_dataloader(self) -> EVAL_DATALOADERS:
    return self.validate_loader

  def test_dataloader(self) -> EVAL_DATALOADERS:
    return self.test_loader


def create_paging_train_val_test_datasets(tokenizer, WINDOW, is_debug, batch_size=1000):
  data_files = {
    # small samples
    "train_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_train_binned.csv",
    "validate_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_validate_binned.csv",
    "test_binned_200": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_200_test_binned.csv",
    # medium samples
    "train_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_train_binned.csv",
    "validate_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_validate_binned.csv",
    "test_binned_1000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_1000_test_binned.csv",

    # large samples
    "train_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_train_binned.csv",
    "validate_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_validate_binned.csv",
    "test_binned_4000": "/home/soumic/Codes/mqtl-classification/src/inputdata/dataset_4000_test_binned.csv",
  }

  dataset_map = None
  is_my_laptop = os.path.isfile("/src/inputdata/dataset_4000_test_binned.csv")
  if is_my_laptop:
    dataset_map = load_dataset("csv", data_files=data_files, streaming=True)
  else:
    dataset_map = load_dataset("fahimfarhan/mqtl-classification-datasets", streaming=True)

  train_dataset = PagingMQTLDataset(dataset_map[f"train_binned_{WINDOW}"],
                                    check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
                                    tokenizer=tokenizer,
                                    seq_len=WINDOW
                                    )
  val_dataset = PagingMQTLDataset(dataset_map[f"validate_binned_{WINDOW}"],
                                  check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
                                  tokenizer=tokenizer,
                                  seq_len=WINDOW)
  test_dataset = PagingMQTLDataset(dataset_map[f"test_binned_{WINDOW}"],
                                   check_if_pipeline_is_ok_by_inserting_debug_motif=is_debug,
                                   tokenizer=tokenizer,
                                   seq_len=WINDOW)
  # data_module = MqtlDataModule(train_ds=train_dataset, val_ds=val_dataset, test_ds=test_dataset, batch_size=batch_size)
  return train_dataset, val_dataset, test_dataset


def login_inside_huggingface_virtualmachine():
  # Load the .env file, but don't crash if it's not found (e.g., in Hugging Face Space)
  try:
    load_dotenv()  # Only useful on your laptop if .env exists
    print(".env file loaded successfully.")
  except Exception as e:
    print(f"Warning: Could not load .env file. Exception: {e}")

  # Try to get the token from environment variables
  try:
    token = os.getenv("HF_TOKEN")

    if not token:
      raise ValueError("HF_TOKEN not found. Make sure to set it in the environment variables or .env file.")

    # Log in to Hugging Face Hub
    huggingface_hub.login(token)
    print("Logged in to Hugging Face Hub successfully.")

  except Exception as e:
    print(f"Error during Hugging Face login: {e}")
    # Handle the error appropriately (e.g., exit or retry)

  # wand db login
  try:
    api_key = os.getenv("WAND_DB_API_KEY")
    timber.info(f"{api_key = }")

    if not api_key:
      raise ValueError("WAND_DB_API_KEY not found. Make sure to set it in the environment variables or .env file.")

    # Log in to Hugging Face Hub
    wandb.login(key=api_key)
    print("Logged in to wand db successfully.")

  except Exception as e:
    print(f"Error during wand db Face login: {e}")
  pass


def start():
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

  login_inside_huggingface_virtualmachine()
  WINDOW = 1000
  batch_size = 100
  tokenizer = AutoTokenizer.from_pretrained(PRETRAINED_MODEL_NAME, trust_remote_code=True)
  model = AutoModelForSequenceClassification.from_pretrained(PRETRAINED_MODEL_NAME, torch_dtype=torch.bfloat16,
                                                             device_map="auto",
                                                             trust_remote_code=True)
  args = {
    "output_dir": "output_hyena_dna-mqtl_classification",
    "num_train_epochs": 2,
    "max_steps": 20,
    # Set the number of steps you expect to train, originally 1000, takes too much time. So I set it to 10 to run faster and check my code/pipeline
    "run_name": "laptop_run_hyena_dna-mqtl_classification",  # Override run_name here
    "per_device_train_batch_size": 1,
    "gradient_accumulation_steps": 4,
    "gradient_checkpointing": True,
    "learning_rate": 2e-5,
    "save_safetensors": False  # I added it. this solves the runtime error!
  }

  # """
  #   got this error at the end!
  #   raise RuntimeError(
  #   RuntimeError: The weights trying to be saved contained shared tensors [{'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.3.freq', 'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.1.freq', 'hyena.backbone.layers.0.mixer.filter_fn.implicit_filter.5.freq'}] that are mismatching the transformers base configuration. Try saving using `safe_serialization=False` or remove this tensor sharing.
  # """

  training_args = TrainingArguments(**args)
  # train_dataset, eval_dataset, test_dataset = create_data_module(tokenizer=tokenizer, WINDOW=WINDOW,
  #                                                                batch_size=batch_size,
  #                                                                is_debug=False)
  max_length = 32_000
  sequence = 'ACTG' * int(max_length / 4)
  # sequence = 'ACTG' * int(1000) # seq_len = 4000 it works!
  sequence = [sequence] * 8  # Create 8 identical samples
  tokenized = tokenizer(sequence)["input_ids"]
  labels = [0, 1] * 4

  # Create a dataset for training
  run_the_code_ds = Dataset.from_dict({"input_ids": tokenized, "labels": labels})
  run_the_code_ds.set_format("pt")

  # train_ds, val_ds, test_ds = create_paging_train_val_test_datasets(tokenizer, WINDOW=WINDOW, is_debug=False)
  train_ds, val_ds, test_ds = run_the_code_ds, run_the_code_ds, run_the_code_ds
  # train_ds.set_format("pt") # doesn't work!

  trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
  )
  # train, and validate
  result = trainer.train()
  try:
    print(f"{result = }")
  except Exception as x:
    print(f"{x = }")

  # testing
  try:
    # with torch.no_grad(): # didn't work :/
    test_results = trainer.evaluate(eval_dataset=test_ds)
    print(f"{test_results = }")
  except Exception as oome:
    print(f"{oome = }")



if __name__ == '__main__':
  start()
  pass

"""
git submodule add https://huggingface.co/spaces/fahimfarhan/hyenadna-sm-32k-mqtl-classifier-space src/huggingface-mqtl-classification-hyena-dna

"""