import spaces
import gradio as gr
from huggingface_hub import hf_hub_download
from llama_cpp_cuda_tensorcores import Llama

REPO_ID = "keitokei1994/shisa-v1-qwen2-7b-GGUF"
MODEL_NAME = "shisa-v1-qwen2-7b.Q8_0.gguf"
MAX_CONTEXT_LENGTH = 32768
CUDA = True
SYSTEM_PROMPT = "You are a helpful, smart, kind, and efficient AI assistant. You always fulfill the user's requests to the best of your ability."
TOKEN_STOP = ["<|eot_id|>"]
SYS_MSG = "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nSYSTEM_PROMPT<|eot_id|>\n"
USER_PROMPT = (
    "<|start_header_id|>user<|end_header_id|>\n\nUSER_PROMPT<|eot_id|>\n"
)
ASSIS_PROMPT = "<|start_header_id|>assistant<|end_header_id|>\n\n"
END_ASSIS_PREVIOUS_RESPONSE = "<|eot_id|>\n"

TASK_PROMPT = {
    "Assistant": SYSTEM_PROMPT,
    "Translate": "You are an expert translator. Translate the following text into English.",
    "Summarization": "Summarizing information is my specialty. Let me know what you'd like summarized.",
    "Grammar correction": "Grammar is my forte! Feel free to share the text you'd like me to proofread and correct.",
    "Stable diffusion prompt generator": "You are a stable diffusion prompt generator. Break down the user's text and create a more elaborate prompt.",
    "Play Trivia": "Engage the user in a trivia game on various topics.",
    "Share Fun Facts": "Share interesting and fun facts on various topics.",
    "Explain code": "You are an expert programmer guiding someone through a piece of code step by step, explaining each line and its function in detail.",
    "Paraphrase Master": "You have the knack for transforming complex or verbose text into simpler, clearer language while retaining the original meaning and essence.",
    "Recommend Movies": "Recommend movies based on the user's preferences.",
    "Offer Motivational Quotes": "Offer motivational quotes to inspire the user.",
    "Recommend Books": "Recommend books based on the user's favorite genres or interests.",
    "Philosophical discussion": "Engage the user in a philosophical discussion",
    "Music recommendation": "Tune time! What kind of music are you in the mood for? I'll find the perfect song for you.",
    "Generate a Joke": "Generate a witty joke suitable for a stand-up comedy routine.",
    "Roleplay as a Detective": "Roleplay as a detective interrogating a suspect in a murder case.",
    "Act as a News Reporter": "Act as a news reporter covering breaking news about an alien invasion.",
    "Play as a Space Explorer": "Play as a space explorer encountering a new alien civilization.",
    "Be a Medieval Knight": "Imagine yourself as a medieval knight embarking on a quest to rescue a princess.",
    "Act as a Superhero": "Act as a superhero saving a city from a supervillain's evil plot.",
    "Play as a Pirate Captain": "Play as a pirate captain searching for buried treasure on a remote island.",
    "Be a Famous Celebrity": "Imagine yourself as a famous celebrity attending a glamorous red-carpet event.",
    "Design a New Invention": "Imagine you're an inventor tasked with designing a revolutionary new invention that will change the world.",
    "Act as a Time Traveler": "You've just discovered time travel! Describe your adventures as you journey through different eras.",
    "Play as a Magical Girl": "You are a magical girl with extraordinary powers, battling dark forces to protect your city and friends.",
    "Act as a Shonen Protagonist": "You are a determined and spirited shonen protagonist on a quest for strength, friendship, and victory.",
    "Roleplay as a Tsundere Character": "You are a tsundere character, initially cold and aloof but gradually warming up to others through unexpected acts of kindness.",
}

css = ".gradio-container {background-image: url('file=./assets/background.png'); background-size: cover; background-position: center; background-repeat: no-repeat;}"


class ChatLLM:
    def __init__(self, config_model):
        self.llm = None
        self.config_model = config_model
        # self.load_cpp_model()

    def load_cpp_model(self):
        self.llm = Llama(**config_model)

    def apply_chat_template(
        self,
        history,
        system_message,
    ):
        history = history or []

        messages = SYS_MSG.replace("SYSTEM_PROMPT", system_message.strip())
        for msg in history:
            messages += (
                USER_PROMPT.replace("USER_PROMPT", msg[0]) + ASSIS_PROMPT + msg[1]
            )
            messages += END_ASSIS_PREVIOUS_RESPONSE if msg[1] else ""

        print(messages)

        # messages = messages[:-1]

        return messages

    @spaces.GPU(duration=120)
    def response(
        self,
        history,
        system_message,
        max_tokens,
        temperature,
        top_p,
        top_k,
        repeat_penalty,
    ):

        messages = self.apply_chat_template(history, system_message)

        history[-1][1] = ""

        if not self.llm:
            print("Loading model")
            self.load_cpp_model()

        for output in self.llm(
            messages,
            echo=False,
            stream=True,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
            top_k=top_k,
            repeat_penalty=repeat_penalty,
            stop=TOKEN_STOP,
        ):
            answer = output["choices"][0]["text"]
            history[-1][1] += answer
            # stream the response
            yield history, history


def user(message, history):
    history = history or []
    # Append the user's message to the conversation history
    history.append([message, ""])
    return "", history


def clear_chat(chat_history_state, chat_message):
    chat_history_state = []
    chat_message = ""
    return chat_history_state, chat_message


def gui(llm_chat):
    with gr.Blocks(theme="NoCrypt/miku", css=css) as app:
        gr.Markdown("# shisa-v1-qwen2-7b.Q8_0.gguf")
        gr.Markdown(
            f"""
                ### This demo utilizes the repository ID {REPO_ID} with the model {MODEL_NAME}, powered by the LLaMA.cpp backend.
                """
        )
        with gr.Row():
            with gr.Column(scale=2):
                chatbot = gr.Chatbot(
                    label="Chat",
                    height=700,
                    avatar_images=(
                        "assets/avatar_user.jpeg",
                        "assets/avatar_llama.jpeg",
                    ),
                )
            with gr.Column(scale=1):
                with gr.Row():
                    message = gr.Textbox(
                        label="Message",
                        placeholder="Ask me anything.",
                        lines=3,
                    )
                with gr.Row():
                    submit = gr.Button(value="Send message", variant="primary")
                    clear = gr.Button(value="New chat", variant="primary")
                    stop = gr.Button(value="Stop", variant="secondary")

                with gr.Accordion("Contextual Prompt Editor"):
                    default_task = "Assistant"
                    task_prompts_gui = gr.Dropdown(
                        TASK_PROMPT,
                        value=default_task,
                        label="Prompt selector",
                        visible=True,
                        interactive=True,
                    )
                    system_msg = gr.Textbox(
                        TASK_PROMPT[default_task],
                        label="System Message",
                        placeholder="system prompt",
                        lines=4,
                    )

                    def task_selector(choice):
                        return gr.update(value=TASK_PROMPT[choice])

                    task_prompts_gui.change(
                        task_selector,
                        [task_prompts_gui],
                        [system_msg],
                    )

                with gr.Accordion("Advanced settings", open=False):
                    with gr.Column():
                        max_tokens = gr.Slider(
                            20, 4096, label="Max Tokens", step=20, value=400
                        )
                        temperature = gr.Slider(
                            0.2, 2.0, label="Temperature", step=0.1, value=0.8
                        )
                        top_p = gr.Slider(
                            0.0, 1.0, label="Top P", step=0.05, value=0.95
                        )
                        top_k = gr.Slider(
                            0, 100, label="Top K", step=1, value=40
                        )
                        repeat_penalty = gr.Slider(
                            0.0,
                            2.0,
                            label="Repetition Penalty",
                            step=0.1,
                            value=1.1,
                        )

                chat_history_state = gr.State()
                clear.click(
                    clear_chat,
                    inputs=[chat_history_state, message],
                    outputs=[chat_history_state, message],
                    queue=False,
                )
                clear.click(lambda: None, None, chatbot, queue=False)

                submit_click_event = submit.click(
                    fn=user,
                    inputs=[message, chat_history_state],
                    outputs=[message, chat_history_state],
                    queue=True,
                ).then(
                    fn=llm_chat.response,
                    inputs=[
                        chat_history_state,
                        system_msg,
                        max_tokens,
                        temperature,
                        top_p,
                        top_k,
                        repeat_penalty,
                    ],
                    outputs=[chatbot, chat_history_state],
                    queue=True,
                )
                stop.click(
                    fn=None,
                    inputs=None,
                    outputs=None,
                    cancels=[submit_click_event],
                    queue=False,
                )
    return app


if __name__ == "__main__":

    model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_NAME)

    config_model = {
        "model_path": model_path,
        "n_ctx": MAX_CONTEXT_LENGTH,
        "n_gpu_layers": -1 if CUDA else 0,
    }

    llm_chat = ChatLLM(config_model)

    app = gui(llm_chat)

    app.queue(default_concurrency_limit=40)

    app.launch(
        max_threads=40,
        share=False,
        show_error=True,
        quiet=False,
        debug=True,
        allowed_paths=["./assets/"],
    )