tedslin commited on
Commit
ca02256
1 Parent(s): 264871b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +96 -38
app.py CHANGED
@@ -24,10 +24,7 @@ DUPLICATE = """
24
  </div>
25
  """
26
 
27
- AVATAR_IMAGES = (
28
- None,
29
- "https://media.roboflow.com/spaces/gemini-icon.png"
30
- )
31
 
32
  IMAGE_CACHE_DIRECTORY = "/tmp"
33
  IMAGE_WIDTH = 512
@@ -54,22 +51,22 @@ def cache_pil_image(image: Image.Image) -> str:
54
 
55
 
56
  def preprocess_chat_history(
57
- history: CHAT_HISTORY
58
  ) -> List[Dict[str, Union[str, List[str]]]]:
59
  messages = []
60
  for user_message, model_message in history:
61
  if isinstance(user_message, tuple):
62
  pass
63
  elif user_message is not None:
64
- messages.append({'role': 'user', 'parts': [user_message]})
65
  if model_message is not None:
66
- messages.append({'role': 'model', 'parts': [model_message]})
67
  return messages
68
 
69
 
70
  def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
71
  for file in files:
72
- image = Image.open(file).convert('RGB')
73
  image = preprocess_image(image)
74
  image_path = cache_pil_image(image)
75
  chatbot.append(((image_path,), None))
@@ -90,7 +87,9 @@ def bot(
90
  stop_sequences: str,
91
  top_k: int,
92
  top_p: float,
93
- chatbot: CHAT_HISTORY
 
 
94
  ):
95
  if len(chatbot) == 0:
96
  return chatbot
@@ -99,7 +98,12 @@ def bot(
99
  if not google_key:
100
  raise ValueError(
101
  "GOOGLE_API_KEY is not set. "
102
- "Please follow the instructions in the README to set it up.")
 
 
 
 
 
103
 
104
  genai.configure(api_key=google_key)
105
  generation_config = genai.types.GenerationConfig(
@@ -107,31 +111,38 @@ def bot(
107
  max_output_tokens=max_output_tokens,
108
  stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
109
  top_k=top_k,
110
- top_p=top_p)
 
111
 
112
  if files:
113
- text_prompt = [chatbot[-1][0]] \
114
- if chatbot[-1][0] and isinstance(chatbot[-1][0], str) \
 
115
  else []
116
- image_prompt = [Image.open(file).convert('RGB') for file in files]
117
- model = genai.GenerativeModel('gemini-pro-vision')
 
118
  response = model.generate_content(
119
  text_prompt + image_prompt,
120
  stream=True,
121
- generation_config=generation_config)
 
 
122
  else:
123
  messages = preprocess_chat_history(chatbot)
124
- model = genai.GenerativeModel('gemini-pro')
125
  response = model.generate_content(
126
  messages,
127
  stream=True,
128
- generation_config=generation_config)
 
 
129
 
130
  # streaming effect
131
  chatbot[-1][1] = ""
132
  for chunk in response:
133
  for i in range(0, len(chunk.text), 10):
134
- section = chunk.text[i:i + 10]
135
  chatbot[-1][1] += section
136
  time.sleep(0.01)
137
  yield chatbot
@@ -143,18 +154,19 @@ google_key_component = gr.Textbox(
143
  type="password",
144
  placeholder="...",
145
  info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
146
- visible=GOOGLE_API_KEY is None
147
  )
148
  chatbot_component = gr.Chatbot(
149
- label='Gemini',
150
  bubble_full_width=False,
151
  avatar_images=AVATAR_IMAGES,
152
  scale=2,
153
- height=400
154
  )
155
  text_prompt_component = gr.Textbox(
156
  placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8
157
  )
 
158
  upload_button_component = gr.UploadButton(
159
  label="Upload Images", file_count="multiple", file_types=["image"], scale=1
160
  )
@@ -169,7 +181,8 @@ temperature_component = gr.Slider(
169
  "Temperature controls the degree of randomness in token selection. Lower "
170
  "temperatures are good for prompts that expect a true or correct response, "
171
  "while higher temperatures can lead to more diverse or unexpected results. "
172
- ))
 
173
  max_output_tokens_component = gr.Slider(
174
  minimum=1,
175
  maximum=2048,
@@ -178,8 +191,9 @@ max_output_tokens_component = gr.Slider(
178
  label="Token limit",
179
  info=(
180
  "Token limit determines the maximum amount of text output from one prompt. A "
181
- "token is approximately four characters. The default value is 2048."
182
- ))
 
183
  stop_sequences_component = gr.Textbox(
184
  label="Add stop sequence",
185
  value="",
@@ -189,7 +203,8 @@ stop_sequences_component = gr.Textbox(
189
  "A stop sequence is a series of characters (including spaces) that stops "
190
  "response generation if the model encounters it. The sequence is not included "
191
  "as part of the response. You can add up to five stop sequences."
192
- ))
 
193
  top_k_component = gr.Slider(
194
  minimum=1,
195
  maximum=40,
@@ -202,7 +217,8 @@ top_k_component = gr.Slider(
202
  "vocabulary (also called greedy decoding), while a top-k of 3 means that the "
203
  "next token is selected from among the 3 most probable tokens (using "
204
  "temperature)."
205
- ))
 
206
  top_p_component = gr.Slider(
207
  minimum=0,
208
  maximum=1,
@@ -215,12 +231,44 @@ top_p_component = gr.Slider(
215
  "top-p value. For example, if tokens A, B, and C have a probability of .3, .2, "
216
  "and .1 and the top-p value is .5, then the model will select either A or B as "
217
  "the next token (using temperature). "
218
- ))
 
219
 
220
- user_inputs = [
221
- text_prompt_component,
222
- chatbot_component
223
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
  bot_inputs = [
226
  google_key_component,
@@ -230,7 +278,9 @@ bot_inputs = [
230
  stop_sequences_component,
231
  top_k_component,
232
  top_p_component,
233
- chatbot_component
 
 
234
  ]
235
 
236
  with gr.Blocks() as demo:
@@ -242,12 +292,16 @@ with gr.Blocks() as demo:
242
  chatbot_component.render()
243
  with gr.Row():
244
  text_prompt_component.render()
 
245
  upload_button_component.render()
246
  run_button_component.render()
247
  with gr.Accordion("Parameters", open=False):
248
  temperature_component.render()
249
  max_output_tokens_component.render()
250
  stop_sequences_component.render()
 
 
 
251
  with gr.Accordion("Advanced", open=False):
252
  top_k_component.render()
253
  top_p_component.render()
@@ -256,25 +310,29 @@ with gr.Blocks() as demo:
256
  fn=user,
257
  inputs=user_inputs,
258
  outputs=[text_prompt_component, chatbot_component],
259
- queue=False
260
  ).then(
261
- fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
 
 
262
  )
263
 
264
  text_prompt_component.submit(
265
  fn=user,
266
  inputs=user_inputs,
267
  outputs=[text_prompt_component, chatbot_component],
268
- queue=False
269
  ).then(
270
- fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
 
 
271
  )
272
 
273
  upload_button_component.upload(
274
  fn=upload,
275
  inputs=[upload_button_component, chatbot_component],
276
  outputs=[chatbot_component],
277
- queue=False
278
  )
279
 
280
  demo.queue(max_size=99).launch(debug=False, show_error=True)
 
24
  </div>
25
  """
26
 
27
+ AVATAR_IMAGES = (None, "https://media.roboflow.com/spaces/gemini-icon.png")
 
 
 
28
 
29
  IMAGE_CACHE_DIRECTORY = "/tmp"
30
  IMAGE_WIDTH = 512
 
51
 
52
 
53
  def preprocess_chat_history(
54
+ history: CHAT_HISTORY,
55
  ) -> List[Dict[str, Union[str, List[str]]]]:
56
  messages = []
57
  for user_message, model_message in history:
58
  if isinstance(user_message, tuple):
59
  pass
60
  elif user_message is not None:
61
+ messages.append({"role": "user", "parts": [user_message]})
62
  if model_message is not None:
63
+ messages.append({"role": "model", "parts": [model_message]})
64
  return messages
65
 
66
 
67
  def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
68
  for file in files:
69
+ image = Image.open(file).convert("RGB")
70
  image = preprocess_image(image)
71
  image_path = cache_pil_image(image)
72
  chatbot.append(((image_path,), None))
 
87
  stop_sequences: str,
88
  top_k: int,
89
  top_p: float,
90
+ categories: Optional[List[str]],
91
+ threshold: str,
92
+ chatbot: CHAT_HISTORY,
93
  ):
94
  if len(chatbot) == 0:
95
  return chatbot
 
98
  if not google_key:
99
  raise ValueError(
100
  "GOOGLE_API_KEY is not set. "
101
+ "Please follow the instructions in the README to set it up."
102
+ )
103
+
104
+ safety_settings = []
105
+ for category in categories:
106
+ safety_settings.append({"category": category, "threshold": threshold})
107
 
108
  genai.configure(api_key=google_key)
109
  generation_config = genai.types.GenerationConfig(
 
111
  max_output_tokens=max_output_tokens,
112
  stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
113
  top_k=top_k,
114
+ top_p=top_p,
115
+ )
116
 
117
  if files:
118
+ text_prompt = (
119
+ [chatbot[-1][0]]
120
+ if chatbot[-1][0] and isinstance(chatbot[-1][0], str)
121
  else []
122
+ )
123
+ image_prompt = [Image.open(file).convert("RGB") for file in files]
124
+ model = genai.GenerativeModel("gemini-pro-vision")
125
  response = model.generate_content(
126
  text_prompt + image_prompt,
127
  stream=True,
128
+ generation_config=generation_config,
129
+ safety_settings=safety_settings,
130
+ )
131
  else:
132
  messages = preprocess_chat_history(chatbot)
133
+ model = genai.GenerativeModel("gemini-pro")
134
  response = model.generate_content(
135
  messages,
136
  stream=True,
137
+ generation_config=generation_config,
138
+ safety_settings=safety_settings,
139
+ )
140
 
141
  # streaming effect
142
  chatbot[-1][1] = ""
143
  for chunk in response:
144
  for i in range(0, len(chunk.text), 10):
145
+ section = chunk.text[i : i + 10]
146
  chatbot[-1][1] += section
147
  time.sleep(0.01)
148
  yield chatbot
 
154
  type="password",
155
  placeholder="...",
156
  info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
157
+ visible=GOOGLE_API_KEY is None,
158
  )
159
  chatbot_component = gr.Chatbot(
160
+ label="Gemini",
161
  bubble_full_width=False,
162
  avatar_images=AVATAR_IMAGES,
163
  scale=2,
164
+ height=400,
165
  )
166
  text_prompt_component = gr.Textbox(
167
  placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8
168
  )
169
+
170
  upload_button_component = gr.UploadButton(
171
  label="Upload Images", file_count="multiple", file_types=["image"], scale=1
172
  )
 
181
  "Temperature controls the degree of randomness in token selection. Lower "
182
  "temperatures are good for prompts that expect a true or correct response, "
183
  "while higher temperatures can lead to more diverse or unexpected results. "
184
+ ),
185
+ )
186
  max_output_tokens_component = gr.Slider(
187
  minimum=1,
188
  maximum=2048,
 
191
  label="Token limit",
192
  info=(
193
  "Token limit determines the maximum amount of text output from one prompt. A "
194
+ "token is approximately four characters. The max value is 2048."
195
+ ),
196
+ )
197
  stop_sequences_component = gr.Textbox(
198
  label="Add stop sequence",
199
  value="",
 
203
  "A stop sequence is a series of characters (including spaces) that stops "
204
  "response generation if the model encounters it. The sequence is not included "
205
  "as part of the response. You can add up to five stop sequences."
206
+ ),
207
+ )
208
  top_k_component = gr.Slider(
209
  minimum=1,
210
  maximum=40,
 
217
  "vocabulary (also called greedy decoding), while a top-k of 3 means that the "
218
  "next token is selected from among the 3 most probable tokens (using "
219
  "temperature)."
220
+ ),
221
+ )
222
  top_p_component = gr.Slider(
223
  minimum=0,
224
  maximum=1,
 
231
  "top-p value. For example, if tokens A, B, and C have a probability of .3, .2, "
232
  "and .1 and the top-p value is .5, then the model will select either A or B as "
233
  "the next token (using temperature). "
234
+ ),
235
+ )
236
 
237
+ category_dropdown_component = gr.Dropdown(
238
+ label="Category",
239
+ choices=[
240
+ "HARM_CATEGORY_DANGEROUS",
241
+ "HARM_CATEGORY_HARASSMENT",
242
+ "HARM_CATEGORY_HATE_SPEECH",
243
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT",
244
+ ],
245
+ value=[
246
+ "HARM_CATEGORY_DANGEROUS",
247
+ "HARM_CATEGORY_HARASSMENT",
248
+ "HARM_CATEGORY_HATE_SPEECH",
249
+ "HARM_CATEGORY_SEXUALLY_EXPLICIT",
250
+ ],
251
+ info=(
252
+ "The category of a rating."
253
+ "These categories cover various kinds of harms that developers may wish to adjust."
254
+ ),
255
+ multiselect=True,
256
+ )
257
+
258
+ threshold_dropdown_component = gr.Dropdown(
259
+ label="Threshold",
260
+ choices=[
261
+ "BLOCK_LOW_AND_ABOVE",
262
+ "BLOCK_MEDIUM_AND_ABOVE",
263
+ "BLOCK_ONLY_HIGH",
264
+ "BLOCK_NONE",
265
+ ],
266
+ value="BLOCK_NONE",
267
+ info=("Block at and beyond a specified harm probability."),
268
+ )
269
+
270
+
271
+ user_inputs = [text_prompt_component, chatbot_component]
272
 
273
  bot_inputs = [
274
  google_key_component,
 
278
  stop_sequences_component,
279
  top_k_component,
280
  top_p_component,
281
+ category_dropdown_component,
282
+ threshold_dropdown_component,
283
+ chatbot_component,
284
  ]
285
 
286
  with gr.Blocks() as demo:
 
292
  chatbot_component.render()
293
  with gr.Row():
294
  text_prompt_component.render()
295
+ clear_component = gr.ClearButton([text_prompt_component, chatbot_component])
296
  upload_button_component.render()
297
  run_button_component.render()
298
  with gr.Accordion("Parameters", open=False):
299
  temperature_component.render()
300
  max_output_tokens_component.render()
301
  stop_sequences_component.render()
302
+ with gr.Accordion("Safe Setting", open=False):
303
+ category_dropdown_component.render()
304
+ threshold_dropdown_component.render()
305
  with gr.Accordion("Advanced", open=False):
306
  top_k_component.render()
307
  top_p_component.render()
 
310
  fn=user,
311
  inputs=user_inputs,
312
  outputs=[text_prompt_component, chatbot_component],
313
+ queue=False,
314
  ).then(
315
+ fn=bot,
316
+ inputs=bot_inputs,
317
+ outputs=[chatbot_component],
318
  )
319
 
320
  text_prompt_component.submit(
321
  fn=user,
322
  inputs=user_inputs,
323
  outputs=[text_prompt_component, chatbot_component],
324
+ queue=False,
325
  ).then(
326
+ fn=bot,
327
+ inputs=bot_inputs,
328
+ outputs=[chatbot_component],
329
  )
330
 
331
  upload_button_component.upload(
332
  fn=upload,
333
  inputs=[upload_button_component, chatbot_component],
334
  outputs=[chatbot_component],
335
+ queue=False,
336
  )
337
 
338
  demo.queue(max_size=99).launch(debug=False, show_error=True)