Cartinoe5930 commited on
Commit
a71589c
1 Parent(s): 01f28a4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -33
app.py CHANGED
@@ -1,77 +1,271 @@
1
  import gradio as gr
 
 
 
 
 
2
 
3
- def inference(model_list, cot):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  if len(model_list) != 3:
5
  raise gr.Error("Please choose just '3' models! Neither more nor less!")
6
 
7
- if cot:
8
- b = "CoT was also used."
9
- else:
10
- b = ""
11
 
12
- a = f"Hello, {model_list[0]}, {model_list[1]}, and {model_list[2]}!! {b}"
13
  return {
14
  output_msg: gr.update(visible=True),
15
  output_col: gr.update(visible=True),
16
- model1_output1: a
 
 
 
 
 
 
 
 
 
 
17
  }
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>"""
20
 
21
  INTRODUCTION_TEXT = """
22
  The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325).
23
 
24
  Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗
25
- For more details, please refer to the GitHub Repository below.
26
 
27
  You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality!
28
- The Math, GSM8K, and MMLU Tabs show the results of the experiment, and for inference, please use the 'Inference' tab.
29
- Please check the more specific information in [GitHub Repository](https://github.com/gauss5930/LLM-Agora)!
 
 
 
 
 
 
 
 
 
30
  """
31
 
 
 
32
  RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>"""
33
 
34
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  gr.HTML(TITLE)
36
  gr.Markdown(INTRODUCTION_TEXT)
37
  with gr.Column():
38
  with gr.Tab("Inference"):
39
- with gr.Row():
 
 
40
  with gr.Column():
41
- model_list = gr.CheckboxGroup(["Llama2", "Alpaca", "Vicuna", "Koala", "Falcon", "Baize", "WizardLM", "Orca", "phi-1.5"], label="Model Selection", info="Choose 3 LMs to participate in LLM Agora.", type="value")
42
  cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?")
43
- with gr.Column():
44
  API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password")
45
- auth_token = gr.Textbox(label="Huggingface Authentication Token", value="", info="Please fill in your HuggingFace Authentication token.", placeholder="hf..", type="password")
46
- with gr.Column():
47
- question = gr.Textbox(value="", info="Please type your question!", placeholder="")
48
- output = gr.Textbox()
49
  submit = gr.Button("Submit")
 
50
  with gr.Row(visible=False) as output_msg:
51
  gr.HTML(RESPONSE_TEXT)
52
 
53
- with gr.Row(visible=False) as output_col:
54
- with gr.Column():
55
  model1_output1 = gr.Textbox(label="1️⃣ model's initial response")
56
- model1_output2 = gr.Textbox(label="1️⃣ model's revised response")
57
- model1_output3 = gr.Textbox(label="1️⃣ model's final response")
58
- with gr.Column():
59
  model2_output1 = gr.Textbox(label="2️⃣ model's initial response")
 
 
 
 
60
  model2_output2 = gr.Textbox(label="2️⃣ model's revised response")
 
 
 
 
61
  model2_output3 = gr.Textbox(label="2️⃣ model's final response")
62
- with gr.Column():
63
- model2_output1 = gr.Textbox(label="3️⃣ model's initial response")
64
- model2_output2 = gr.Textbox(label="3️⃣ model's revised response")
65
- model2_output3 = gr.Textbox(label="3️⃣ model's final response")
66
-
67
 
 
 
 
 
68
  with gr.Tab("Math"):
69
- output_math = gr.Textbox()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  with gr.Tab("GSM8K"):
71
- output_gsm = gr.Textbox()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  with gr.Tab("MMLU"):
73
- output_mmlu = gr.Textbox()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
- submit.click(inference, [model_list], [output_msg, output_col, model1_output1])
 
76
 
77
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ import json
3
+ import requests
4
+ import os
5
+ from model_inference import Inference
6
+ HF_TOKEN = os.environ.get("HF_TOKEN")
7
 
8
+ model_list = ["llama", "llama-chat", "vicuna", "falcon", "falcon-instruct", "orca", "wizardlm"]
9
+
10
+ with open("src/inference_endpoint.json", "r") as f:
11
+ inference_endpoint = json.load(f)
12
+
13
+ for i in range(len(model_list)):
14
+ inference_endpoint[model_list[i]]["headers"]["Authorization"] += HF_TOKEN
15
+
16
+ def warmup(model_list=model_list, model_inference_endpoints=inference_endpoint):
17
+ for i in range(len(model_list)):
18
+ API_URL = model_inference_endpoints[model_list[i]["API_URL"]]
19
+ headers = model_inference_endpoints[model_list[i]["headers"]]
20
+
21
+ def query(payload):
22
+ response = requests.post(API_URL, headers=headers, json=payload)
23
+ return response.json()
24
+
25
+ output = query({
26
+ "inputs": "Hello."
27
+ })
28
+ time.sleep(300)
29
+ return {
30
+ options: gr.update(visible=True),
31
+ inputbox: gr.update(visible=True),
32
+ warmup_button: gr.update(visible=False),
33
+ welcome_message: gr.update(visible=True)
34
+ }
35
+
36
+ def inference(model_list, API_KEY, cot):
37
  if len(model_list) != 3:
38
  raise gr.Error("Please choose just '3' models! Neither more nor less!")
39
 
40
+ model_response = Inference(model_list, API_KEY, cot)
 
 
 
41
 
 
42
  return {
43
  output_msg: gr.update(visible=True),
44
  output_col: gr.update(visible=True),
45
+ model1_output1: model_response["agent_response"][model_list[0]][0],
46
+ model2_output1: model_response["agent_response"][model_list[1]][0],
47
+ model3_output1: model_response["agent_response"][model_list[2]][0],
48
+ summarization_text1: model_response["summarization"][0],
49
+ model1_output2: model_response["agent_response"][model_list[0]][1],
50
+ model2_output2: model_response["agent_response"][model_list[1]][1],
51
+ model3_output2: model_response["agent_response"][model_list[2]][1],
52
+ summarization_text2: model_response["summarization"][1],
53
+ model1_output3: model_response["agent_response"][model_list[0]][2],
54
+ model2_output3: model_response["agent_response"][model_list[1]][2],
55
+ model3_output3: model_response["agent_response"][model_list[2]][2]
56
  }
57
 
58
+ def load_responses():
59
+ with open("result/Math/math_result.json", "r") as math_file:
60
+ math_responses = json.load(math_file)
61
+
62
+ with open("result/Math/math_result_cot.json", "r") as math_cot_file:
63
+ math_cot_responses = json.load(math_cot_file)
64
+
65
+ with open("result/GSM8K/gsm_result.json", "r") as gsm_file:
66
+ gsm_responses = json.load(gsm_file)
67
+
68
+ with open("result/GSM8K/gsm_result_cot.json", "r") as gsm_cot_file:
69
+ gsm_cot_responses = json.load(gsm_cot_file)
70
+
71
+ # with open("result/MMLU/mmlu_result.json", "r") as mmlu_file:
72
+ # mmlu_responses = json.load(mmlu_file)
73
+
74
+ # with open("result/MMLU/mmlu_result_cot.json", "r") as mmlu_cot_file:
75
+ # mmlu_cot_responses = json.load(mmlu_cot_file)
76
+
77
+ return math_responses, gsm_responses#, mmlu_responses
78
+
79
+ def load_questions(math, gsm, mmlu):
80
+ math_questions = []
81
+ gsm_auestions = []
82
+ # mmlu_questions = []
83
+ for i in range(100):
84
+ math_questions.append(f"{i+1}. " + math[i]["question"])
85
+ gsm_questions.append(f"{i+1}. " + gsm[i]["question"])
86
+ # mmlu_questions.append(f"{i+1}. " + mmlu[i]["question"])
87
+
88
+ return math_questions, gsm_questions#, mmlu_questions
89
+
90
+ math_result, gsm_result#, mmlu_result = load_responses()
91
+
92
+ math_questions, gsm_questions = load_questions(math_result, gsm_result)
93
+
94
  TITLE = """<h1 align="center">LLM Agora 🗣️🏦</h1>"""
95
 
96
  INTRODUCTION_TEXT = """
97
  The **LLM Agora** 🗣️🏦 aims to improve the quality of open-source LMs' responses through debate & revision introduced in [Improving Factuality and Reasoning in Language Models through Multiagent Debate](https://arxiv.org/abs/2305.14325).
98
 
99
  Do you know that? 🤔 **LLMs can also improve their responses by debating with other LLMs**! 😮 We applied this concept to several open-source LMs to verify that the open-source model, not the proprietary one, can sufficiently improve the response through discussion. 🤗
100
+ For more details, please refer to the [GitHub Repository](https://github.com/gauss5930/LLM-Agora).
101
 
102
  You can use LLM Agora with your own questions if the response of open-source LM is not satisfactory and you want to improve the quality!
103
+ The Math, GSM8K, and MMLU Tabs show the results of the experiment(Llama2, WizardLM2, Orca2), and for inference, please use the 'Inference' tab.
104
+
105
+ Here's how to use LLM Agora!
106
+
107
+ 1. Before start, click the 'Warm-up LLM Agora 🔥' button and wait until 'LLM Agora Ready!!' appears. (Go grab a coffee☕ since it takes 5 minutes!)
108
+ 2. Choose just 3 models! Neither more nor less!
109
+ 3. Check the CoT box if you want to utilize the Chain-of-Thought while inferencing.
110
+ 4. Please fill in your OpenAI API KEY, it will be used to use ChatGPT to summarize the responses.
111
+ 5. Type your question to Question box and click the 'Submit' button! If you do so, LLM Agora will show you improved answers! 🤗 (It will spend roughly a minute! Please wait for an answer!)
112
+
113
+ For more detailed information, please check '※ Specific information about LLM Agora' at the bottom of the page.
114
  """
115
 
116
+ WELCOME_TEXT = """<h1 align="center">🤗🔥 Welcome to LLM Agora 🔥🤗</h1>"""
117
+
118
  RESPONSE_TEXT = """<h1 align="center">🤗 Here are the responses to each model!! 🤗</h1>"""
119
 
120
+ SPECIFIC_INFORMATION = """
121
+ This is the specific information about LLM Agora!
122
+
123
+ **Model size**
124
+
125
+ |Model name|Model size|
126
+ |---|---|
127
+ |Llama2|13B|
128
+ |Llama2-Chat|13B|
129
+ |Vicuna|13B|
130
+ |Falcon|7B|
131
+ |Falcon-Instruct|7B|
132
+ |WizardLM|13B|
133
+ |Orca|13B|
134
+
135
+ **Agent numbers & Debate rounds**
136
+
137
+ - We limit the number of agents and debate rounds because of limitation of resources. As a result, we decided to use 3 agents and 2 rounds of debate!
138
+
139
+ **GitHub Repository**
140
+
141
+ - If you want to see more specific information, please check the [GitHub Repository](https://github.com/gauss5930/LLM-Agora) of LLM Agora!
142
+
143
+ **Citation**
144
+
145
+ ```
146
+ @article{du2023improving,
147
+ title={Improving Factuality and Reasoning in Language Models through Multiagent Debate},
148
+ author={Du, Yilun and Li, Shuang and Torralba, Antonio and Tenenbaum, Joshua B and Mordatch, Igor},
149
+ journal={arXiv preprint arXiv:2305.14325},
150
+ year={2023}
151
+ }
152
+ ```
153
+ """
154
+
155
+ with gr.Blocks(css=block_css) as demo:
156
  gr.HTML(TITLE)
157
  gr.Markdown(INTRODUCTION_TEXT)
158
  with gr.Column():
159
  with gr.Tab("Inference"):
160
+ warmup_button = gr.Button("Warm-up LLM Agora 🔥", visible=True)
161
+ welcome_message = gr.HTML(WELCOME_TEXT, visible=False)
162
+ with gr.Row(visible=False) as options:
163
  with gr.Column():
164
+ model_list = gr.CheckboxGroup(["Llama2🦙", "Llama2-Chat🦙", "Vicuna🦙", "Falcon🦅", "Falcon-Instruct🦅", "WizardLM🧙‍♂️", "Orca🐬"], label="Model Selection", info="Choose 3 LMs to participate in LLM Agora.", type="value")
165
  cot = gr.Checkbox(label="CoT", info="Do you want to use CoT for inference?")
166
+ with gr.Column() as inputbox:
167
  API_KEY = gr.Textbox(label="OpenAI API Key", value="", info="Please fill in your OpenAI API token.", placeholder="sk..", type="password")
168
+ with gr.Column(visible=False) as inputbox:
169
+ question = gr.Textbox(label="Question", value="", info="Please type your question!", placeholder="")
 
 
170
  submit = gr.Button("Submit")
171
+
172
  with gr.Row(visible=False) as output_msg:
173
  gr.HTML(RESPONSE_TEXT)
174
 
175
+ with gr.Column(visible=False) as output_col:
176
+ with gr.Row(elem_id="model1_response"):
177
  model1_output1 = gr.Textbox(label="1️⃣ model's initial response")
 
 
 
178
  model2_output1 = gr.Textbox(label="2️⃣ model's initial response")
179
+ model3_output1 = gr.Textbox(label="3️⃣ model's initial response")
180
+ summarization_text1 = gr.Textbox(lebel="Summarization 1")
181
+ with gr.Row(elem_id="model2_response"):
182
+ model1_output2 = gr.Textbox(label="1️⃣ model's revised response")
183
  model2_output2 = gr.Textbox(label="2️⃣ model's revised response")
184
+ model3_output2 = gr.Textbox(label="3️⃣ model's revised response")
185
+ summarization_text2 = gr.Textbox(label="Summarization 2")
186
+ with gr.Row(elem_id="model3_response"):
187
+ model1_output3 = gr.Textbox(label="1️⃣ model's final response")
188
  model2_output3 = gr.Textbox(label="2️⃣ model's final response")
189
+ model3_output3 = gr.Textbox(label="3️⃣ model's final response")
 
 
 
 
190
 
191
+ with gr.Accordion("※ Specific information about LLM Agora", open=False):
192
+ gr.Markdown(SPECIFIC_INFORMATION)
193
+
194
+
195
  with gr.Tab("Math"):
196
+ math_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
197
+ math_question_list = gr.Dropdown(math_questions, label="Math Question", every=0.1)
198
+
199
+ with gr.Column():
200
+ with gr.Row(elem_id="model1_response"):
201
+ math_model1_output1 = gr.Textbox(label="Llama2🦙's initial response") # value=[int(math_question_list[0])-1]["agent_response"][]
202
+ math_model2_output1 = gr.Textbox(label="WizardLM🧙‍♂️'s initial response")
203
+ math_model3_output1 = gr.Textbox(label="Orca🐬's initial response")
204
+ math_summarization_text1 = gr.Textbox(lebel="Summarization 1")
205
+ with gr.Row(elem_id="model2_response"):
206
+ math_model1_output2 = gr.Textbox(label="Llama2🦙's revised response")
207
+ math_model2_output2 = gr.Textbox(label="WizardLM🧙‍♂️'s revised response")
208
+ math_model3_output2 = gr.Textbox(label="Orca🐬's revised response")
209
+ math_summarization_text2 = gr.Textbox(label="Summarization 2")
210
+ with gr.Row(elem_id="model3_response"):
211
+ math_model1_output3 = gr.Textbox(label="Llama2🦙's final response")
212
+ math_model2_output3 = gr.Textbox(label="WizardLM🧙‍♂️'s final response")
213
+ math_model3_output3 = gr.Textbox(label="Orca🐬's final response")
214
+
215
+ gr.HTML("""<h1 align="center"> The result of Math </h1>""")
216
+ gr.Image(value="result/Math/math_result.png")
217
+
218
+
219
  with gr.Tab("GSM8K"):
220
+ gsm_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
221
+ gsm_question_list = gr.Dropdown(gsm_questions, label="Math Question")
222
+
223
+ with gr.Column():
224
+ with gr.Row(elem_id="model1_response"):
225
+ gsm_model1_output1 = gr.Textbox(label="Llama2🦙's initial response")
226
+ gsm_model2_output1 = gr.Textbox(label="WizardLM🧙‍♂️'s initial response")
227
+ gsm_model3_output1 = gr.Textbox(label="Orca🐬's initial response")
228
+ gsm_summarization_text1 = gr.Textbox(lebel="Summarization 1")
229
+ with gr.Row(elem_id="model2_response"):
230
+ gsm_model1_output2 = gr.Textbox(label="Llama2🦙's revised response")
231
+ gsm_model2_output2 = gr.Textbox(label="WizardLM🧙‍♂️'s revised response")
232
+ gsm_model3_output2 = gr.Textbox(label="Orca🐬's revised response")
233
+ gsm_summarization_text2 = gr.Textbox(label="Summarization 2")
234
+ with gr.Row(elem_id="model3_response"):
235
+ gsm_model1_output3 = gr.Textbox(label="Llama2🦙's final response")
236
+ gsm_model2_output3 = gr.Textbox(label="WizardLM🧙‍♂️'s final response")
237
+ gsm_model3_output3 = gr.Textbox(label="Orca🐬's final response")
238
+
239
+ gr.HTML("""<h1 align="center"> The result of GSM8K </h1>""")
240
+ gr.Image(value="result/GSM8K/gsm_result.png")
241
+
242
+
243
  with gr.Tab("MMLU"):
244
+ mmlu_cot = gr.Checkbox(label="CoT", info="If you want to see CoT result, please check the box.")
245
+ # mmlu_question_list = gr.Dropdown(mmlu_questions, label="Math Question")
246
+ with gr.Column():
247
+ with gr.Row(elem_id="model1_response"):
248
+ mmlu_model1_output1 = gr.Textbox(label="Llama2🦙's initial response")
249
+ mmlu_model2_output1 = gr.Textbox(label="WizardLM🧙‍♂️'s initial response")
250
+ mmlu_model3_output1 = gr.Textbox(label="Orca🐬's initial response")
251
+ mmlu_summarization_text1 = gr.Textbox(lebel="Summarization 1")
252
+ with gr.Row(elem_id="model2_response"):
253
+ mmlu_model1_output2 = gr.Textbox(label="Llama2🦙's revised response")
254
+ mmlu_model2_output2 = gr.Textbox(label="WizardLM🧙‍♂️'s revised response")
255
+ mmlu_model3_output2 = gr.Textbox(label="Orca🐬's revised response")
256
+ mmlu_summarization_text2 = gr.Textbox(label="Summarization 2")
257
+ with gr.Row(elem_id="model3_response"):
258
+ mmlu_model1_output3 = gr.Textbox(label="Llama2🦙's final response")
259
+ mmlu_model2_output3 = gr.Textbox(label="WizardLM🧙‍♂️'s final response")
260
+ mmlu_model3_output3 = gr.Textbox(label="Orca🐬's final response")
261
+
262
+ gr.HTML("""<h1 align="center"> The result of MMLU </h1>""")
263
+ gr.Image(value="result/MMLU/mmlu_result.png")
264
+
265
+ with gr.Accordion("※ Specific information about LLM Agora", open=False):
266
+ gr.Markdown(SPECIFIC_INFORMATION)
267
 
268
+ warmup_button.click(warmup, [], [options, inputbox, warmup_button, welcome_message])
269
+ submit.click(inference, [model_list, API_KEY, cot], [output_msg, output_col, model1_output1, model2_output1, model3_output1, summarization_text1, model1_output2, model2_output2, model3_output2, summarization_text2, model1_output3, model2_output3, model3_output3])
270
 
271
+ demo.launch()