Spaces:
Running
Running
Kang Suhyun
commited on
Commit
•
43c8549
1
Parent(s):
2a0aa5a
[#71] Add custom prompt option (#77)
Browse files* [#71] Add custom prompt option
Changes:
- It will look for a summarizationInstruction or a translationInstruction config for each model and use it as a prompt
- A default prompt is used if no instruction is found
- The prompt requests a JSON format response
* Update prompts
* Reorder message content in check_models function
* prompt -> instruction
- app.py +8 -5
- model.py +23 -18
- response.py +17 -16
app.py
CHANGED
@@ -28,7 +28,7 @@ class VoteOptions(enum.Enum):
|
|
28 |
|
29 |
|
30 |
def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
|
31 |
-
|
32 |
doc_id = uuid4().hex
|
33 |
winner = VoteOptions(vote_button).name.lower()
|
34 |
|
@@ -37,7 +37,7 @@ def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
|
|
37 |
|
38 |
doc = {
|
39 |
"id": doc_id,
|
40 |
-
"prompt":
|
41 |
"instruction": instruction,
|
42 |
"model_a": model_a_name,
|
43 |
"model_b": model_b_name,
|
@@ -116,7 +116,7 @@ with gr.Blocks(title="Arena", css=css) as app:
|
|
116 |
model_names = [gr.State(None), gr.State(None)]
|
117 |
response_boxes = [gr.State(None), gr.State(None)]
|
118 |
|
119 |
-
|
120 |
submit = gr.Button()
|
121 |
|
122 |
with gr.Group():
|
@@ -166,7 +166,10 @@ with gr.Blocks(title="Arena", css=css) as app:
|
|
166 |
category_radio, source_language, target_language, submit, vote_row,
|
167 |
model_name_row
|
168 |
]).then(fn=get_responses,
|
169 |
-
inputs=[
|
|
|
|
|
|
|
170 |
outputs=response_boxes + model_names + [instruction_state])
|
171 |
submit_event.success(fn=lambda: gr.Row(visible=True), outputs=vote_row)
|
172 |
submit_event.then(
|
@@ -179,7 +182,7 @@ with gr.Blocks(title="Arena", css=css) as app:
|
|
179 |
outputs=[category_radio, source_language, target_language, submit])
|
180 |
|
181 |
common_inputs = response_boxes + model_names + [
|
182 |
-
|
183 |
target_language
|
184 |
]
|
185 |
common_outputs = [option_a, option_b, tie, model_name_row]
|
|
|
28 |
|
29 |
|
30 |
def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
|
31 |
+
prompt, instruction, category, source_lang, target_lang):
|
32 |
doc_id = uuid4().hex
|
33 |
winner = VoteOptions(vote_button).name.lower()
|
34 |
|
|
|
37 |
|
38 |
doc = {
|
39 |
"id": doc_id,
|
40 |
+
"prompt": prompt,
|
41 |
"instruction": instruction,
|
42 |
"model_a": model_a_name,
|
43 |
"model_b": model_b_name,
|
|
|
116 |
model_names = [gr.State(None), gr.State(None)]
|
117 |
response_boxes = [gr.State(None), gr.State(None)]
|
118 |
|
119 |
+
prompt_textarea = gr.TextArea(label="Prompt", lines=4)
|
120 |
submit = gr.Button()
|
121 |
|
122 |
with gr.Group():
|
|
|
166 |
category_radio, source_language, target_language, submit, vote_row,
|
167 |
model_name_row
|
168 |
]).then(fn=get_responses,
|
169 |
+
inputs=[
|
170 |
+
prompt_textarea, category_radio, source_language,
|
171 |
+
target_language
|
172 |
+
],
|
173 |
outputs=response_boxes + model_names + [instruction_state])
|
174 |
submit_event.success(fn=lambda: gr.Row(visible=True), outputs=vote_row)
|
175 |
submit_event.then(
|
|
|
182 |
outputs=[category_radio, source_language, target_language, submit])
|
183 |
|
184 |
common_inputs = response_boxes + model_names + [
|
185 |
+
prompt_textarea, instruction_state, category_radio, source_language,
|
186 |
target_language
|
187 |
]
|
188 |
common_outputs = [option_a, option_b, tie, model_name_row]
|
model.py
CHANGED
@@ -25,6 +25,9 @@ decoded_secret = models_secret.payload.data.decode("UTF-8")
|
|
25 |
|
26 |
supported_models_json = json.loads(decoded_secret)
|
27 |
|
|
|
|
|
|
|
28 |
|
29 |
class Model:
|
30 |
|
@@ -35,11 +38,25 @@ class Model:
|
|
35 |
# The JSON keys are in camelCase. To unpack these keys into
|
36 |
# Model attributes, we need to use the same camelCase names.
|
37 |
apiKey: str = None, # pylint: disable=invalid-name
|
38 |
-
apiBase: str = None
|
|
|
|
|
39 |
self.name = name
|
40 |
self.provider = provider
|
41 |
self.api_key = apiKey
|
42 |
self.api_base = apiBase
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
|
45 |
supported_models: List[Model] = [
|
@@ -48,27 +65,15 @@ supported_models: List[Model] = [
|
|
48 |
]
|
49 |
|
50 |
|
51 |
-
def completion(model: Model, messages: List, max_tokens: float = None) -> str:
|
52 |
-
response = litellm.completion(model=model.provider + "/" +
|
53 |
-
model.name if model.provider else model.name,
|
54 |
-
api_key=model.api_key,
|
55 |
-
api_base=model.api_base,
|
56 |
-
messages=messages,
|
57 |
-
max_tokens=max_tokens)
|
58 |
-
|
59 |
-
return response.choices[0].message.content
|
60 |
-
|
61 |
-
|
62 |
def check_models(models: List[Model]):
|
63 |
for model in models:
|
64 |
print(f"Checking model {model.name}...")
|
65 |
try:
|
66 |
-
completion(
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
max_tokens=5)
|
72 |
print(f"Model {model.name} is available.")
|
73 |
|
74 |
# This check is designed to verify the availability of the models
|
|
|
25 |
|
26 |
supported_models_json = json.loads(decoded_secret)
|
27 |
|
28 |
+
DEFAULT_SUMMARIZE_INSTRUCTION = "Summarize the following text, maintaining the language of the text." # pylint: disable=line-too-long
|
29 |
+
DEFAULT_TRANSLATE_INSTRUCTION = "Translate the following text from {source_lang} to {target_lang}." # pylint: disable=line-too-long
|
30 |
+
|
31 |
|
32 |
class Model:
|
33 |
|
|
|
38 |
# The JSON keys are in camelCase. To unpack these keys into
|
39 |
# Model attributes, we need to use the same camelCase names.
|
40 |
apiKey: str = None, # pylint: disable=invalid-name
|
41 |
+
apiBase: str = None, # pylint: disable=invalid-name
|
42 |
+
summarizeInstruction: str = None, # pylint: disable=invalid-name
|
43 |
+
translateInstruction: str = None): # pylint: disable=invalid-name
|
44 |
self.name = name
|
45 |
self.provider = provider
|
46 |
self.api_key = apiKey
|
47 |
self.api_base = apiBase
|
48 |
+
self.summarize_instruction = summarizeInstruction or DEFAULT_SUMMARIZE_INSTRUCTION # pylint: disable=line-too-long
|
49 |
+
self.translate_instruction = translateInstruction or DEFAULT_TRANSLATE_INSTRUCTION # pylint: disable=line-too-long
|
50 |
+
|
51 |
+
def completion(self, messages: List, max_tokens: float = None) -> str:
|
52 |
+
response = litellm.completion(model=self.provider + "/" +
|
53 |
+
self.name if self.provider else self.name,
|
54 |
+
api_key=self.api_key,
|
55 |
+
api_base=self.api_base,
|
56 |
+
messages=messages,
|
57 |
+
max_tokens=max_tokens)
|
58 |
+
|
59 |
+
return response.choices[0].message.content
|
60 |
|
61 |
|
62 |
supported_models: List[Model] = [
|
|
|
65 |
]
|
66 |
|
67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
def check_models(models: List[Model]):
|
69 |
for model in models:
|
70 |
print(f"Checking model {model.name}...")
|
71 |
try:
|
72 |
+
model.completion(messages=[{
|
73 |
+
"role": "user",
|
74 |
+
"content": "Hello."
|
75 |
+
}],
|
76 |
+
max_tokens=5)
|
|
|
77 |
print(f"Model {model.name} is available.")
|
78 |
|
79 |
# This check is designed to verify the availability of the models
|
response.py
CHANGED
@@ -11,7 +11,6 @@ from firebase_admin import firestore
|
|
11 |
import gradio as gr
|
12 |
|
13 |
from leaderboard import db
|
14 |
-
from model import completion
|
15 |
from model import Model
|
16 |
from model import supported_models
|
17 |
|
@@ -39,14 +38,18 @@ class Category(enum.Enum):
|
|
39 |
|
40 |
|
41 |
# TODO(#31): Let the model builders set the instruction.
|
42 |
-
def get_instruction(category,
|
|
|
43 |
if category == Category.SUMMARIZE.value:
|
44 |
-
return
|
|
|
45 |
if category == Category.TRANSLATE.value:
|
46 |
-
return
|
|
|
47 |
|
48 |
|
49 |
-
def get_responses(
|
|
|
50 |
if not category:
|
51 |
raise gr.Error("Please select a category.")
|
52 |
|
@@ -55,21 +58,19 @@ def get_responses(user_prompt, category, source_lang, target_lang):
|
|
55 |
raise gr.Error("Please select source and target languages.")
|
56 |
|
57 |
models: List[Model] = sample(list(supported_models), 2)
|
58 |
-
instruction = get_instruction(category, source_lang, target_lang)
|
59 |
-
|
60 |
responses = []
|
61 |
for model in models:
|
|
|
62 |
try:
|
63 |
# TODO(#1): Allow user to set configuration.
|
64 |
-
response = completion(
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
create_history(model.name, instruction, user_prompt, response)
|
73 |
responses.append(response)
|
74 |
|
75 |
# TODO(#1): Narrow down the exception type.
|
|
|
11 |
import gradio as gr
|
12 |
|
13 |
from leaderboard import db
|
|
|
14 |
from model import Model
|
15 |
from model import supported_models
|
16 |
|
|
|
38 |
|
39 |
|
40 |
# TODO(#31): Let the model builders set the instruction.
|
41 |
+
def get_instruction(category: str, model: Model, source_lang: str,
|
42 |
+
target_lang: str):
|
43 |
if category == Category.SUMMARIZE.value:
|
44 |
+
return model.summarize_instruction
|
45 |
+
|
46 |
if category == Category.TRANSLATE.value:
|
47 |
+
return model.translate_instruction.format(source_lang=source_lang,
|
48 |
+
target_lang=target_lang)
|
49 |
|
50 |
|
51 |
+
def get_responses(prompt: str, category: str, source_lang: str,
|
52 |
+
target_lang: str):
|
53 |
if not category:
|
54 |
raise gr.Error("Please select a category.")
|
55 |
|
|
|
58 |
raise gr.Error("Please select source and target languages.")
|
59 |
|
60 |
models: List[Model] = sample(list(supported_models), 2)
|
|
|
|
|
61 |
responses = []
|
62 |
for model in models:
|
63 |
+
instruction = get_instruction(category, model, source_lang, target_lang)
|
64 |
try:
|
65 |
# TODO(#1): Allow user to set configuration.
|
66 |
+
response = model.completion(messages=[{
|
67 |
+
"role": "system",
|
68 |
+
"content": instruction
|
69 |
+
}, {
|
70 |
+
"role": "user",
|
71 |
+
"content": prompt
|
72 |
+
}])
|
73 |
+
create_history(model.name, instruction, prompt, response)
|
|
|
74 |
responses.append(response)
|
75 |
|
76 |
# TODO(#1): Narrow down the exception type.
|