import spaces
import torch
import sys
import html
from transformers import AutoTokenizer, AutoModelForCausalLM, TextIteratorStreamer
from threading import Thread
import gradio as gr
from gradio_rich_textbox import RichTextbox

title = """# MetaMath - Tencent's Mistral DPO finetune for mathematics
Model: [TencentARC/Mistral_Pro_8B_v0.1](
Using examples from [introspector/unimath](
"""

model_name = 'TencentARC/Mistral_Pro_8B_v0.1'
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
# model.generation_config = GenerationConfig.from_pretrained(model_name)
# model.generation_config.pad_token_id = model.generation_config.eos_token_id

@torch.inference_mode()
@spaces.GPU
def predict_math_bot(user_message, system_message="", max_new_tokens=125, temperature=0.1, top_p=0.9, repetition_penalty=1.9, do_sample=False):
    prompt = f"<|user|>{user_message}\n<|system|>{system_message}\n<|assistant|>\n" if system_message else user_message
    inputs = tokenizer(prompt, return_tensors='pt', add_special_tokens=True)
    input_ids = inputs["input_ids"].to(model.device)
    output_ids = model.generate(
        input_ids,
        max_length=input_ids.shape[1] + max_new_tokens,
        temperature=temperature,
        top_p=top_p,
        repetition_penalty=repetition_penalty,
        pad_token_id=tokenizer.eos_token_id,
        do_sample=do_sample
    )
    response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    return response

def main():
    with gr.Blocks() as demo:
        gr.Markdown(title)
        with gr.Row():
            user_message = gr.Code(label="🫡Enter your math query here...", language="r", lines=3, value="""F(x) &= \int^a_b \frac{1}{3}x^3""")
            system_message = gr.Textbox(label="📉System Prompt", lines=2, placeholder="Optional: give precise instructions to resolve the problem provided above, produce complete answer in Latex format:")
        with gr.Accordion("Advanced Settings"):
            with gr.Row():
                max_new_tokens = gr.Slider(label="Max new tokens", value=125, minimum=25, maximum=1250)
                temperature = gr.Slider(label="Temperature", value=0.1, minimum=0.05, maximum=1.0)
                top_p = gr.Slider(label="Top-p (nucleus sampling)", value=0.90, minimum=0.01, maximum=0.99)
                repetition_penalty = gr.Slider(label="Repetition penalty", value=1.9, minimum=1.0, maximum=2.0)
                do_sample = gr.Checkbox(label="Uncheck for faster inference", value=False)
        output_text = RichTextbox(label="🫡📉MetaMath", interactive=True)
        gr.Button("Try🫡📉MetaMath").click(
            predict_math_bot,
            inputs=[user_message, system_message, max_new_tokens, temperature, top_p, repetition_penalty, do_sample],
            outputs=output_text
        )
    demo.launch()

if __name__ == "__main__":
    main()