import os
import gradio as gr
from gradio.components import Textbox, Button, Slider, Checkbox
from AinaTheme import theme
from urllib.error import HTTPError

from rag import RAG
from utils import setup

MAX_NEW_TOKENS = 700
SHOW_MODEL_PARAMETERS_IN_UI = os.environ.get("SHOW_MODEL_PARAMETERS_IN_UI", default="True") == "True"

setup()


rag = RAG(
    hf_token=os.getenv("HF_TOKEN"),
    embeddings_model=os.getenv("EMBEDDINGS"), 
    repo_name=os.getenv("REPO_NAME"),
)

# model_name=os.getenv("MODEL"),   
# rerank_model=os.getenv("RERANK_MODEL"),
# rerank_number_contexts=int(os.getenv("RERANK_NUMBER_CONTEXTS"))

def generate(prompt, model_parameters):
    try:
        output, context, source = rag.get_response(prompt, model_parameters)
        return output, context, source
    except HTTPError as err:
        if err.code == 400:
            gr.Warning(
                "The inference endpoint is only available Monday through Friday, from 08:00 to 20:00 CET."
            )
    except:
        gr.Warning(
            "Inference endpoint is not available right now. Please try again later."
        )
    return None, None, None


def submit_input(input_, num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature):
    if input_.strip() == "":
        gr.Warning("Not possible to inference an empty input")
        return None


    model_parameters = {
        "NUM_CHUNKS": num_chunks,
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,
        "top_k": top_k,
        "top_p": top_p,
        "do_sample": do_sample,
        "temperature": temperature
    }

    output, context, source = generate(input_, model_parameters)
    sources_markup = ""

    for url in source:
        sources_markup += f'<a href="{url}" target="_blank">{url}</a><br>'

    return output, sources_markup, context  
    # return output.strip(), sources_markup, context


def change_interactive(text):
    if len(text) == 0:
        return gr.update(interactive=True), gr.update(interactive=False)
    return gr.update(interactive=True), gr.update(interactive=True)


def clear():
    return (
        None, 
        None,
        None,
        None,
        gr.Slider(value=2.0),
        gr.Slider(value=MAX_NEW_TOKENS),
        gr.Slider(value=1.0),
        gr.Slider(value=50),
        gr.Slider(value=0.99),
        gr.Checkbox(value=False),
        gr.Slider(value=0.35),
    )


def gradio_app():
    with gr.Blocks(theme=theme) as demo:
        with gr.Row():
            with gr.Column(scale=0.1):
                gr.Image("rag_image.jpg", elem_id="flor-banner", scale=1, height=256, width=256, show_label=False, show_download_button = False, show_share_button = False)
            with gr.Column():
                gr.Markdown(
                    """# Demo de Retrieval-Augmented Generation per documents legals
                    🔍 **Retrieval-Augmented Generation** (RAG) és una tecnologia d'IA que permet interrogar un repositori de documents amb preguntes 
                    en llenguatge natural, i combina tècniques de recuperació d'informació avançades amb models generatius per redactar una resposta 
                    fent servir només la informació existent en els documents del repositori. 
                        
                    🎯 **Objectiu:** Aquest és un demostrador amb la normativa vigent publicada al Diari Oficial de la Generalitat de Catalunya, en el 
                    repositori del EADOP (Entitat Autònoma del Diari Oficial i de Publicacions). Aquesta versió explora prop de 2000 documents en català, 
                    i genera la resposta fent servir el model Salamandra-7b-aligned-EADOP, el model BSC-LT/salamandra-7b-instruct alineat amb el dataset de alinia/EADOP-RAG-out-of-domain. 
                    
                    ⚠️ **Advertencies**: Aquesta versió és experimental. El contingut generat per aquest model no està supervisat i pot ser incorrecte. 
                    Si us plau, tingueu-ho en compte quan exploreu aquest recurs.                 
                    """
                )
        with gr.Row(equal_height=True):
            with gr.Column(variant="panel"):
                input_ = Textbox(
                    lines=11,
                    label="Input",
                    placeholder="Quina és la finalitat del Servei Meteorològic de Catalunya?",
                    # value = "Quina és la finalitat del Servei Meteorològic de Catalunya?"
                )
                with gr.Row(variant="panel"):
                    clear_btn = Button(
                        "Clear",
                    )
                    submit_btn = Button("Submit", variant="primary", interactive=False)

                with gr.Row(variant="panel"):
                    with gr.Accordion("Model parameters", open=False, visible=SHOW_MODEL_PARAMETERS_IN_UI):
                        num_chunks = Slider(
                            minimum=1,
                            maximum=6,
                            step=1,
                            value=2,
                            label="Number of chunks"
                        )
                        max_new_tokens = Slider(
                            minimum=50,
                            maximum=2000,
                            step=1,
                            value=MAX_NEW_TOKENS,
                            label="Max tokens"
                        )
                        repetition_penalty = Slider(
                            minimum=0.1,
                            maximum=2.0,
                            step=0.1,
                            value=1.0,
                            label="Repetition penalty"
                        )
                        top_k = Slider(
                            minimum=1,
                            maximum=100,
                            step=1,
                            value=50,
                            label="Top k"
                        )
                        top_p = Slider(
                            minimum=0.01,
                            maximum=0.99,
                            value=0.99,
                            label="Top p"
                        )  
                        do_sample = Checkbox(
                            value=False, 
                            label="Do sample"
                        )
                        temperature = Slider(
                            minimum=0.1, 
                            maximum=1,
                            value=0.35,
                            label="Temperature"
                        )

                        parameters_compontents = [num_chunks, max_new_tokens, repetition_penalty, top_k, top_p, do_sample, temperature]

            with gr.Column(variant="panel"):
                output = Textbox(
                    lines=10, 
                    label="Output", 
                    interactive=False, 
                    show_copy_button=True
                )
                with gr.Accordion("Sources and context:", open=False):
                    source_context = gr.Markdown(
                        label="Sources",
                        show_label=False,
                    )
                    with gr.Accordion("See full context evaluation:", open=False):
                        context_evaluation = gr.Markdown(
                            label="Full context",
                            show_label=False,
                            # interactive=False, 
                            # autoscroll=False,
                            # show_copy_button=True
                        )
                

        input_.change(
            fn=change_interactive,
            inputs=[input_],
            outputs=[clear_btn, submit_btn],
            api_name=False,
        )

        input_.change(
            fn=None,
            inputs=[input_],
            api_name=False,
            js="""(i, m) => {
            document.getElementById('inputlenght').textContent = i.length + '  '
            document.getElementById('inputlenght').style.color =  (i.length > m) ? "#ef4444" : "";
        }""",
        )

        clear_btn.click(
            fn=clear, 
            inputs=[], 
            outputs=[input_, output, source_context, context_evaluation] + parameters_compontents,
              queue=False, 
              api_name=False
        )
        
        submit_btn.click(
            fn=submit_input, 
            inputs=[input_]+ parameters_compontents, 
            outputs=[output, source_context, context_evaluation],
            api_name="get-results"
        )

        with gr.Row():
            with gr.Column(scale=0.5):
                gr.Examples(
                    examples=[
                        ["""Què és l'EADOP (Entitat Autònoma del Diari Oficial i de Publicacions)?"""],
                    ],
                    inputs=input_,
                    outputs=[output, source_context, context_evaluation],
                    fn=submit_input,
                )
                gr.Examples(
                    examples=[
                        ["""Què diu el decret sobre la senyalització de les begudes alcohòliques i el tabac a Catalunya?"""],
                    ],
                    inputs=input_,
                    outputs=[output, source_context, context_evaluation],
                    fn=submit_input,
                )
                gr.Examples(
                    examples=[
                        ["""Com es pot inscriure una persona al Registre de catalans i catalanes residents a l'exterior?"""],
                    ],
                    inputs=input_,
                    outputs=[output, source_context, context_evaluation],
                    fn=submit_input,
                )
                gr.Examples(
                    examples=[
                        ["""Quina és la finalitat del Servei Meterològic de Catalunya ?"""],
                    ],
                    inputs=input_,
                    outputs=[output, source_context, context_evaluation],
                    fn=submit_input,
                )

        demo.launch(show_api=True)


if __name__ == "__main__":
    gradio_app()