MohamedRashad's picture
Add generation configurations to chatbot interface
f0ac041
raw
history blame
5.72 kB
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import gradio as gr
from threading import Thread
base_model_id = "NousResearch/Meta-Llama-3-8B-Instruct"
new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct"
# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
new_model = AutoModelForCausalLM.from_pretrained(
new_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
@spaces.GPU(duration=120)
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
system_prompt_list = [{"role": "system", "content": system_prompt}]
input_text_list = [{"role": "user", "content": input_text}]
base_chat_history = []
for user, assistant in base_chatbot:
base_chat_history.append({"role": "user", "content": user})
base_chat_history.append({"role": "assistant", "content": assistant})
new_chat_history = []
for user, assistant in new_chatbot:
new_chat_history.append({"role": "user", "content": user})
new_chat_history.append({"role": "assistant", "content": assistant})
base_messages = system_prompt_list + base_chat_history + input_text_list
new_messages = system_prompt_list + new_chat_history + input_text_list
base_input_ids = tokenizer.apply_chat_template(
base_messages,
add_generation_prompt=True,
return_tensors="pt"
).to(base_model.device).long()
new_input_ids = tokenizer.apply_chat_template(
new_messages,
add_generation_prompt=True,
return_tensors="pt"
).to(new_model.device).long()
base_generation_kwargs = dict(
input_ids=base_input_ids,
streamer=base_text_streamer,
max_new_tokens=max_new_tokens,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id,
do_sample=True if temperature > 0 else False,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
new_generation_kwargs = dict(
input_ids=new_input_ids,
streamer=new_text_streamer,
max_new_tokens=max_new_tokens,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id,
do_sample=True if temperature > 0 else False,
temperature=temperature,
top_p=top_p,
repetition_penalty=repetition_penalty,
)
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
base_thread.start()
base_chatbot.append([input_text, ""])
new_chatbot.append([input_text, ""])
for base_text in base_text_streamer:
if "<|eot_id|>" in base_text:
eot_location = base_text.find("<|eot_id|>")
base_text = base_text[:eot_location]
base_chatbot[-1][-1] += base_text
yield base_chatbot, new_chatbot
new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs)
new_thread.start()
for new_text in new_text_streamer:
if "<|eot_id|>" in new_text:
eot_location = new_text.find("<|eot_id|>")
new_text = new_text[:eot_location]
new_chatbot[-1][-1] += new_text
yield base_chatbot, new_chatbot
return base_chatbot, new_chatbot
def clear():
return [], []
with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
with gr.Column():
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
with gr.Row(variant="panel"):
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
submit_btn = gr.Button(value="Generate", variant="primary")
clear_btn = gr.Button(value="Clear", variant="secondary")
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
with gr.Accordion(label="Generation Configurations", open=False):
max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128)
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature", step=0.01)
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01)
repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1)
input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
demo.launch()