from llama_cpp import Llama
from huggingface_hub import hf_hub_download
import gradio as gr
from typing import Tuple, List

DESCRIPTION = f"""
# Chat with Arco 500M as GGUF on CPU
"""

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 200

# Download the GGUF file
model_path = hf_hub_download(
    repo_id="ijohn07/arco-plus-Q8_0-GGUF",
    filename="arco-plus-q8_0.gguf",
    repo_type="model"
)
# Load the GGUF model
pipe = Llama(
    n_ctx=MAX_MAX_NEW_TOKENS,
    # n_threads=4, # Set the desired number of threads to use, defaults to number of cores
    # n_gpu_layers = 1, # Enable to use GPU, check supported layers and GPU size.
    # n_batch=1, # Set the batch size.
    # use_mlock =True, # Set to False to disable locking to RAM.
    model_path=model_path
)

def predict(message: str, history: List[List[str]], max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS):
    if not message:
        return "", history
    prompt = message

    # Initialize reply
    reply = ""

    history.append([message, ""])
    
    # Use stream=True for streaming
    stream = pipe(
        prompt, 
        max_tokens=max_new_tokens, 
        stop=["</s>"],
        stream=True
    )

    for output in stream:
        # This loop will receive partial output (one token at a time)
        new_text = output['choices'][0]['text']
        
        # Append to the current reply
        reply += new_text
        
        # Update the history
        history[-1][1] = reply
        
        # Yield for incremental display on chat
        yield "", history
    
with gr.Blocks() as demo:
    gr.Markdown(DESCRIPTION)
    chatbot = gr.Chatbot()
    with gr.Row():
        textbox = gr.Textbox(placeholder="Type here and press enter")
    max_new_tokens_slider = gr.Slider(
        minimum=1,
        maximum=MAX_MAX_NEW_TOKENS,
        value=DEFAULT_MAX_NEW_TOKENS,
        label="Max New Tokens",
    )
    textbox.submit(predict, [textbox, chatbot, max_new_tokens_slider], [textbox, chatbot])
    
demo.queue().launch()