Kang Suhyun commited on
Commit
486e533
2 Parent(s): 0ac094d 000d4f2

Merge pull request #22 from Y-IAB/19-fastchat

Browse files
Files changed (4) hide show
  1. .gitignore +1 -0
  2. README.md +4 -2
  3. app.py +52 -72
  4. requirments.txt +22 -41
.gitignore CHANGED
@@ -1,2 +1,3 @@
1
  venv
2
  *.log
 
 
1
  venv
2
  *.log
3
+ __pycache__
README.md CHANGED
@@ -19,7 +19,9 @@
19
  Set your OpenAI API key as an environment variable and start the application:
20
 
21
  ```shell
22
- GCP_PROJECT_ID=<your project id> OPENAI_API_KEY=<your key> python3 app.py
23
  ```
24
 
25
- Replace <your project id> and <your key> with your GCP project ID and OpenAI API key respectively.
 
 
 
19
  Set your OpenAI API key as an environment variable and start the application:
20
 
21
  ```shell
22
+ OPENAI_API_KEY=<your key> python3 app.py
23
  ```
24
 
25
+ Replace `<your key>` with your GCP project ID.
26
+
27
+ > To run the app with [auto-reloading](https://www.gradio.app/guides/developing-faster-with-reload-mode), use `gradio app.py --demo-name app` instead of `python3 app.py`.
app.py CHANGED
@@ -3,21 +3,22 @@ It provides a platform for comparing the responses of two LLMs.
3
  """
4
 
5
  import enum
6
- import json
7
  from random import sample
8
  from uuid import uuid4
9
 
10
- from fastchat.serve import gradio_web_server
11
- from fastchat.serve.gradio_web_server import bot_response
12
  import firebase_admin
13
  from firebase_admin import firestore
14
  import gradio as gr
 
15
 
 
16
  db_app = firebase_admin.initialize_app()
17
  db = firestore.client()
18
 
19
  # TODO(#1): Add more models.
20
- SUPPORTED_MODELS = ["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gemini-pro"]
 
 
21
 
22
  # TODO(#4): Add more languages.
23
  SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
@@ -34,23 +35,20 @@ class VoteOptions(enum.Enum):
34
  TIE = "Tie"
35
 
36
 
37
- def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
 
38
  doc_id = uuid4().hex
39
  winner = VoteOptions(vote_button).name.lower()
40
 
41
- # The 'messages' field in the state is an array of arrays, which is
42
- # not supported by Firestore. Therefore, we convert it to a JSON string.
43
- model_a_conv = json.dumps(state_a.dict())
44
- model_b_conv = json.dumps(state_b.dict())
45
-
46
  if res_type == ResponseType.SUMMARIZE.value:
47
  doc_ref = db.collection("arena-summarizations").document(doc_id)
48
  doc_ref.set({
49
  "id": doc_id,
50
- "model_a": state_a.model_name,
51
- "model_b": state_b.model_name,
52
- "model_a_conv": model_a_conv,
53
- "model_b_conv": model_b_conv,
 
54
  "winner": winner,
55
  "timestamp": firestore.SERVER_TIMESTAMP
56
  })
@@ -60,10 +58,11 @@ def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
60
  doc_ref = db.collection("arena-translations").document(doc_id)
61
  doc_ref.set({
62
  "id": doc_id,
63
- "model_a": state_a.model_name,
64
- "model_b": state_b.model_name,
65
- "model_a_conv": model_a_conv,
66
- "model_b_conv": model_b_conv,
 
67
  "source_language": source_lang.lower(),
68
  "target_language": target_lang.lower(),
69
  "winner": winner,
@@ -71,42 +70,38 @@ def vote(state_a, state_b, vote_button, res_type, source_lang, target_lang):
71
  })
72
 
73
 
74
- def user(user_prompt):
75
- model_pair = sample(SUPPORTED_MODELS, 2)
76
- new_state_a = gradio_web_server.State(model_pair[0])
77
- new_state_b = gradio_web_server.State(model_pair[1])
78
-
79
- for state in [new_state_a, new_state_b]:
80
- state.conv.append_message(state.conv.roles[0], user_prompt)
81
- state.conv.append_message(state.conv.roles[1], None)
82
- state.skip_next = False
83
 
84
- return [
85
- new_state_a, new_state_b, new_state_a.model_name, new_state_b.model_name
86
- ]
87
 
88
 
89
- def bot(state_a, state_b, request: gr.Request):
90
- new_states = [state_a, state_b]
91
 
92
  generators = []
93
- for state in new_states:
94
  try:
95
  # TODO(#1): Allow user to set configuration.
96
- # bot_response returns a generator yielding states.
97
- generator = bot_response(state,
98
- temperature=0.9,
99
- top_p=0.9,
100
- max_new_tokens=100,
101
- request=request)
102
- generators.append(generator)
103
 
104
  # TODO(#1): Narrow down the exception type.
105
  except Exception as e: # pylint: disable=broad-except
106
  print(f"Error in bot_response: {e}")
107
  raise e
108
 
109
- new_responses = [None, None]
110
 
111
  # It simulates concurrent response generation from two models.
112
  while True:
@@ -116,19 +111,14 @@ def bot(state_a, state_b, request: gr.Request):
116
  try:
117
  yielded = next(generators[i])
118
 
119
- # The generator yields a tuple, with the new state as the first item.
120
- new_state = yielded[0]
121
- new_states[i] = new_state
122
-
123
- # The last item from 'messages' represents the response to the prompt.
124
- bot_message = new_state.conv.messages[-1]
125
-
126
- # Each message in conv.messages is structured as [role, message],
127
- # so we extract the last message component.
128
- new_responses[i] = bot_message[-1]
129
 
 
130
  stop = False
131
 
 
 
132
  except StopIteration:
133
  pass
134
 
@@ -137,8 +127,6 @@ def bot(state_a, state_b, request: gr.Request):
137
  print(f"Error in generator: {e}")
138
  raise e
139
 
140
- yield new_states + new_responses
141
-
142
  if stop:
143
  break
144
 
@@ -174,36 +162,22 @@ with gr.Blocks() as app:
174
  [source_language, target_language])
175
 
176
  model_names = [gr.State(None), gr.State(None)]
177
- responses = [gr.State(None), gr.State(None)]
178
-
179
- # states stores FastChat-specific conversation states.
180
- states = [gr.State(None), gr.State(None)]
181
 
182
  prompt = gr.TextArea(label="Prompt", lines=4)
183
  submit = gr.Button()
184
 
185
  with gr.Row():
186
- responses[0] = gr.Textbox(label="Model A", interactive=False)
187
- responses[1] = gr.Textbox(label="Model B", interactive=False)
188
 
189
  # TODO(#5): Display it only after the user submits the prompt.
190
  # TODO(#6): Block voting if the response_type is not set.
191
  # TODO(#6): Block voting if the user already voted.
192
  with gr.Row():
193
  option_a = gr.Button(VoteOptions.MODEL_A.value)
194
- option_a.click(
195
- vote, states +
196
- [option_a, response_type_radio, source_language, target_language])
197
-
198
  option_b = gr.Button("Model B is better")
199
- option_b.click(
200
- vote, states +
201
- [option_b, response_type_radio, source_language, target_language])
202
-
203
  tie = gr.Button("Tie")
204
- tie.click(
205
- vote,
206
- states + [tie, response_type_radio, source_language, target_language])
207
 
208
  # TODO(#7): Hide it until the user votes.
209
  with gr.Accordion("Show models", open=False):
@@ -211,8 +185,14 @@ with gr.Blocks() as app:
211
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
212
  model_names[1] = gr.Textbox(label="Model B", interactive=False)
213
 
214
- submit.click(user, prompt, states + model_names,
215
- queue=False).then(bot, states, states + responses)
 
 
 
 
 
 
216
 
217
  if __name__ == "__main__":
218
  # We need to enable queue to use generators.
 
3
  """
4
 
5
  import enum
 
6
  from random import sample
7
  from uuid import uuid4
8
 
 
 
9
  import firebase_admin
10
  from firebase_admin import firestore
11
  import gradio as gr
12
+ from litellm import completion
13
 
14
+ # TODO(#21): Fix auto-reload issue related to the initialization of Firebase.
15
  db_app = firebase_admin.initialize_app()
16
  db = firestore.client()
17
 
18
  # TODO(#1): Add more models.
19
+ SUPPORTED_MODELS = [
20
+ "gpt-4", "gpt-4-0125-preview", "gpt-3.5-turbo", "gemini-pro"
21
+ ]
22
 
23
  # TODO(#4): Add more languages.
24
  SUPPORTED_TRANSLATION_LANGUAGES = ["Korean", "English"]
 
35
  TIE = "Tie"
36
 
37
 
38
+ def vote(vote_button, response_a, response_b, model_a_name, model_b_name,
39
+ user_prompt, res_type, source_lang, target_lang):
40
  doc_id = uuid4().hex
41
  winner = VoteOptions(vote_button).name.lower()
42
 
 
 
 
 
 
43
  if res_type == ResponseType.SUMMARIZE.value:
44
  doc_ref = db.collection("arena-summarizations").document(doc_id)
45
  doc_ref.set({
46
  "id": doc_id,
47
+ "prompt": user_prompt,
48
+ "model_a": model_a_name,
49
+ "model_b": model_b_name,
50
+ "model_a_response": response_a,
51
+ "model_b_response": response_b,
52
  "winner": winner,
53
  "timestamp": firestore.SERVER_TIMESTAMP
54
  })
 
58
  doc_ref = db.collection("arena-translations").document(doc_id)
59
  doc_ref.set({
60
  "id": doc_id,
61
+ "prompt": user_prompt,
62
+ "model_a": model_a_name,
63
+ "model_b": model_b_name,
64
+ "model_a_response": response_a,
65
+ "model_b_response": response_b,
66
  "source_language": source_lang.lower(),
67
  "target_language": target_lang.lower(),
68
  "winner": winner,
 
70
  })
71
 
72
 
73
+ def response_generator(response: str):
74
+ for part in response:
75
+ content = part.choices[0].delta.content
76
+ if content is None:
77
+ continue
 
 
 
 
78
 
79
+ # To simulate a stream, we yield each character of the response.
80
+ for character in content:
81
+ yield character
82
 
83
 
84
+ def get_responses(user_prompt):
85
+ models = sample(SUPPORTED_MODELS, 2)
86
 
87
  generators = []
88
+ for model in models:
89
  try:
90
  # TODO(#1): Allow user to set configuration.
91
+ response = completion(model=model,
92
+ messages=[{
93
+ "content": user_prompt,
94
+ "role": "user"
95
+ }],
96
+ stream=True)
97
+ generators.append(response_generator(response))
98
 
99
  # TODO(#1): Narrow down the exception type.
100
  except Exception as e: # pylint: disable=broad-except
101
  print(f"Error in bot_response: {e}")
102
  raise e
103
 
104
+ responses = ["", ""]
105
 
106
  # It simulates concurrent response generation from two models.
107
  while True:
 
111
  try:
112
  yielded = next(generators[i])
113
 
114
+ if yielded is None:
115
+ continue
 
 
 
 
 
 
 
 
116
 
117
+ responses[i] += yielded
118
  stop = False
119
 
120
+ yield responses + models
121
+
122
  except StopIteration:
123
  pass
124
 
 
127
  print(f"Error in generator: {e}")
128
  raise e
129
 
 
 
130
  if stop:
131
  break
132
 
 
162
  [source_language, target_language])
163
 
164
  model_names = [gr.State(None), gr.State(None)]
165
+ response_boxes = [gr.State(None), gr.State(None)]
 
 
 
166
 
167
  prompt = gr.TextArea(label="Prompt", lines=4)
168
  submit = gr.Button()
169
 
170
  with gr.Row():
171
+ response_boxes[0] = gr.Textbox(label="Model A", interactive=False)
172
+ response_boxes[1] = gr.Textbox(label="Model B", interactive=False)
173
 
174
  # TODO(#5): Display it only after the user submits the prompt.
175
  # TODO(#6): Block voting if the response_type is not set.
176
  # TODO(#6): Block voting if the user already voted.
177
  with gr.Row():
178
  option_a = gr.Button(VoteOptions.MODEL_A.value)
 
 
 
 
179
  option_b = gr.Button("Model B is better")
 
 
 
 
180
  tie = gr.Button("Tie")
 
 
 
181
 
182
  # TODO(#7): Hide it until the user votes.
183
  with gr.Accordion("Show models", open=False):
 
185
  model_names[0] = gr.Textbox(label="Model A", interactive=False)
186
  model_names[1] = gr.Textbox(label="Model B", interactive=False)
187
 
188
+ submit.click(get_responses, prompt, response_boxes + model_names)
189
+
190
+ common_inputs = response_boxes + model_names + [
191
+ prompt, response_type_radio, source_language, target_language
192
+ ]
193
+ option_a.click(vote, [option_a] + common_inputs)
194
+ option_b.click(vote, [option_b] + common_inputs)
195
+ tie.click(vote, [tie] + common_inputs)
196
 
197
  if __name__ == "__main__":
198
  # We need to enable queue to use generators.
requirments.txt CHANGED
@@ -1,4 +1,3 @@
1
- accelerate==0.26.1
2
  aiofiles==23.2.1
3
  aiohttp==3.9.3
4
  aiosignal==1.3.1
@@ -6,9 +5,9 @@ altair==5.2.0
6
  annotated-types==0.6.0
7
  anyio==4.2.0
8
  attrs==23.2.0
9
- CacheControl==0.13.1
10
  cachetools==5.3.2
11
- certifi==2023.11.17
12
  cffi==1.16.0
13
  charset-normalizer==3.3.2
14
  click==8.1.7
@@ -17,76 +16,67 @@ contourpy==1.2.0
17
  cryptography==42.0.2
18
  cycler==0.12.1
19
  distro==1.9.0
20
- fastapi==0.109.0
21
  ffmpy==0.3.1
22
  filelock==3.13.1
23
  firebase-admin==6.4.0
24
  fonttools==4.47.2
25
  frozenlist==1.4.1
26
- fschat==0.2.35
27
- fsspec==2023.12.2
28
- google-api-core==2.16.1
29
  google-api-python-client==2.116.0
30
  google-auth==2.27.0
31
  google-auth-httplib2==0.2.0
32
- google-cloud-aiplatform==1.40.0
33
- google-cloud-bigquery==3.17.1
34
  google-cloud-core==2.4.1
35
  google-cloud-firestore==2.14.0
36
- google-cloud-resource-manager==1.11.0
37
  google-cloud-storage==2.14.0
38
  google-crc32c==1.5.0
39
  google-resumable-media==2.7.0
40
  googleapis-common-protos==1.62.0
41
- gradio==3.50.2
42
- gradio_client==0.6.1
43
- grpc-google-iam-v1==0.13.0
44
- grpcio==1.60.0
45
- grpcio-status==1.60.0
46
  h11==0.14.0
47
  httpcore==1.0.2
48
  httplib2==0.22.0
49
  httpx==0.26.0
50
  huggingface-hub==0.20.3
51
  idna==3.6
 
52
  importlib-resources==6.1.1
53
  Jinja2==3.1.3
54
  jsonschema==4.21.1
55
  jsonschema-specifications==2023.12.1
56
  kiwisolver==1.4.5
 
57
  markdown-it-py==3.0.0
58
- markdown2==2.4.12
59
- MarkupSafe==2.1.4
60
  matplotlib==3.8.2
61
  mdurl==0.1.2
62
- mpmath==1.3.0
63
  msgpack==1.0.7
64
- multidict==6.0.4
65
- networkx==3.2.1
66
- nh3==0.2.15
67
  numpy==1.26.3
68
- openai==0.28.0
69
- orjson==3.9.12
70
  packaging==23.2
71
  pandas==2.2.0
72
- peft==0.8.1
73
  pillow==10.2.0
74
- prompt-toolkit==3.0.43
75
  proto-plus==1.23.0
76
  protobuf==4.25.2
77
- psutil==5.9.8
78
  pyasn1==0.5.1
79
  pyasn1-modules==0.3.0
80
  pycparser==2.21
81
- pydantic==1.10.14
82
  pydantic_core==2.16.1
83
  pydub==0.25.1
84
  Pygments==2.17.2
85
  PyJWT==2.8.0
86
  pyparsing==3.1.1
87
  python-dateutil==2.8.2
88
- python-multipart==0.0.6
89
- pytz==2023.4
 
90
  PyYAML==6.0.1
91
  referencing==0.33.0
92
  regex==2023.12.25
@@ -94,32 +84,23 @@ requests==2.31.0
94
  rich==13.7.0
95
  rpds-py==0.17.1
96
  rsa==4.9
97
- ruff==0.1.15
98
- safetensors==0.4.2
99
  semantic-version==2.10.0
100
- sentencepiece==0.1.99
101
- shapely==2.0.2
102
  shellingham==1.5.4
103
- shortuuid==1.0.11
104
  six==1.16.0
105
  sniffio==1.3.0
106
- starlette==0.35.1
107
- svgwrite==1.4.3
108
- sympy==1.12
109
  tiktoken==0.5.2
110
  tokenizers==0.15.1
111
  tomlkit==0.12.0
112
  toolz==0.12.1
113
- torch==2.2.0
114
  tqdm==4.66.1
115
- transformers==4.37.2
116
  typer==0.9.0
117
  typing_extensions==4.9.0
118
  tzdata==2023.4
119
  uritemplate==4.1.1
120
  urllib3==2.2.0
121
  uvicorn==0.27.0.post1
122
- wavedrom==2.0.3.post3
123
- wcwidth==0.2.13
124
  websockets==11.0.3
125
  yarl==1.9.4
 
 
 
1
  aiofiles==23.2.1
2
  aiohttp==3.9.3
3
  aiosignal==1.3.1
 
5
  annotated-types==0.6.0
6
  anyio==4.2.0
7
  attrs==23.2.0
8
+ CacheControl==0.14.0
9
  cachetools==5.3.2
10
+ certifi==2024.2.2
11
  cffi==1.16.0
12
  charset-normalizer==3.3.2
13
  click==8.1.7
 
16
  cryptography==42.0.2
17
  cycler==0.12.1
18
  distro==1.9.0
19
+ fastapi==0.109.2
20
  ffmpy==0.3.1
21
  filelock==3.13.1
22
  firebase-admin==6.4.0
23
  fonttools==4.47.2
24
  frozenlist==1.4.1
25
+ fsspec==2024.2.0
26
+ google-api-core==2.16.2
 
27
  google-api-python-client==2.116.0
28
  google-auth==2.27.0
29
  google-auth-httplib2==0.2.0
 
 
30
  google-cloud-core==2.4.1
31
  google-cloud-firestore==2.14.0
 
32
  google-cloud-storage==2.14.0
33
  google-crc32c==1.5.0
34
  google-resumable-media==2.7.0
35
  googleapis-common-protos==1.62.0
36
+ gradio==4.16.0
37
+ gradio_client==0.8.1
38
+ grpcio==1.60.1
39
+ grpcio-status==1.60.1
 
40
  h11==0.14.0
41
  httpcore==1.0.2
42
  httplib2==0.22.0
43
  httpx==0.26.0
44
  huggingface-hub==0.20.3
45
  idna==3.6
46
+ importlib-metadata==7.0.1
47
  importlib-resources==6.1.1
48
  Jinja2==3.1.3
49
  jsonschema==4.21.1
50
  jsonschema-specifications==2023.12.1
51
  kiwisolver==1.4.5
52
+ litellm==1.22.3
53
  markdown-it-py==3.0.0
54
+ MarkupSafe==2.1.5
 
55
  matplotlib==3.8.2
56
  mdurl==0.1.2
 
57
  msgpack==1.0.7
58
+ multidict==6.0.5
 
 
59
  numpy==1.26.3
60
+ openai==1.11.1
61
+ orjson==3.9.13
62
  packaging==23.2
63
  pandas==2.2.0
 
64
  pillow==10.2.0
 
65
  proto-plus==1.23.0
66
  protobuf==4.25.2
 
67
  pyasn1==0.5.1
68
  pyasn1-modules==0.3.0
69
  pycparser==2.21
70
+ pydantic==2.6.0
71
  pydantic_core==2.16.1
72
  pydub==0.25.1
73
  Pygments==2.17.2
74
  PyJWT==2.8.0
75
  pyparsing==3.1.1
76
  python-dateutil==2.8.2
77
+ python-dotenv==1.0.1
78
+ python-multipart==0.0.7
79
+ pytz==2024.1
80
  PyYAML==6.0.1
81
  referencing==0.33.0
82
  regex==2023.12.25
 
84
  rich==13.7.0
85
  rpds-py==0.17.1
86
  rsa==4.9
87
+ ruff==0.2.0
 
88
  semantic-version==2.10.0
 
 
89
  shellingham==1.5.4
 
90
  six==1.16.0
91
  sniffio==1.3.0
92
+ starlette==0.36.3
 
 
93
  tiktoken==0.5.2
94
  tokenizers==0.15.1
95
  tomlkit==0.12.0
96
  toolz==0.12.1
 
97
  tqdm==4.66.1
 
98
  typer==0.9.0
99
  typing_extensions==4.9.0
100
  tzdata==2023.4
101
  uritemplate==4.1.1
102
  urllib3==2.2.0
103
  uvicorn==0.27.0.post1
 
 
104
  websockets==11.0.3
105
  yarl==1.9.4
106
+ zipp==3.17.0