MohamedRashad commited on
Commit
e64d5d5
1 Parent(s): 412e761

Add Arabic-ORPO-Llama3 chatbot comparison functionality

Browse files
Files changed (2) hide show
  1. app.py +125 -0
  2. requirements.txt +3 -0
app.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
2
+ import torch
3
+ import gradio as gr
4
+ from threading import Thread
5
+
6
+ base_model_id = "meta-llama/Meta-Llama-3-8B-Instruct"
7
+ new_model_id = "MohamedRashad/Arabic-Orpo-Llama-3-8B-Instruct"
8
+
9
+ # Reload tokenizer and model
10
+ tokenizer = AutoTokenizer.from_pretrained(base_model_id)
11
+ base_model = AutoModelForCausalLM.from_pretrained(
12
+ base_model_id,
13
+ torch_dtype=torch.bfloat16,
14
+ device_map="auto",
15
+ ).eval()
16
+ new_model = AutoModelForCausalLM.from_pretrained(
17
+ new_model_id,
18
+ torch_dtype=torch.bfloat16,
19
+ device_map="auto",
20
+ ).eval()
21
+ terminators = [
22
+ tokenizer.eos_token_id,
23
+ tokenizer.convert_tokens_to_ids("<|eot_id|>"),
24
+ ]
25
+
26
+
27
+ def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
28
+ base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
29
+ new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
30
+
31
+ system_prompt_list = [{"role": "system", "content": system_prompt}]
32
+ input_text_list = [{"role": "user", "content": input_text}]
33
+
34
+ base_chat_history = []
35
+ for user, assistant in base_chatbot:
36
+ base_chat_history.append({"role": "user", "content": user})
37
+ base_chat_history.append({"role": "assistant", "content": assistant})
38
+
39
+ new_chat_history = []
40
+ for user, assistant in new_chatbot:
41
+ new_chat_history.append({"role": "user", "content": user})
42
+ new_chat_history.append({"role": "assistant", "content": assistant})
43
+
44
+ base_messages = system_prompt_list + base_chat_history + input_text_list
45
+ new_messages = system_prompt_list + new_chat_history + input_text_list
46
+
47
+ base_input_ids = tokenizer.apply_chat_template(
48
+ base_messages,
49
+ add_generation_prompt=True,
50
+ return_tensors="pt"
51
+ ).to(base_model.device).long()
52
+
53
+ new_input_ids = tokenizer.apply_chat_template(
54
+ new_messages,
55
+ add_generation_prompt=True,
56
+ return_tensors="pt"
57
+ ).to(new_model.device).long()
58
+
59
+ base_generation_kwargs = dict(
60
+ input_ids=base_input_ids,
61
+ streamer=base_text_streamer,
62
+ max_new_tokens=2048,
63
+ eos_token_id=terminators,
64
+ pad_token_id=tokenizer.eos_token_id,
65
+ do_sample=True,
66
+ temperature=0.2,
67
+ top_p=0.9,
68
+ )
69
+ new_generation_kwargs = dict(
70
+ input_ids=new_input_ids,
71
+ streamer=new_text_streamer,
72
+ max_new_tokens=2048,
73
+ eos_token_id=terminators,
74
+ pad_token_id=tokenizer.eos_token_id,
75
+ do_sample=True,
76
+ temperature=0.2,
77
+ top_p=0.9,
78
+ )
79
+
80
+ base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
81
+ base_thread.start()
82
+
83
+ base_chatbot.append([input_text, ""])
84
+ new_chatbot.append([input_text, ""])
85
+
86
+ for base_text in base_text_streamer:
87
+ if "<|eot_id|>" in base_text:
88
+ eot_location = base_text.find("<|eot_id|>")
89
+ base_text = base_text[:eot_location]
90
+ base_chatbot[-1][-1] += base_text
91
+ yield base_chatbot, new_chatbot
92
+
93
+ new_thread = Thread(target=new_model.generate, kwargs=new_generation_kwargs)
94
+ new_thread.start()
95
+
96
+ for new_text in new_text_streamer:
97
+ if "<|eot_id|>" in new_text:
98
+ eot_location = new_text.find("<|eot_id|>")
99
+ new_text = new_text[:eot_location]
100
+ new_chatbot[-1][-1] += new_text
101
+ yield base_chatbot, new_chatbot
102
+
103
+ return base_chatbot, new_chatbot
104
+
105
+ def clear():
106
+ return [], []
107
+
108
+ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
109
+ with gr.Column():
110
+ gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
111
+ system_prompt = gr.Textbox(lines=1, label="System Prompt", value="You are a pirate chatbot who always responds in pirate speak!")
112
+ with gr.Row(variant="panel"):
113
+ base_chatbot = gr.Chatbot(label=base_model_id, rtl=False, likeable=True, show_copy_button=True)
114
+ new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True)
115
+ with gr.Row(variant="panel"):
116
+ with gr.Column(scale=1):
117
+ submit_btn = gr.Button(value="Generate", variant="primary")
118
+ clear_btn = gr.Button(value="Clear", variant="secondary")
119
+ input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
120
+
121
+ input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot])
122
+ submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot], outputs=[base_chatbot, new_chatbot])
123
+ clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
124
+
125
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ transformers
2
+ torch
3
+ accelerate