# -*- coding: utf-8 -*-
"""MIXTRAL_Mixtral-8x7B (QLoRA)

This notebook shows how to fine-tune Mixtral-8x7b on a sample of ultrachat with QLoRA. It requires at least 32 GB of VRAM (at least 2*16 GB GPUs if you want to use consumer hardware). On Google Colab, you can use the A100.

First, we need all these dependencies:
"""

!pip install -q bitsandbytes
!pip install -q transformers
!pip install -q peft
!pip install -q accelerate
!pip install -q datasets
!pip install -q trl
!pip install -q huggingface_hub
!pip install -q diffusers

import torch
from datasets import Dataset
from peft import LoraConfig, PeftModel, prepare_model_for_kbit_training
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,
    TrainingArguments,
    Trainer
)
from trl import SFTTrainer

"""Load the tokenizer and configure padding"""

import os
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer
from peft import PeftModel, prepare_model_for_kbit_training, LoraConfig

# Assuming you have the correct token set as an environment variable or directly in your script
os.environ['HF_TOKEN'] = 'XXXX'

# Name of the model you want to load
model_id = "mistralai/Mistral-7B-Instruct-v0.1"

try:
    # Attempt to load the tokenizer
    tokenizer = AutoTokenizer.from_pretrained(model_id, force_download=True)
    tokenizer.pad_token = tokenizer.unk_token
    tokenizer.pad_token_id = tokenizer.unk_token_id
    tokenizer.padding_side = 'right'
    print("Tokenizer loaded successfully.")

    # Attempt to load the model
    model = AutoModelForCausalLM.from_pretrained(model_id, force_download=True)
    print("Model loaded successfully.")

except Exception as e:
    print(f"Error loading the tokenizer or model: {e}")

"""Load and preprocess the version of ultrachat prepared by Hugging Face. Since each row is a full dialog that can be very long, I only kept the first two turns to reduce the sequence length of the training examples.
"""

# Carga de datos
def load_custom_dataset(file_path):
    with open(file_path, "r", encoding="utf-8") as file:
        lines = file.readlines()
    texts = [line.strip() for line in lines if line.strip()]
    return Dataset.from_dict({"text": texts})

# Actualiza las rutas a los archivos correctos
dataset_train_sft = load_custom_dataset("MIXTRAL_DatosEntrenamiento.txt")
dataset_test_sft = load_custom_dataset("MIXTRAL_DatosValidacion.txt")

"""Load the model and prepare it to be fine-tuned with QLoRA."""

compute_dtype = getattr(torch, "float16")

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=compute_dtype,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
          model_id, quantization_config=bnb_config, device_map={"": 0}
)

model = prepare_model_for_kbit_training(model)

model.config.pad_token_id = tokenizer.pad_token_id
model.config.use_cache = False # Gradient checkpointing is used by default but not compatible with caching

"""The following cell only prints the architecture of the model."""

print(model)

"""Define the configuration of LoRA."""

peft_config = LoraConfig(
    lora_alpha=16,
    lora_dropout=0.1,
    r=64,
)

"""For this demonstration, I trained for only 300 steps. You should train for at least 3000 steps. One epoch would be ideal.

from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    return {
        "accuracy": accuracy_score(labels, predictions),
        "f1": f1_score(labels, predictions, average='macro')
    }
"""

training_arguments = TrainingArguments(
    output_dir="./results_mixtral_sft/",
    evaluation_strategy="steps",
    do_eval=True,
    optim="paged_adamw_8bit",
    num_train_epochs=1,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    per_device_eval_batch_size=4,
    log_level="debug",
    save_steps=1000,
    logging_steps=100,
    learning_rate=2e-4,
    eval_steps=500,
    max_steps=-1,
    lr_scheduler_type="linear",
    report_to="tensorboard" # Ensure TensorBoard is enabled
)

"""Start training:"""

trainer = SFTTrainer(
    model=model,
    train_dataset=dataset_train_sft,
    eval_dataset=dataset_test_sft,
    peft_config=peft_config,
    dataset_text_field="text",
    max_seq_length=512,
    tokenizer=tokenizer,
    args=training_arguments,
)

trainer.train()

# Commented out IPython magic to ensure Python compatibility.
# Activar TensorBoard para visualizar gráficos
# %load_ext tensorboard
# %tensorboard --logdir results_mixtral_sft/runs