Kang Suhyun commited on
Commit
43c8549
1 Parent(s): 2a0aa5a

[#71] Add custom prompt option (#77)

Browse files

* [#71] Add custom prompt option

Changes:
- It will look for a summarizationInstruction or a translationInstruction config for each model and use it as a prompt
- A default prompt is used if no instruction is found
- The prompt requests a JSON format response

* Update prompts

* Reorder message content in check_models function

* prompt -> instruction

Files changed (3) hide show
  1. app.py +8 -5
  2. model.py +23 -18
  3. response.py +17 -16
app.py CHANGED
@@ -28,7 +28,7 @@ class VoteOptions(enum.Enum):
28
 
29
 
30
  def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
31
- user_prompt, instruction, category, source_lang, target_lang):
32
  doc_id = uuid4().hex
33
  winner = VoteOptions(vote_button).name.lower()
34
 
@@ -37,7 +37,7 @@ def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
37
 
38
  doc = {
39
  "id": doc_id,
40
- "prompt": user_prompt,
41
  "instruction": instruction,
42
  "model_a": model_a_name,
43
  "model_b": model_b_name,
@@ -116,7 +116,7 @@ with gr.Blocks(title="Arena", css=css) as app:
116
  model_names = [gr.State(None), gr.State(None)]
117
  response_boxes = [gr.State(None), gr.State(None)]
118
 
119
- prompt = gr.TextArea(label="Prompt", lines=4)
120
  submit = gr.Button()
121
 
122
  with gr.Group():
@@ -166,7 +166,10 @@ with gr.Blocks(title="Arena", css=css) as app:
166
  category_radio, source_language, target_language, submit, vote_row,
167
  model_name_row
168
  ]).then(fn=get_responses,
169
- inputs=[prompt, category_radio, source_language, target_language],
 
 
 
170
  outputs=response_boxes + model_names + [instruction_state])
171
  submit_event.success(fn=lambda: gr.Row(visible=True), outputs=vote_row)
172
  submit_event.then(
@@ -179,7 +182,7 @@ with gr.Blocks(title="Arena", css=css) as app:
179
  outputs=[category_radio, source_language, target_language, submit])
180
 
181
  common_inputs = response_boxes + model_names + [
182
- prompt, instruction_state, category_radio, source_language,
183
  target_language
184
  ]
185
  common_outputs = [option_a, option_b, tie, model_name_row]
 
28
 
29
 
30
  def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
31
+ prompt, instruction, category, source_lang, target_lang):
32
  doc_id = uuid4().hex
33
  winner = VoteOptions(vote_button).name.lower()
34
 
 
37
 
38
  doc = {
39
  "id": doc_id,
40
+ "prompt": prompt,
41
  "instruction": instruction,
42
  "model_a": model_a_name,
43
  "model_b": model_b_name,
 
116
  model_names = [gr.State(None), gr.State(None)]
117
  response_boxes = [gr.State(None), gr.State(None)]
118
 
119
+ prompt_textarea = gr.TextArea(label="Prompt", lines=4)
120
  submit = gr.Button()
121
 
122
  with gr.Group():
 
166
  category_radio, source_language, target_language, submit, vote_row,
167
  model_name_row
168
  ]).then(fn=get_responses,
169
+ inputs=[
170
+ prompt_textarea, category_radio, source_language,
171
+ target_language
172
+ ],
173
  outputs=response_boxes + model_names + [instruction_state])
174
  submit_event.success(fn=lambda: gr.Row(visible=True), outputs=vote_row)
175
  submit_event.then(
 
182
  outputs=[category_radio, source_language, target_language, submit])
183
 
184
  common_inputs = response_boxes + model_names + [
185
+ prompt_textarea, instruction_state, category_radio, source_language,
186
  target_language
187
  ]
188
  common_outputs = [option_a, option_b, tie, model_name_row]
model.py CHANGED
@@ -25,6 +25,9 @@ decoded_secret = models_secret.payload.data.decode("UTF-8")
25
 
26
  supported_models_json = json.loads(decoded_secret)
27
 
 
 
 
28
 
29
  class Model:
30
 
@@ -35,11 +38,25 @@ class Model:
35
  # The JSON keys are in camelCase. To unpack these keys into
36
  # Model attributes, we need to use the same camelCase names.
37
  apiKey: str = None, # pylint: disable=invalid-name
38
- apiBase: str = None): # pylint: disable=invalid-name
 
 
39
  self.name = name
40
  self.provider = provider
41
  self.api_key = apiKey
42
  self.api_base = apiBase
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
 
45
  supported_models: List[Model] = [
@@ -48,27 +65,15 @@ supported_models: List[Model] = [
48
  ]
49
 
50
 
51
- def completion(model: Model, messages: List, max_tokens: float = None) -> str:
52
- response = litellm.completion(model=model.provider + "/" +
53
- model.name if model.provider else model.name,
54
- api_key=model.api_key,
55
- api_base=model.api_base,
56
- messages=messages,
57
- max_tokens=max_tokens)
58
-
59
- return response.choices[0].message.content
60
-
61
-
62
  def check_models(models: List[Model]):
63
  for model in models:
64
  print(f"Checking model {model.name}...")
65
  try:
66
- completion(model=model,
67
- messages=[{
68
- "content": "Hello.",
69
- "role": "user"
70
- }],
71
- max_tokens=5)
72
  print(f"Model {model.name} is available.")
73
 
74
  # This check is designed to verify the availability of the models
 
25
 
26
  supported_models_json = json.loads(decoded_secret)
27
 
28
+ DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the language of the text." # pylint: disable=line-too-long
29
+ DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
30
+
31
 
32
  class Model:
33
 
 
38
  # The JSON keys are in camelCase. To unpack these keys into
39
  # Model attributes, we need to use the same camelCase names.
40
  apiKey: str = None, # pylint: disable=invalid-name
41
+ apiBase: str = None, # pylint: disable=invalid-name
42
+ summarizeInstruction: str = None, # pylint: disable=invalid-name
43
+ translateInstruction: str = None): # pylint: disable=invalid-name
44
  self.name = name
45
  self.provider = provider
46
  self.api_key = apiKey
47
  self.api_base = apiBase
48
+ self.summarize_instruction = summarizeInstruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
49
+ self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
50
+
51
+ def completion(self, messages: List, max_tokens: float = None) -> str:
52
+ response = litellm.completion(model=self.provider + "/" +
53
+ self.name if self.provider else self.name,
54
+ api_key=self.api_key,
55
+ api_base=self.api_base,
56
+ messages=messages,
57
+ max_tokens=max_tokens)
58
+
59
+ return response.choices[0].message.content
60
 
61
 
62
  supported_models: List[Model] = [
 
65
  ]
66
 
67
 
 
 
 
 
 
 
 
 
 
 
 
68
  def check_models(models: List[Model]):
69
  for model in models:
70
  print(f"Checking model {model.name}...")
71
  try:
72
+ model.completion(messages=[{
73
+ "role": "user",
74
+ "content": "Hello."
75
+ }],
76
+ max_tokens=5)
 
77
  print(f"Model {model.name} is available.")
78
 
79
  # This check is designed to verify the availability of the models
response.py CHANGED
@@ -11,7 +11,6 @@ from firebase_admin import firestore
11
  import gradio as gr
12
 
13
  from leaderboard import db
14
- from model import completion
15
  from model import Model
16
  from model import supported_models
17
 
@@ -39,14 +38,18 @@ class Category(enum.Enum):
39
 
40
 
41
  # TODO(#31): Let the model builders set the instruction.
42
- def get_instruction(category, source_lang, target_lang):
 
43
  if category == Category.SUMMARIZE.value:
44
- return "Summarize the following text, maintaining the original language of the text in the summary." # pylint: disable=line-too-long
 
45
  if category == Category.TRANSLATE.value:
46
- return f"Translate the following text from {source_lang} to {target_lang}."
 
47
 
48
 
49
- def get_responses(user_prompt, category, source_lang, target_lang):
 
50
  if not category:
51
  raise gr.Error("Please select a category.")
52
 
@@ -55,21 +58,19 @@ def get_responses(user_prompt, category, source_lang, target_lang):
55
  raise gr.Error("Please select source and target languages.")
56
 
57
  models: List[Model] = sample(list(supported_models), 2)
58
- instruction = get_instruction(category, source_lang, target_lang)
59
-
60
  responses = []
61
  for model in models:
 
62
  try:
63
  # TODO(#1): Allow user to set configuration.
64
- response = completion(model=model,
65
- messages=[{
66
- "role": "system",
67
- "content": instruction
68
- }, {
69
- "role": "user",
70
- "content": user_prompt
71
- }])
72
- create_history(model.name, instruction, user_prompt, response)
73
  responses.append(response)
74
 
75
  # TODO(#1): Narrow down the exception type.
 
11
  import gradio as gr
12
 
13
  from leaderboard import db
 
14
  from model import Model
15
  from model import supported_models
16
 
 
38
 
39
 
40
  # TODO(#31): Let the model builders set the instruction.
41
+ def get_instruction(category: str, model: Model, source_lang: str,
42
+ target_lang: str):
43
  if category == Category.SUMMARIZE.value:
44
+ return model.summarize_instruction
45
+
46
  if category == Category.TRANSLATE.value:
47
+ return model.translate_instruction.format(source_lang=source_lang,
48
+ target_lang=target_lang)
49
 
50
 
51
+ def get_responses(prompt: str, category: str, source_lang: str,
52
+ target_lang: str):
53
  if not category:
54
  raise gr.Error("Please select a category.")
55
 
 
58
  raise gr.Error("Please select source and target languages.")
59
 
60
  models: List[Model] = sample(list(supported_models), 2)
 
 
61
  responses = []
62
  for model in models:
63
+ instruction = get_instruction(category, model, source_lang, target_lang)
64
  try:
65
  # TODO(#1): Allow user to set configuration.
66
+ response = model.completion(messages=[{
67
+ "role": "system",
68
+ "content": instruction
69
+ }, {
70
+ "role": "user",
71
+ "content": prompt
72
+ }])
73
+ create_history(model.name, instruction, prompt, response)
 
74
  responses.append(response)
75
 
76
  # TODO(#1): Narrow down the exception type.