Spaces:
Runtime error
Runtime error
Cartinoe5930
commited on
Commit
•
a71589c
1
Parent(s):
01f28a4
Update app.py
Browse files
app.py
CHANGED
@@ -1,77 +1,271 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
if len(model_list) != 3:
|
5 |
raise gr.Error("Please choose just '3' models! Neither more nor less!")
|
6 |
|
7 |
-
|
8 |
-
b = "CoT was also used."
|
9 |
-
else:
|
10 |
-
b = ""
|
11 |
|
12 |
-
a = f"Hello, {model_list[0]}, {model_list[1]}, and {model_list[2]}!! {b}"
|
13 |
return {
|
14 |
output_msg: gr.update(visible=True),
|
15 |
output_col: gr.update(visible=True),
|
16 |
-
model1_output1:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
}
|
18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>"""
|
20 |
|
21 |
INTRODUCTION_TEXT = """
|
22 |
The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325).
|
23 |
|
24 |
Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗
|
25 |
-
For more details, please refer to the GitHub Repository
|
26 |
|
27 |
You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality!
|
28 |
-
The Math, GSM8K, and MMLU Tabs show the results of the experiment, and for inference, please use the 'Inference' tab.
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
"""
|
31 |
|
|
|
|
|
32 |
RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>"""
|
33 |
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
gr.HTML(TITLE)
|
36 |
gr.Markdown(INTRODUCTION_TEXT)
|
37 |
with gr.Column():
|
38 |
with gr.Tab("Inference"):
|
39 |
-
|
|
|
|
|
40 |
with gr.Column():
|
41 |
-
model_list = gr.CheckboxGroup(["Llama2", "
|
42 |
cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?")
|
43 |
-
with gr.Column():
|
44 |
API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password")
|
45 |
-
|
46 |
-
|
47 |
-
question = gr.Textbox(value="", info="Please type your question!", placeholder="")
|
48 |
-
output = gr.Textbox()
|
49 |
submit = gr.Button("Submit")
|
|
|
50 |
with gr.Row(visible=False) as output_msg:
|
51 |
gr.HTML(RESPONSE_TEXT)
|
52 |
|
53 |
-
with gr.
|
54 |
-
with gr.
|
55 |
model1_output1 = gr.Textbox(label="1️⃣ model's initial response")
|
56 |
-
model1_output2 = gr.Textbox(label="1️⃣ model's revised response")
|
57 |
-
model1_output3 = gr.Textbox(label="1️⃣ model's final response")
|
58 |
-
with gr.Column():
|
59 |
model2_output1 = gr.Textbox(label="2️⃣ model's initial response")
|
|
|
|
|
|
|
|
|
60 |
model2_output2 = gr.Textbox(label="2️⃣ model's revised response")
|
|
|
|
|
|
|
|
|
61 |
model2_output3 = gr.Textbox(label="2️⃣ model's final response")
|
62 |
-
|
63 |
-
model2_output1 = gr.Textbox(label="3️⃣ model's initial response")
|
64 |
-
model2_output2 = gr.Textbox(label="3️⃣ model's revised response")
|
65 |
-
model2_output3 = gr.Textbox(label="3️⃣ model's final response")
|
66 |
-
|
67 |
|
|
|
|
|
|
|
|
|
68 |
with gr.Tab("Math"):
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
with gr.Tab("GSM8K"):
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
with gr.Tab("MMLU"):
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
75 |
-
|
|
|
76 |
|
77 |
-
demo.launch(
|
|
|
1 |
import gradio as gr
|
2 |
+
import json
|
3 |
+
import requests
|
4 |
+
import os
|
5 |
+
from model_inference import Inference
|
6 |
+
HF_TOKEN = os.environ.get("HF_TOKEN")
|
7 |
|
8 |
+
model_list = ["llama", "llama-chat", "vicuna", "falcon", "falcon-instruct", "orca", "wizardlm"]
|
9 |
+
|
10 |
+
with open("src/inference_endpoint.json", "r") as f:
|
11 |
+
inference_endpoint = json.load(f)
|
12 |
+
|
13 |
+
for i in range(len(model_list)):
|
14 |
+
inference_endpoint[model_list[i]]["headers"]["Authorization"] += HF_TOKEN
|
15 |
+
|
16 |
+
def warmup(model_list=model_list, model_inference_endpoints=inference_endpoint):
|
17 |
+
for i in range(len(model_list)):
|
18 |
+
API_URL = model_inference_endpoints[model_list[i]["API_URL"]]
|
19 |
+
headers = model_inference_endpoints[model_list[i]["headers"]]
|
20 |
+
|
21 |
+
def query(payload):
|
22 |
+
response = requests.post(API_URL, headers=headers, json=payload)
|
23 |
+
return response.json()
|
24 |
+
|
25 |
+
output = query({
|
26 |
+
"inputs": "Hello."
|
27 |
+
})
|
28 |
+
time.sleep(300)
|
29 |
+
return {
|
30 |
+
options: gr.update(visible=True),
|
31 |
+
inputbox: gr.update(visible=True),
|
32 |
+
warmup_button: gr.update(visible=False),
|
33 |
+
welcome_message: gr.update(visible=True)
|
34 |
+
}
|
35 |
+
|
36 |
+
def inference(model_list, API_KEY, cot):
|
37 |
if len(model_list) != 3:
|
38 |
raise gr.Error("Please choose just '3' models! Neither more nor less!")
|
39 |
|
40 |
+
model_response = Inference(model_list, API_KEY, cot)
|
|
|
|
|
|
|
41 |
|
|
|
42 |
return {
|
43 |
output_msg: gr.update(visible=True),
|
44 |
output_col: gr.update(visible=True),
|
45 |
+
model1_output1: model_response["agent_response"][model_list[0]][0],
|
46 |
+
model2_output1: model_response["agent_response"][model_list[1]][0],
|
47 |
+
model3_output1: model_response["agent_response"][model_list[2]][0],
|
48 |
+
summarization_text1: model_response["summarization"][0],
|
49 |
+
model1_output2: model_response["agent_response"][model_list[0]][1],
|
50 |
+
model2_output2: model_response["agent_response"][model_list[1]][1],
|
51 |
+
model3_output2: model_response["agent_response"][model_list[2]][1],
|
52 |
+
summarization_text2: model_response["summarization"][1],
|
53 |
+
model1_output3: model_response["agent_response"][model_list[0]][2],
|
54 |
+
model2_output3: model_response["agent_response"][model_list[1]][2],
|
55 |
+
model3_output3: model_response["agent_response"][model_list[2]][2]
|
56 |
}
|
57 |
|
58 |
+
def load_responses():
|
59 |
+
with open("result/Math/math_result.json", "r") as math_file:
|
60 |
+
math_responses = json.load(math_file)
|
61 |
+
|
62 |
+
with open("result/Math/math_result_cot.json", "r") as math_cot_file:
|
63 |
+
math_cot_responses = json.load(math_cot_file)
|
64 |
+
|
65 |
+
with open("result/GSM8K/gsm_result.json", "r") as gsm_file:
|
66 |
+
gsm_responses = json.load(gsm_file)
|
67 |
+
|
68 |
+
with open("result/GSM8K/gsm_result_cot.json", "r") as gsm_cot_file:
|
69 |
+
gsm_cot_responses = json.load(gsm_cot_file)
|
70 |
+
|
71 |
+
# with open("result/MMLU/mmlu_result.json", "r") as mmlu_file:
|
72 |
+
# mmlu_responses = json.load(mmlu_file)
|
73 |
+
|
74 |
+
# with open("result/MMLU/mmlu_result_cot.json", "r") as mmlu_cot_file:
|
75 |
+
# mmlu_cot_responses = json.load(mmlu_cot_file)
|
76 |
+
|
77 |
+
return math_responses, gsm_responses#, mmlu_responses
|
78 |
+
|
79 |
+
def load_questions(math, gsm, mmlu):
|
80 |
+
math_questions = []
|
81 |
+
gsm_auestions = []
|
82 |
+
# mmlu_questions = []
|
83 |
+
for i in range(100):
|
84 |
+
math_questions.append(f"{i+1}. " + math[i]["question"])
|
85 |
+
gsm_questions.append(f"{i+1}. " + gsm[i]["question"])
|
86 |
+
# mmlu_questions.append(f"{i+1}. " + mmlu[i]["question"])
|
87 |
+
|
88 |
+
return math_questions, gsm_questions#, mmlu_questions
|
89 |
+
|
90 |
+
math_result, gsm_result#, mmlu_result = load_responses()
|
91 |
+
|
92 |
+
math_questions, gsm_questions = load_questions(math_result, gsm_result)
|
93 |
+
|
94 |
TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>"""
|
95 |
|
96 |
INTRODUCTION_TEXT = """
|
97 |
The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325).
|
98 |
|
99 |
Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗
|
100 |
+
For more details, please refer to the [GitHub Repository](https://github.com/gauss5930/LLM-Agora).
|
101 |
|
102 |
You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality!
|
103 |
+
The Math, GSM8K, and MMLU Tabs show the results of the experiment(Llama2, WizardLM2, Orca2), and for inference, please use the 'Inference' tab.
|
104 |
+
|
105 |
+
Here's how to use LLM Agora!
|
106 |
+
|
107 |
+
1. Before start, click the 'Warm-up LLM Agora 🔥' button and wait until 'LLM Agora Ready!!' appears. (Go grab a coffee☕ since it takes 5 minutes!)
|
108 |
+
2. Choose just 3 models! Neither more nor less!
|
109 |
+
3. Check the CoT box if you want to utilize the Chain-of-Thought while inferencing.
|
110 |
+
4. Please fill in your OpenAI API KEY, it will be used to use ChatGPT to summarize the responses.
|
111 |
+
5. Type your question to Question box and click the 'Submit' button! If you do so, LLM Agora will show you improved answers! 🤗 (It will spend roughly a minute! Please wait for an answer!)
|
112 |
+
|
113 |
+
For more detailed information, please check '※ Specific information about LLM Agora' at the bottom of the page.
|
114 |
"""
|
115 |
|
116 |
+
WELCOME_TEXT = """<h1 align="center">🤗🔥 Welcome to LLM Agora 🔥🤗</h1>"""
|
117 |
+
|
118 |
RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>"""
|
119 |
|
120 |
+
SPECIFIC_INFORMATION = """
|
121 |
+
This is the specific information about LLM Agora!
|
122 |
+
|
123 |
+
**Model size**
|
124 |
+
|
125 |
+
|Model name|Model size|
|
126 |
+
|---|---|
|
127 |
+
|Llama2|13B|
|
128 |
+
|Llama2-Chat|13B|
|
129 |
+
|Vicuna|13B|
|
130 |
+
|Falcon|7B|
|
131 |
+
|Falcon-Instruct|7B|
|
132 |
+
|WizardLM|13B|
|
133 |
+
|Orca|13B|
|
134 |
+
|
135 |
+
**Agent numbers & Debate rounds**
|
136 |
+
|
137 |
+
- We limit the number of agents and debate rounds because of limitation of resources. As a result, we decided to use 3 agents and 2 rounds of debate!
|
138 |
+
|
139 |
+
**GitHub Repository**
|
140 |
+
|
141 |
+
- If you want to see more specific information, please check the [GitHub Repository](https://github.com/gauss5930/LLM-Agora) of LLM Agora!
|
142 |
+
|
143 |
+
**Citation**
|
144 |
+
|
145 |
+
```
|
146 |
+
@article{du2023improving,
|
147 |
+
title={Improving Factuality and Reasoning in Language Models through Multiagent Debate},
|
148 |
+
author={Du, Yilun and Li, Shuang and Torralba, Antonio and Tenenbaum, Joshua B and Mordatch, Igor},
|
149 |
+
journal={arXiv preprint arXiv:2305.14325},
|
150 |
+
year={2023}
|
151 |
+
}
|
152 |
+
```
|
153 |
+
"""
|
154 |
+
|
155 |
+
with gr.Blocks(css=block_css) as demo:
|
156 |
gr.HTML(TITLE)
|
157 |
gr.Markdown(INTRODUCTION_TEXT)
|
158 |
with gr.Column():
|
159 |
with gr.Tab("Inference"):
|
160 |
+
warmup_button = gr.Button("Warm-up LLM Agora 🔥", visible=True)
|
161 |
+
welcome_message = gr.HTML(WELCOME_TEXT, visible=False)
|
162 |
+
with gr.Row(visible=False) as options:
|
163 |
with gr.Column():
|
164 |
+
model_list = gr.CheckboxGroup(["Llama2🦙", "Llama2-Chat🦙", "Vicuna🦙", "Falcon🦅", "Falcon-Instruct🦅", "WizardLM🧙♂️", "Orca🐬"], label="Model Selection", info="Choose 3 LMs to participate in LLM Agora.", type="value")
|
165 |
cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?")
|
166 |
+
with gr.Column() as inputbox:
|
167 |
API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password")
|
168 |
+
with gr.Column(visible=False) as inputbox:
|
169 |
+
question = gr.Textbox(label="Question", value="", info="Please type your question!", placeholder="")
|
|
|
|
|
170 |
submit = gr.Button("Submit")
|
171 |
+
|
172 |
with gr.Row(visible=False) as output_msg:
|
173 |
gr.HTML(RESPONSE_TEXT)
|
174 |
|
175 |
+
with gr.Column(visible=False) as output_col:
|
176 |
+
with gr.Row(elem_id="model1_response"):
|
177 |
model1_output1 = gr.Textbox(label="1️⃣ model's initial response")
|
|
|
|
|
|
|
178 |
model2_output1 = gr.Textbox(label="2️⃣ model's initial response")
|
179 |
+
model3_output1 = gr.Textbox(label="3️⃣ model's initial response")
|
180 |
+
summarization_text1 = gr.Textbox(lebel="Summarization 1")
|
181 |
+
with gr.Row(elem_id="model2_response"):
|
182 |
+
model1_output2 = gr.Textbox(label="1️⃣ model's revised response")
|
183 |
model2_output2 = gr.Textbox(label="2️⃣ model's revised response")
|
184 |
+
model3_output2 = gr.Textbox(label="3️⃣ model's revised response")
|
185 |
+
summarization_text2 = gr.Textbox(label="Summarization 2")
|
186 |
+
with gr.Row(elem_id="model3_response"):
|
187 |
+
model1_output3 = gr.Textbox(label="1️⃣ model's final response")
|
188 |
model2_output3 = gr.Textbox(label="2️⃣ model's final response")
|
189 |
+
model3_output3 = gr.Textbox(label="3️⃣ model's final response")
|
|
|
|
|
|
|
|
|
190 |
|
191 |
+
with gr.Accordion("※ Specific information about LLM Agora", open=False):
|
192 |
+
gr.Markdown(SPECIFIC_INFORMATION)
|
193 |
+
|
194 |
+
|
195 |
with gr.Tab("Math"):
|
196 |
+
math_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
|
197 |
+
math_question_list = gr.Dropdown(math_questions, label="Math Question", every=0.1)
|
198 |
+
|
199 |
+
with gr.Column():
|
200 |
+
with gr.Row(elem_id="model1_response"):
|
201 |
+
math_model1_output1 = gr.Textbox(label="Llama2🦙's initial response") # value=[int(math_question_list[0])-1]["agent_response"][]
|
202 |
+
math_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s initial response")
|
203 |
+
math_model3_output1 = gr.Textbox(label="Orca🐬's initial response")
|
204 |
+
math_summarization_text1 = gr.Textbox(lebel="Summarization 1")
|
205 |
+
with gr.Row(elem_id="model2_response"):
|
206 |
+
math_model1_output2 = gr.Textbox(label="Llama2🦙's revised response")
|
207 |
+
math_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s revised response")
|
208 |
+
math_model3_output2 = gr.Textbox(label="Orca🐬's revised response")
|
209 |
+
math_summarization_text2 = gr.Textbox(label="Summarization 2")
|
210 |
+
with gr.Row(elem_id="model3_response"):
|
211 |
+
math_model1_output3 = gr.Textbox(label="Llama2🦙's final response")
|
212 |
+
math_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s final response")
|
213 |
+
math_model3_output3 = gr.Textbox(label="Orca🐬's final response")
|
214 |
+
|
215 |
+
gr.HTML("""<h1 align="center"> The result of Math </h1>""")
|
216 |
+
gr.Image(value="result/Math/math_result.png")
|
217 |
+
|
218 |
+
|
219 |
with gr.Tab("GSM8K"):
|
220 |
+
gsm_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
|
221 |
+
gsm_question_list = gr.Dropdown(gsm_questions, label="Math Question")
|
222 |
+
|
223 |
+
with gr.Column():
|
224 |
+
with gr.Row(elem_id="model1_response"):
|
225 |
+
gsm_model1_output1 = gr.Textbox(label="Llama2🦙's initial response")
|
226 |
+
gsm_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s initial response")
|
227 |
+
gsm_model3_output1 = gr.Textbox(label="Orca🐬's initial response")
|
228 |
+
gsm_summarization_text1 = gr.Textbox(lebel="Summarization 1")
|
229 |
+
with gr.Row(elem_id="model2_response"):
|
230 |
+
gsm_model1_output2 = gr.Textbox(label="Llama2🦙's revised response")
|
231 |
+
gsm_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s revised response")
|
232 |
+
gsm_model3_output2 = gr.Textbox(label="Orca🐬's revised response")
|
233 |
+
gsm_summarization_text2 = gr.Textbox(label="Summarization 2")
|
234 |
+
with gr.Row(elem_id="model3_response"):
|
235 |
+
gsm_model1_output3 = gr.Textbox(label="Llama2🦙's final response")
|
236 |
+
gsm_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s final response")
|
237 |
+
gsm_model3_output3 = gr.Textbox(label="Orca🐬's final response")
|
238 |
+
|
239 |
+
gr.HTML("""<h1 align="center"> The result of GSM8K </h1>""")
|
240 |
+
gr.Image(value="result/GSM8K/gsm_result.png")
|
241 |
+
|
242 |
+
|
243 |
with gr.Tab("MMLU"):
|
244 |
+
mmlu_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
|
245 |
+
# mmlu_question_list = gr.Dropdown(mmlu_questions, label="Math Question")
|
246 |
+
with gr.Column():
|
247 |
+
with gr.Row(elem_id="model1_response"):
|
248 |
+
mmlu_model1_output1 = gr.Textbox(label="Llama2🦙's initial response")
|
249 |
+
mmlu_model2_output1 = gr.Textbox(label="WizardLM🧙♂️'s initial response")
|
250 |
+
mmlu_model3_output1 = gr.Textbox(label="Orca🐬's initial response")
|
251 |
+
mmlu_summarization_text1 = gr.Textbox(lebel="Summarization 1")
|
252 |
+
with gr.Row(elem_id="model2_response"):
|
253 |
+
mmlu_model1_output2 = gr.Textbox(label="Llama2🦙's revised response")
|
254 |
+
mmlu_model2_output2 = gr.Textbox(label="WizardLM🧙♂️'s revised response")
|
255 |
+
mmlu_model3_output2 = gr.Textbox(label="Orca🐬's revised response")
|
256 |
+
mmlu_summarization_text2 = gr.Textbox(label="Summarization 2")
|
257 |
+
with gr.Row(elem_id="model3_response"):
|
258 |
+
mmlu_model1_output3 = gr.Textbox(label="Llama2🦙's final response")
|
259 |
+
mmlu_model2_output3 = gr.Textbox(label="WizardLM🧙♂️'s final response")
|
260 |
+
mmlu_model3_output3 = gr.Textbox(label="Orca🐬's final response")
|
261 |
+
|
262 |
+
gr.HTML("""<h1 align="center"> The result of MMLU </h1>""")
|
263 |
+
gr.Image(value="result/MMLU/mmlu_result.png")
|
264 |
+
|
265 |
+
with gr.Accordion("※ Specific information about LLM Agora", open=False):
|
266 |
+
gr.Markdown(SPECIFIC_INFORMATION)
|
267 |
|
268 |
+
warmup_button.click(warmup, [], [options, inputbox, warmup_button, welcome_message])
|
269 |
+
submit.click(inference, [model_list, API_KEY, cot], [output_msg, output_col, model1_output1, model2_output1, model3_output1, summarization_text1, model1_output2, model2_output2, model3_output2, summarization_text2, model1_output3, model2_output3, model3_output3])
|
270 |
|
271 |
+
demo.launch()
|