MohamedRashad's picture
Update app.py
eaed44f verified
raw
history blame
No virus
4.81 kB
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
import torch
import gradio as gr
from threading import Thread
base_model_id = "NousResearch/Meta-Llama-3-8B-Instruct"
new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct"
# Reload tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(base_model_id)
base_model = AutoModelForCausalLM.from_pretrained(
base_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
new_model = AutoModelForCausalLM.from_pretrained(
new_model_id,
torch_dtype=torch.bfloat16,
device_map="auto",
).eval()
terminators = [
tokenizer.eos_token_id,
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]
@spaces.GPU(duration=120)
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
system_prompt_list = [{"role": "system", "content": system_prompt}]
input_text_list = [{"role": "user", "content": input_text}]
base_chat_history = []
for user, assistant in base_chatbot:
base_chat_history.append({"role": "user", "content": user})
base_chat_history.append({"role": "assistant", "content": assistant})
new_chat_history = []
for user, assistant in new_chatbot:
new_chat_history.append({"role": "user", "content": user})
new_chat_history.append({"role": "assistant", "content": assistant})
base_messages = system_prompt_list + base_chat_history + input_text_list
new_messages = system_prompt_list + new_chat_history + input_text_list
base_input_ids = tokenizer.apply_chat_template(
base_messages,
add_generation_prompt=True,
return_tensors="pt"
).to(base_model.device).long()
new_input_ids = tokenizer.apply_chat_template(
new_messages,
add_generation_prompt=True,
return_tensors="pt"
).to(new_model.device).long()
base_generation_kwargs = dict(
input_ids=base_input_ids,
streamer=base_text_streamer,
max_new_tokens=2048,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.2,
top_p=0.9,
)
new_generation_kwargs = dict(
input_ids=new_input_ids,
streamer=new_text_streamer,
max_new_tokens=2048,
eos_token_id=terminators,
pad_token_id=tokenizer.eos_token_id,
do_sample=True,
temperature=0.2,
top_p=0.9,
)
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
base_thread.start()
base_chatbot.append([input_text, ""])
new_chatbot.append([input_text, ""])
for base_text in base_text_streamer:
if "<|eot_id|>" in base_text:
eot_location = base_text.find("<|eot_id|>")
base_text = base_text[:eot_location]
base_chatbot[-1][-1] += base_text
yield base_chatbot, new_chatbot
new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs)
new_thread.start()
for new_text in new_text_streamer:
if "<|eot_id|>" in new_text:
eot_location = new_text.find("<|eot_id|>")
new_text = new_text[:eot_location]
new_chatbot[-1][-1] += new_text
yield base_chatbot, new_chatbot
return base_chatbot, new_chatbot
def clear():
return [], []
with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
with gr.Column():
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
with gr.Row(variant="panel"):
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True)
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True)
with gr.Row(variant="panel"):
with gr.Column(scale=1):
submit_btn = gr.Button(value="Generate", variant="primary")
clear_btn = gr.Button(value="Clear", variant="secondary")
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot])
submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot])
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
demo.launch()