Spaces:
Running
on
Zero
Running
on
Zero
MohamedRashad
commited on
Commit
•
f0ac041
1
Parent(s):
6253bc5
Add generation configurations to chatbot interface
Browse files
app.py
CHANGED
@@ -25,7 +25,7 @@ terminators = [
|
|
25 |
]
|
26 |
|
27 |
@spaces.GPU(duration=120)
|
28 |
-
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
|
29 |
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
30 |
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
31 |
|
@@ -60,22 +60,24 @@ def generate_both(system_prompt, input_text, base_chatbot, new_chatbot):
|
|
60 |
base_generation_kwargs = dict(
|
61 |
input_ids=base_input_ids,
|
62 |
streamer=base_text_streamer,
|
63 |
-
max_new_tokens=
|
64 |
eos_token_id=terminators,
|
65 |
pad_token_id=tokenizer.eos_token_id,
|
66 |
-
do_sample=True,
|
67 |
-
temperature=
|
68 |
-
top_p=
|
|
|
69 |
)
|
70 |
new_generation_kwargs = dict(
|
71 |
input_ids=new_input_ids,
|
72 |
streamer=new_text_streamer,
|
73 |
-
max_new_tokens=
|
74 |
eos_token_id=terminators,
|
75 |
pad_token_id=tokenizer.eos_token_id,
|
76 |
-
do_sample=True,
|
77 |
-
temperature=
|
78 |
-
top_p=
|
|
|
79 |
)
|
80 |
|
81 |
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
|
@@ -111,16 +113,21 @@ with gr.Blocks(title="Arabic-ORPO-Llama3") as demo:
|
|
111 |
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
|
112 |
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
|
113 |
with gr.Row(variant="panel"):
|
114 |
-
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True)
|
115 |
-
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True)
|
116 |
with gr.Row(variant="panel"):
|
117 |
with gr.Column(scale=1):
|
118 |
submit_btn = gr.Button(value="Generate", variant="primary")
|
119 |
clear_btn = gr.Button(value="Clear", variant="secondary")
|
120 |
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
124 |
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
|
125 |
|
126 |
demo.launch()
|
|
|
25 |
]
|
26 |
|
27 |
@spaces.GPU(duration=120)
|
28 |
+
def generate_both(system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens=2048, temperature=0.2, top_p=0.9, repetition_penalty=1.1):
|
29 |
base_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
30 |
new_text_streamer = TextIteratorStreamer(tokenizer, skip_prompt=True)
|
31 |
|
|
|
60 |
base_generation_kwargs = dict(
|
61 |
input_ids=base_input_ids,
|
62 |
streamer=base_text_streamer,
|
63 |
+
max_new_tokens=max_new_tokens,
|
64 |
eos_token_id=terminators,
|
65 |
pad_token_id=tokenizer.eos_token_id,
|
66 |
+
do_sample=True if temperature > 0 else False,
|
67 |
+
temperature=temperature,
|
68 |
+
top_p=top_p,
|
69 |
+
repetition_penalty=repetition_penalty,
|
70 |
)
|
71 |
new_generation_kwargs = dict(
|
72 |
input_ids=new_input_ids,
|
73 |
streamer=new_text_streamer,
|
74 |
+
max_new_tokens=max_new_tokens,
|
75 |
eos_token_id=terminators,
|
76 |
pad_token_id=tokenizer.eos_token_id,
|
77 |
+
do_sample=True if temperature > 0 else False,
|
78 |
+
temperature=temperature,
|
79 |
+
top_p=top_p,
|
80 |
+
repetition_penalty=repetition_penalty,
|
81 |
)
|
82 |
|
83 |
base_thread = Thread(target=base_model.generate, kwargs=base_generation_kwargs)
|
|
|
113 |
gr.HTML("<center><h1>Arabic Chatbot Comparison</h1></center>")
|
114 |
system_prompt = gr.Textbox(lines=1, label="System Prompt", value="أنت متحدث لبق باللغة العربية!", rtl=True, text_align="right", show_copy_button=True)
|
115 |
with gr.Row(variant="panel"):
|
116 |
+
base_chatbot = gr.Chatbot(label=base_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
|
117 |
+
new_chatbot = gr.Chatbot(label=new_model_id, rtl=True, likeable=True, show_copy_button=True, height=500)
|
118 |
with gr.Row(variant="panel"):
|
119 |
with gr.Column(scale=1):
|
120 |
submit_btn = gr.Button(value="Generate", variant="primary")
|
121 |
clear_btn = gr.Button(value="Clear", variant="secondary")
|
122 |
input_text = gr.Textbox(lines=1, label="", value="مرحبا", rtl=True, text_align="right", scale=3, show_copy_button=True)
|
123 |
+
with gr.Accordion(label="Generation Configurations", open=False):
|
124 |
+
max_new_tokens = gr.Slider(minimum=128, maximum=4096, value=2048, label="Max New Tokens", step=128)
|
125 |
+
temperature = gr.Slider(minimum=0.0, maximum=1.0, value=0.2, label="Temperature", step=0.01)
|
126 |
+
top_p = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, label="Top-p", step=0.01)
|
127 |
+
repetition_penalty = gr.Slider(minimum=0.1, maximum=2.0, value=1.1, label="Repetition Penalty", step=0.1)
|
128 |
+
|
129 |
+
input_text.submit(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
|
130 |
+
submit_btn.click(generate_both, inputs=[system_prompt, input_text, base_chatbot, new_chatbot, max_new_tokens, temperature, top_p, repetition_penalty], outputs=[base_chatbot, new_chatbot])
|
131 |
clear_btn.click(clear, outputs=[base_chatbot, new_chatbot])
|
132 |
|
133 |
demo.launch()
|