# Standard Python imports
import os
import re
import json
from typing import List, Dict, Any

# Data processing and visualization
from PIL import Image
from tqdm import tqdm
from tqdm.notebook import tqdm

# Deep Learning & ML
import torch
from transformers import (
   AutoProcessor,
   AutoModelForVision2Seq,
   AutoTokenizer,
   AutoModelForCausalLM,
   TextStreamer,
   Idefics3ForConditionalGeneration,
   BitsAndBytesConfig

)

from unsloth import FastVisionModel

# Dataset handling
from datasets import load_from_disk

# API & Authentication
from huggingface_hub import login

# UI & Environment
import gradio as gr
from dotenv import load_dotenv

# Available models
MODELS = {
    "Blood Cell Classifier with Llama-3.2": "laurru01/Llama-3.2-11B-Vision-Instruct-ft-PeripherallBloodCells",
    "Blood Cell Classifier with Qwen2-VL": "laurru01/Qwen2-VL-2B-Instruct-ft-bloodcells-big",
    "Blood Cell Classifier with SmolVLM": "laurru01/SmolVLM-Instruct-ft-PeripherallBloodCells",
}

# Global dictionary to store loaded models
loaded_models = {}

def initialize_models():
    """Preload all models during startup"""
    print("Initializing models...")
    for model_name, model_path in MODELS.items():
        print(f"Loading {model_name}...")
        try:
            if "SmolVLM" in model_name:
                # Carga específica para SmolVLM
                base_model = Idefics3ForConditionalGeneration.from_pretrained(
                    "HuggingFaceTB/SmolVLM-Instruct",
                    device_map="auto",
                    torch_dtype=torch.bfloat16,
                    load_in_4bit=True,
                    max_memory={0: "12GB"}
                )
                base_model.load_adapter(model_path)
                processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
                
                loaded_models[model_name] = {
                    "model": base_model,
                    "processor": processor,
                    "type": "smolvlm"
                }
            else:
                # Carga original para Llama y Qwen (sin cambios)
                model, tokenizer = FastVisionModel.from_pretrained(
                    model_name=model_path,
                    load_in_4bit=True,
                    use_gradient_checkpointing="unsloth"
                )
                FastVisionModel.for_inference(model)
                processor = AutoProcessor.from_pretrained(model_path)

                loaded_models[model_name] = {
                    "model": model,
                    "tokenizer": tokenizer,
                    "processor": processor,
                    "type": "standard"
                }
            print(f"Successfully loaded {model_name}")

        except Exception as e:
            print(f"Error loading {model_name}: {str(e)}")

    print("Model initialization complete")

def extract_cell_type(text):
    """Extract cell type from generated description"""
    cell_types = ['neutrophil', 'lymphocyte', 'monocyte', 'eosinophil', 'basophil']
    text_lower = text.lower()
    for cell_type in cell_types:
        if cell_type in text_lower:
            return cell_type.capitalize()
    return "Unidentified Cell Type"

@torch.no_grad()
def generate_description_standard(model, tokenizer, image):
    """Generate description using standard models (Llama and Qwen)"""
    messages = [{
        "role": "user",
        "content": [
            {"type": "image"},
            {"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."}
    ]}]

    input_text = tokenizer.apply_chat_template(messages, add_generation_prompt=True)
    inputs = tokenizer(image, input_text, add_special_tokens=False, return_tensors="pt").to("cuda")

    text_streamer = TextStreamer(tokenizer, skip_prompt=True)
    output = model.generate(
        **inputs,
        streamer=text_streamer,
        max_new_tokens=1024,
        use_cache=True,
        temperature=1.5,
        min_p=0.1
    )

    raw_output = tokenizer.decode(output[0], skip_special_tokens=True)
    if "The provided image" in raw_output:
        start_idx = raw_output.find("assistant")
        cleaned_output = raw_output[start_idx:]
    else:
        cleaned_output = raw_output

    return cleaned_output.strip()

@torch.no_grad()
def generate_description_smolvlm(model, processor, image):
    """Generate description using SmolVLM model with memory-efficient settings"""
    if image.mode != "RGB":
        image = image.convert("RGB")
    
    # Redimensionar a un tamaño más pequeño para reducir memoria
    max_size = 192
    image.thumbnail((max_size, max_size), Image.Resampling.LANCZOS)
    
    sample = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text", "text": "As a hematologist, carefully identify the type of blood cell in this image and describe its key characteristics."}
        ]
    }]

    text_input = processor.apply_chat_template(
        sample, 
        add_generation_prompt=True
    )

    try:
        torch.cuda.empty_cache()
        
        with torch.cuda.amp.autocast():
            model_inputs = processor(
                text=text_input,
                images=[[image]],
                return_tensors="pt",
            ).to("cuda")

            generated_ids = model.generate(
                **model_inputs,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7,
                top_p=0.9,
                repetition_penalty=1.5,
                no_repeat_ngram_size=3,
                num_beams=2,
                length_penalty=1.0,
                early_stopping=True,
                use_cache=True,
                pad_token_id=processor.tokenizer.pad_token_id,
            )

            response_ids = generated_ids[0][len(model_inputs.input_ids[0]):]
            output_text = processor.decode(
                response_ids,
                skip_special_tokens=True,
                clean_up_tokenization_spaces=True
            ).strip()

            if len(set(output_text.split())) < 5:
                output_text = "Error: Generated response was too repetitive. Please try again."

        del model_inputs, generated_ids, response_ids
        torch.cuda.empty_cache()

        return output_text

    except Exception as e:
        torch.cuda.empty_cache()
        raise e

def analyze_cell(image, model_name):
    """Main function to analyze cell images"""
    if not isinstance(image, Image.Image):
        return "Invalid image format. Please upload a valid image.", "", None

    try:
        if model_name not in loaded_models:
            return f"Model {model_name} not loaded.", "", None

        model_components = loaded_models[model_name]
        
        if model_components["type"] == "smolvlm":
            description = generate_description_smolvlm(
                model_components["model"],
                model_components["processor"],
                image
            )
        else:
            description = generate_description_standard(
                model_components["model"],
                model_components["tokenizer"],
                image
            )

        cell_type = extract_cell_type(description)
        return cell_type, description, image
    
    except Exception as e:
        return f"Error occurred: {str(e)}", "", None

# Initialize all models before starting the interface
initialize_models()

# Gradio Interface
with gr.Blocks() as iface:
    gr.HTML("<h1>Blood Cell Analyzer</h1>")
    gr.HTML("<p>Upload a microscopic blood cell image for instant classification and detailed analysis</p>")

    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Upload Blood Cell Image",
                type="pil",
                sources=["upload"]
            )
            model_dropdown = gr.Dropdown(
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
                label="Select Model Version"
            )
            submit_btn = gr.Button("Analyze Cell")

        with gr.Column():
            cell_type = gr.Textbox(label="Identified Cell Type")
            description = gr.Textbox(label="Analysis Details", lines=8)
            output_image = gr.Image(label="Analyzed Image")

    submit_btn.click(
        fn=analyze_cell,
        inputs=[input_image, model_dropdown],
        outputs=[cell_type, description, output_image]
    )

# Enhanced CSS with modern color scheme
custom_css = """
.container {
    max-width: 1000px;
    margin: auto;
    padding: 30px;
    background: linear-gradient(135deg, #f6f9fc 0%, #ffffff 100%);
    border-radius: 20px;
    box-shadow: 0 10px 20px rgba(0,0,0,0.05);
}
.title {
    text-align: center;
    color: #2d3436;
    font-size: 3em;
    font-weight: 700;
    margin-bottom: 20px;
    text-shadow: 2px 2px 4px rgba(0,0,0,0.1);
}
.subtitle {
    text-align: center;
    color: #636e72;
    font-size: 1.2em;
    margin-bottom: 40px;
}
.input-image {
    border: 2px dashed #74b9ff;
    border-radius: 15px;
    padding: 20px;
    transition: all 0.3s ease;
}
.input-image:hover {
    border-color: #0984e3;
    transform: translateY(-2px);
}
.model-dropdown {
    background: #f8f9fa;
    border-radius: 10px;
    border: 1px solid #dfe6e9;
    margin: 15px 0;
}
.submit-button {
    background: linear-gradient(45deg, #0984e3, #74b9ff);
    color: white;
    border: none;
    padding: 12px 25px;
    border-radius: 10px;
    font-weight: 600;
    transition: all 0.3s ease;
}
.submit-button:hover {
    transform: translateY(-2px);
    box-shadow: 0 5px 15px rgba(9, 132, 227, 0.3);
}
.result-box {
    background: white;
    border-radius: 10px;
    border: 1px solid #dfe6e9;
    padding: 15px;
    margin: 10px 0;
}
.output-image {
    border-radius: 15px;
    overflow: hidden;
    box-shadow: 0 5px 15px rgba(0,0,0,0.1);
}
"""
# Interface
with gr.Blocks(css=custom_css) as iface:
    gr.HTML("<h1 class='title'>Blood Cell Classifier</h1>")
    gr.HTML("<p class='subtitle'>Upload a microscopic blood cell image for instant classification and detailed analysis</p>")
    with gr.Row():
        with gr.Column():
            input_image = gr.Image(
                label="Upload Blood Cell Image",
                type="pil",
                sources=["upload"],  # Only allow computer uploads
                elem_classes="input-image"
            )
            model_dropdown = gr.Dropdown(
                choices=list(MODELS.keys()),
                value=list(MODELS.keys())[0],
                label="Select Model Version",
                elem_classes="model-dropdown"
            )
            submit_btn = gr.Button(
                "Analyze Cell",
                variant="primary",
                elem_classes="submit-button"
            )
        with gr.Column():
            cell_type = gr.Textbox(
                label="Identified Cell Type",
                elem_classes="result-box"
            )
            description = gr.Textbox(
                label="Analysis Details",
                lines=8,
                elem_classes="result-box"
            )
            output_image = gr.Image(
                label="Analyzed Image",
                elem_classes="output-image"
            )
    submit_btn.click(
        fn=analyze_cell,
        inputs=[input_image, model_dropdown],
        outputs=[cell_type, description, output_image]
    )
    gr.HTML("""
        <div style="text-align: center; margin-top: 30px; padding: 20px;">
            <p style="color: #636e72;">Developed by Laura Ruiz | MSc Bioinformatics and Biostatistics</p>
            <a href="https://github.com/laurru01" target="_blank"
               style="color: #0984e3; text-decoration: none; font-weight: 600;">
               View on GitHub
            </a>
        </div>
    """)

# Launch the interface
iface.launch()