Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
import torch | |
import gradio as gr | |
from threading import Thread | |
base_model_id = "NousResearch/Meta-Llama-3-8B-Instruct" | |
new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct" | |
# Reload tokenizer and model | |
tokenizer = AutoTokenizer.from_pretrained(base_model_id) | |
base_model = AutoModelForCausalLM.from_pretrained( | |
base_model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
).eval() | |
new_model = AutoModelForCausalLM.from_pretrained( | |
new_model_id, | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
).eval() | |
terminators = [ | |
tokenizer.eos_token_id, | |
tokenizer.convert_tokens_to_ids("<|eot_id|>"), | |
] | |
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1): | |
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True) | |
system_prompt_list = [{"role": "system", "content": system_prompt}] | |
input_text_list = [{"role": "user", "content": input_text}] | |
base_chat_history = [] | |
for user, assistant in base_chatbot: | |
base_chat_history.append({"role": "user", "content": user}) | |
base_chat_history.append({"role": "assistant", "content": assistant}) | |
new_chat_history = [] | |
for user, assistant in new_chatbot: | |
new_chat_history.append({"role": "user", "content": user}) | |
new_chat_history.append({"role": "assistant", "content": assistant}) | |
base_messages = system_prompt_list + base_chat_history + input_text_list | |
new_messages = system_prompt_list + new_chat_history + input_text_list | |
base_input_ids = tokenizer.apply_chat_template( | |
base_messages, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(base_model.device).long() | |
new_input_ids = tokenizer.apply_chat_template( | |
new_messages, | |
add_generation_prompt=True, | |
return_tensors="pt" | |
).to(new_model.device).long() | |
base_generation_kwargs = dict( | |
input_ids=base_input_ids, | |
streamer=base_text_streamer, | |
max_new_tokens=max_new_tokens, | |
eos_token_id=terminators, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True if temperature > 0 else False, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
new_generation_kwargs = dict( | |
input_ids=new_input_ids, | |
streamer=new_text_streamer, | |
max_new_tokens=max_new_tokens, | |
eos_token_id=terminators, | |
pad_token_id=tokenizer.eos_token_id, | |
do_sample=True if temperature > 0 else False, | |
temperature=temperature, | |
top_p=top_p, | |
repetition_penalty=repetition_penalty, | |
) | |
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs) | |
base_thread.start() | |
base_chatbot.append([input_text, ""]) | |
new_chatbot.append([input_text, ""]) | |
for base_text in base_text_streamer: | |
if "<|eot_id|>" in base_text: | |
eot_location = base_text.find("<|eot_id|>") | |
base_text = base_text[:eot_location] | |
base_chatbot[-1][-1] += base_text | |
yield base_chatbot, new_chatbot | |
new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs) | |
new_thread.start() | |
for new_text in new_text_streamer: | |
if "<|eot_id|>" in new_text: | |
eot_location = new_text.find("<|eot_id|>") | |
new_text = new_text[:eot_location] | |
new_chatbot[-1][-1] += new_text | |
yield base_chatbot, new_chatbot | |
return base_chatbot, new_chatbot | |
def clear(): | |
return [], [] | |
with gr.Blocks(title="Arabic-ORPO-Llama3") as demo: | |
with gr.Column(): | |
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>") | |
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True) | |
with gr.Row(variant="panel"): | |
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True, height=500) | |
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True, height=500) | |
with gr.Row(variant="panel"): | |
with gr.Column(scale=1): | |
submit_btn = gr.Button(value="Generate", variant="primary") | |
clear_btn = gr.Button(value="Clear", variant="secondary") | |
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True) | |
with gr.Accordion(label="Generation Configurations", open=False): | |
max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128) | |
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature", step=0.01) | |
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01) | |
repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1) | |
input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot]) | |
submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot]) | |
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot]) | |
demo.launch() | |