suhyun.kang commited on
Commit
3c495cc
1 Parent(s): b727183

[#27] Add instructions for each category

Browse files

Changes:
- Extracted functions related to LLM responses to response.py
- Renamed the "response type" to "category"
- Added instructions for each category
- Blocked the submit until the user has selected a category

Files changed (2) hide show
  1. app.py +37 -109
  2. response.py +98 -0
app.py CHANGED
@@ -3,35 +3,25 @@ It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
5
  import enum
6
- from random import sample
7
  from uuid import uuid4
8
 
9
  import firebase_admin
10
  from firebase_admin import firestore
11
  import gradio as gr
12
- from litellm import completion
13
 
14
  from leaderboard import build_leaderboard
 
 
15
 
16
  # TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
17
  db_app = firebase_admin.initialize_app()
18
  db = firestore.client()
19
 
20
- # TODO(#1): Add more models.
21
- SUPPORTED_MODELS = [
22
- "gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
23
- ]
24
-
25
  SUPPORTED_TRANSLATION_LANGUAGES = [
26
  "Korean", "English", "Chinese", "Japanese", "Spanish", "French"
27
  ]
28
 
29
 
30
- class ResponseType(enum.Enum):
31
- SUMMARIZE = "Summarize"
32
- TRANSLATE = "Translate"
33
-
34
-
35
  class VoteOptions(enum.Enum):
36
  MODEL_A = "Model A is better"
37
  MODEL_B = "Model B is better"
@@ -39,107 +29,40 @@ class VoteOptions(enum.Enum):
39
 
40
 
41
  def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
42
- user_prompt, res_type, source_lang, target_lang):
43
  doc_id = uuid4().hex
44
  winner = VoteOptions(vote_button).name.lower()
45
 
46
- if res_type == ResponseType.SUMMARIZE.value:
 
 
 
 
 
 
 
 
 
 
 
 
47
  doc_ref = db.collection("arena-summarizations").document(doc_id)
48
- doc_ref.set({
49
- "id": doc_id,
50
- "prompt": user_prompt,
51
- "model_a": model_a_name,
52
- "model_b": model_b_name,
53
- "model_a_response": response_a,
54
- "model_b_response": response_b,
55
- "winner": winner,
56
- "timestamp": firestore.SERVER_TIMESTAMP
57
- })
58
  return
59
 
60
- if res_type == ResponseType.TRANSLATE.value:
61
  doc_ref = db.collection("arena-translations").document(doc_id)
62
- doc_ref.set({
63
- "id": doc_id,
64
- "prompt": user_prompt,
65
- "model_a": model_a_name,
66
- "model_b": model_b_name,
67
- "model_a_response": response_a,
68
- "model_b_response": response_b,
69
- "source_language": source_lang.lower(),
70
- "target_language": target_lang.lower(),
71
- "winner": winner,
72
- "timestamp": firestore.SERVER_TIMESTAMP
73
- })
74
-
75
-
76
- def response_generator(response: str):
77
- for part in response:
78
- content = part.choices[0].delta.content
79
- if content is None:
80
- continue
81
-
82
- # To simulate a stream, we yield each character of the response.
83
- for character in content:
84
- yield character
85
-
86
-
87
- def get_responses(user_prompt):
88
- models = sample(SUPPORTED_MODELS, 2)
89
-
90
- generators = []
91
- for model in models:
92
- try:
93
- # TODO(#1): Allow user to set configuration.
94
- response = completion(model=model,
95
- messages=[{
96
- "content": user_prompt,
97
- "role": "user"
98
- }],
99
- stream=True)
100
- generators.append(response_generator(response))
101
-
102
- # TODO(#1): Narrow down the exception type.
103
- except Exception as e: # pylint: disable=broad-except
104
- print(f"Error in bot_response: {e}")
105
- raise e
106
-
107
- responses = ["", ""]
108
-
109
- # It simulates concurrent response generation from two models.
110
- while True:
111
- stop = True
112
-
113
- for i in range(len(generators)):
114
- try:
115
- yielded = next(generators[i])
116
-
117
- if yielded is None:
118
- continue
119
-
120
- responses[i] += yielded
121
- stop = False
122
-
123
- yield responses + models
124
-
125
- except StopIteration:
126
- pass
127
-
128
- # TODO(#1): Narrow down the exception type.
129
- except Exception as e: # pylint: disable=broad-except
130
- print(f"Error in generator: {e}")
131
- raise e
132
-
133
- if stop:
134
- break
135
 
136
 
137
  with gr.Blocks(title="Arena") as app:
138
  with gr.Row():
139
- response_type_radio = gr.Radio(
140
- [response_type.value for response_type in ResponseType],
141
- label="Response type",
142
- info="Choose the type of response you want from the model.")
143
 
144
  source_language = gr.Dropdown(
145
  choices=SUPPORTED_TRANSLATION_LANGUAGES,
@@ -154,15 +77,15 @@ with gr.Blocks(title="Arena") as app:
154
  interactive=True,
155
  visible=False)
156
 
157
- def update_language_visibility(response_type):
158
- visible = response_type == ResponseType.TRANSLATE.value
159
  return {
160
  source_language: gr.Dropdown(visible=visible),
161
  target_language: gr.Dropdown(visible=visible)
162
  }
163
 
164
- response_type_radio.change(update_language_visibility, response_type_radio,
165
- [source_language, target_language])
166
 
167
  model_names = [gr.State(None), gr.State(None)]
168
  response_boxes = [gr.State(None), gr.State(None)]
@@ -175,7 +98,7 @@ with gr.Blocks(title="Arena") as app:
175
  response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
176
 
177
  # TODO(#5): Display it only after the user submits the prompt.
178
- # TODO(#6): Block voting if the response_type is not set.
179
  # TODO(#6): Block voting if the user already voted.
180
  with gr.Row():
181
  option_a = gr.Button(VoteOptions.MODEL_A.value)
@@ -188,10 +111,15 @@ with gr.Blocks(title="Arena") as app:
188
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
189
  model_names[1] = gr.Textbox(label="Model B", interactive=False)
190
 
191
- submit.click(get_responses, prompt, response_boxes + model_names)
 
 
 
 
192
 
193
  common_inputs = response_boxes + model_names + [
194
- prompt, response_type_radio, source_language, target_language
 
195
  ]
196
  option_a.click(vote, [option_a] + common_inputs)
197
  option_b.click(vote, [option_b] + common_inputs)
 
3
  """
4
 
5
  import enum
 
6
  from uuid import uuid4
7
 
8
  import firebase_admin
9
  from firebase_admin import firestore
10
  import gradio as gr
 
11
 
12
  from leaderboard import build_leaderboard
13
+ import response
14
+ from response import get_responses
15
 
16
  # TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
17
  db_app = firebase_admin.initialize_app()
18
  db = firestore.client()
19
 
 
 
 
 
 
20
  SUPPORTED_TRANSLATION_LANGUAGES = [
21
  "Korean", "English", "Chinese", "Japanese", "Spanish", "French"
22
  ]
23
 
24
 
 
 
 
 
 
25
  class VoteOptions(enum.Enum):
26
  MODEL_A = "Model A is better"
27
  MODEL_B = "Model B is better"
 
29
 
30
 
31
  def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
32
+ user_prompt, instruction, category, source_lang, target_lang):
33
  doc_id = uuid4().hex
34
  winner = VoteOptions(vote_button).name.lower()
35
 
36
+ doc = {
37
+ "id": doc_id,
38
+ "prompt": user_prompt,
39
+ "instruction": instruction,
40
+ "model_a": model_a_name,
41
+ "model_b": model_b_name,
42
+ "model_a_response": response_a,
43
+ "model_b_response": response_b,
44
+ "winner": winner,
45
+ "timestamp": firestore.SERVER_TIMESTAMP
46
+ }
47
+
48
+ if category == response.Category.SUMMARIZE.value:
49
  doc_ref = db.collection("arena-summarizations").document(doc_id)
50
+ doc_ref.set(doc)
 
 
 
 
 
 
 
 
 
51
  return
52
 
53
+ if category == response.Category.TRANSLATE.value:
54
  doc_ref = db.collection("arena-translations").document(doc_id)
55
+ doc["source_lang"] = source_lang.lower()
56
+ doc["target_lang"] = target_lang.lower()
57
+ doc_ref.set(doc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
 
60
  with gr.Blocks(title="Arena") as app:
61
  with gr.Row():
62
+ category_radio = gr.Radio(
63
+ [category.value for category in response.Category],
64
+ label="Category",
65
+ info="The chosen category determines the instruction sent to the LLMs.")
66
 
67
  source_language = gr.Dropdown(
68
  choices=SUPPORTED_TRANSLATION_LANGUAGES,
 
77
  interactive=True,
78
  visible=False)
79
 
80
+ def update_language_visibility(category):
81
+ visible = category == response.Category.TRANSLATE.value
82
  return {
83
  source_language: gr.Dropdown(visible=visible),
84
  target_language: gr.Dropdown(visible=visible)
85
  }
86
 
87
+ category_radio.change(update_language_visibility, category_radio,
88
+ [source_language, target_language])
89
 
90
  model_names = [gr.State(None), gr.State(None)]
91
  response_boxes = [gr.State(None), gr.State(None)]
 
98
  response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
99
 
100
  # TODO(#5): Display it only after the user submits the prompt.
101
+ # TODO(#6): Block voting if the category is not set.
102
  # TODO(#6): Block voting if the user already voted.
103
  with gr.Row():
104
  option_a = gr.Button(VoteOptions.MODEL_A.value)
 
111
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
112
  model_names[1] = gr.Textbox(label="Model B", interactive=False)
113
 
114
+ instruction_state = gr.State("")
115
+
116
+ submit.click(get_responses,
117
+ [prompt, category_radio, source_language, target_language],
118
+ response_boxes + model_names + [instruction_state])
119
 
120
  common_inputs = response_boxes + model_names + [
121
+ prompt, instruction_state, category_radio, source_language,
122
+ target_language
123
  ]
124
  option_a.click(vote, [option_a] + common_inputs)
125
  option_b.click(vote, [option_b] + common_inputs)
response.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This module contains functions for generating responses using LLMs.
3
+ """
4
+
5
+ import enum
6
+ from random import sample
7
+
8
+ import gradio as gr
9
+ from litellm import completion
10
+
11
+ # TODO(#1): Add more models.
12
+ SUPPORTED_MODELS = [
13
+ "gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
14
+ ]
15
+
16
+
17
+ class Category(enum.Enum):
18
+ SUMMARIZE = "Summarize"
19
+ TRANSLATE = "Translate"
20
+
21
+
22
+ def get_instruction(category, source_lang, target_lang):
23
+ if category == Category.SUMMARIZE.value:
24
+ return "Summarize the following text in its original language."
25
+ if category == Category.TRANSLATE.value:
26
+ return f"Translate the following text from {source_lang} to {target_lang}."
27
+
28
+
29
+ def response_generator(response: str):
30
+ for part in response:
31
+ content = part.choices[0].delta.content
32
+ if content is None:
33
+ continue
34
+
35
+ # To simulate a stream, we yield each character of the response.
36
+ for character in content:
37
+ yield character
38
+
39
+
40
+ def get_responses(user_prompt, category, source_lang, target_lang):
41
+ if not category:
42
+ raise gr.Error("Please select a category.")
43
+
44
+ if category == Category.TRANSLATE.value and (not source_lang or
45
+ not target_lang):
46
+ raise gr.Error("Please select source and target languages.")
47
+
48
+ models = sample(SUPPORTED_MODELS, 2)
49
+ instruction = get_instruction(category, source_lang, target_lang)
50
+
51
+ generators = []
52
+ for model in models:
53
+ try:
54
+ # TODO(#1): Allow user to set configuration.
55
+ response = completion(model=model,
56
+ messages=[{
57
+ "content": instruction,
58
+ "role": "system"
59
+ }, {
60
+ "content": user_prompt,
61
+ "role": "user"
62
+ }],
63
+ stream=True)
64
+ generators.append(response_generator(response))
65
+
66
+ # TODO(#1): Narrow down the exception type.
67
+ except Exception as e: # pylint: disable=broad-except
68
+ print(f"Error in bot_response: {e}")
69
+ raise e
70
+
71
+ responses = ["", ""]
72
+
73
+ # It simulates concurrent response generation from two models.
74
+ while True:
75
+ stop = True
76
+
77
+ for i in range(len(generators)):
78
+ try:
79
+ yielded = next(generators[i])
80
+
81
+ if yielded is None:
82
+ continue
83
+
84
+ responses[i] += yielded
85
+ stop = False
86
+
87
+ yield responses + models + [instruction]
88
+
89
+ except StopIteration:
90
+ pass
91
+
92
+ # TODO(#1): Narrow down the exception type.
93
+ except Exception as e: # pylint: disable=broad-except
94
+ print(f"Error in generator: {e}")
95
+ raise e
96
+
97
+ if stop:
98
+ break