SkalskiP commited on
Commit
ef38c60
1 Parent(s): b7463e4

Multi-image upload support

Browse files
Files changed (1) hide show
  1. app.py +62 -28
app.py CHANGED
@@ -1,12 +1,12 @@
1
  import os
2
  import time
3
- from typing import List, Tuple, Optional, Dict
 
4
 
5
  import google.generativeai as genai
6
  import gradio as gr
7
  from PIL import Image
8
 
9
-
10
  print("google-generativeai:", genai.__version__)
11
 
12
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
@@ -29,7 +29,9 @@ AVATAR_IMAGES = (
29
  "https://media.roboflow.com/spaces/gemini-icon.png"
30
  )
31
 
 
32
  IMAGE_WIDTH = 512
 
33
 
34
 
35
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
@@ -43,39 +45,62 @@ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
43
  return image.resize((IMAGE_WIDTH, image_height))
44
 
45
 
 
 
 
 
 
 
 
 
46
  def preprocess_chat_history(
47
- history: List[Tuple[Optional[str], Optional[str]]]
48
- ) -> List[Dict[str, List[str]]]:
49
  messages = []
50
  for user_message, model_message in history:
51
- if user_message is not None:
 
 
52
  messages.append({'role': 'user', 'parts': [user_message]})
53
  if model_message is not None:
54
  messages.append({'role': 'model', 'parts': [model_message]})
55
  return messages
56
 
57
 
58
- def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
59
- return "", chatbot + [[text_prompt, None]]
 
 
 
 
 
 
 
 
 
 
 
60
 
61
 
62
  def bot(
63
  google_key: str,
64
- image_prompt: Optional[Image.Image],
65
  temperature: float,
66
  max_output_tokens: int,
67
  stop_sequences: str,
68
  top_k: int,
69
  top_p: float,
70
- chatbot: List[Tuple[str, str]]
71
  ):
 
 
 
72
  google_key = google_key if google_key else GOOGLE_API_KEY
73
  if not google_key:
74
  raise ValueError(
75
  "GOOGLE_API_KEY is not set. "
76
  "Please follow the instructions in the README to set it up.")
77
 
78
- text_prompt = chatbot[-1][0]
79
  genai.configure(api_key=google_key)
80
  generation_config = genai.types.GenerationConfig(
81
  temperature=temperature,
@@ -84,21 +109,23 @@ def bot(
84
  top_k=top_k,
85
  top_p=top_p)
86
 
87
- if image_prompt is None:
88
- model = genai.GenerativeModel('gemini-pro')
 
 
 
 
89
  response = model.generate_content(
90
- preprocess_chat_history(chatbot),
91
  stream=True,
92
  generation_config=generation_config)
93
- response.resolve()
94
  else:
95
- image_prompt = preprocess_image(image_prompt)
96
- model = genai.GenerativeModel('gemini-pro-vision')
97
  response = model.generate_content(
98
- contents=[text_prompt, image_prompt],
99
  stream=True,
100
  generation_config=generation_config)
101
- response.resolve()
102
 
103
  # streaming effect
104
  chatbot[-1][1] = ""
@@ -118,8 +145,6 @@ google_key_component = gr.Textbox(
118
  info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
119
  visible=GOOGLE_API_KEY is None
120
  )
121
-
122
- image_prompt_component = gr.Image(type="pil", label="Image", scale=1, height=400)
123
  chatbot_component = gr.Chatbot(
124
  label='Gemini',
125
  bubble_full_width=False,
@@ -128,10 +153,12 @@ chatbot_component = gr.Chatbot(
128
  height=400
129
  )
130
  text_prompt_component = gr.Textbox(
131
- placeholder="Hi there!",
132
- label="Ask me anything and press Enter"
 
 
133
  )
134
- run_button_component = gr.Button()
135
  temperature_component = gr.Slider(
136
  minimum=0,
137
  maximum=1.0,
@@ -197,7 +224,7 @@ user_inputs = [
197
 
198
  bot_inputs = [
199
  google_key_component,
200
- image_prompt_component,
201
  temperature_component,
202
  max_output_tokens_component,
203
  stop_sequences_component,
@@ -212,11 +239,11 @@ with gr.Blocks() as demo:
212
  gr.HTML(DUPLICATE)
213
  with gr.Column():
214
  google_key_component.render()
 
215
  with gr.Row():
216
- image_prompt_component.render()
217
- chatbot_component.render()
218
- text_prompt_component.render()
219
- run_button_component.render()
220
  with gr.Accordion("Parameters", open=False):
221
  temperature_component.render()
222
  max_output_tokens_component.render()
@@ -243,4 +270,11 @@ with gr.Blocks() as demo:
243
  fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
244
  )
245
 
 
 
 
 
 
 
 
246
  demo.queue(max_size=99).launch(debug=False, show_error=True)
 
1
  import os
2
  import time
3
+ import uuid
4
+ from typing import List, Tuple, Optional, Dict, Union
5
 
6
  import google.generativeai as genai
7
  import gradio as gr
8
  from PIL import Image
9
 
 
10
  print("google-generativeai:", genai.__version__)
11
 
12
  GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
 
29
  "https://media.roboflow.com/spaces/gemini-icon.png"
30
  )
31
 
32
+ IMAGE_CACHE_DIRECTORY = "/tmp"
33
  IMAGE_WIDTH = 512
34
+ CHAT_HISTORY = List[Tuple[Optional[Union[Tuple[str], str]], Optional[str]]]
35
 
36
 
37
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
 
45
  return image.resize((IMAGE_WIDTH, image_height))
46
 
47
 
48
+ def cache_pil_image(image: Image.Image) -> str:
49
+ image_filename = f"{uuid.uuid4()}.jpeg"
50
+ os.makedirs(IMAGE_CACHE_DIRECTORY, exist_ok=True)
51
+ image_path = os.path.join(IMAGE_CACHE_DIRECTORY, image_filename)
52
+ image.save(image_path, "JPEG")
53
+ return image_path
54
+
55
+
56
  def preprocess_chat_history(
57
+ history: CHAT_HISTORY
58
+ ) -> List[Dict[str, Union[str, List[str]]]]:
59
  messages = []
60
  for user_message, model_message in history:
61
+ if isinstance(user_message, tuple):
62
+ pass
63
+ elif user_message is not None:
64
  messages.append({'role': 'user', 'parts': [user_message]})
65
  if model_message is not None:
66
  messages.append({'role': 'model', 'parts': [model_message]})
67
  return messages
68
 
69
 
70
+ def upload(files: Optional[List[str]], chatbot: CHAT_HISTORY) -> CHAT_HISTORY:
71
+ for file in files:
72
+ image = Image.open(file).convert('RGB')
73
+ image = preprocess_image(image)
74
+ image_path = cache_pil_image(image)
75
+ chatbot.append(((image_path,), None))
76
+ return chatbot
77
+
78
+
79
+ def user(text_prompt: str, chatbot: CHAT_HISTORY):
80
+ if text_prompt:
81
+ chatbot.append((text_prompt, None))
82
+ return "", chatbot
83
 
84
 
85
  def bot(
86
  google_key: str,
87
+ files: Optional[List[str]],
88
  temperature: float,
89
  max_output_tokens: int,
90
  stop_sequences: str,
91
  top_k: int,
92
  top_p: float,
93
+ chatbot: CHAT_HISTORY
94
  ):
95
+ if len(chatbot) == 0:
96
+ return chatbot
97
+
98
  google_key = google_key if google_key else GOOGLE_API_KEY
99
  if not google_key:
100
  raise ValueError(
101
  "GOOGLE_API_KEY is not set. "
102
  "Please follow the instructions in the README to set it up.")
103
 
 
104
  genai.configure(api_key=google_key)
105
  generation_config = genai.types.GenerationConfig(
106
  temperature=temperature,
 
109
  top_k=top_k,
110
  top_p=top_p)
111
 
112
+ if files:
113
+ text_prompt = [chatbot[-1][0]] \
114
+ if chatbot[-1][0] and isinstance(chatbot[-1][0], str) \
115
+ else []
116
+ image_prompt = [Image.open(file).convert('RGB') for file in files]
117
+ model = genai.GenerativeModel('gemini-pro-vision')
118
  response = model.generate_content(
119
+ text_prompt + image_prompt,
120
  stream=True,
121
  generation_config=generation_config)
 
122
  else:
123
+ messages = preprocess_chat_history(chatbot)
124
+ model = genai.GenerativeModel('gemini-pro')
125
  response = model.generate_content(
126
+ messages,
127
  stream=True,
128
  generation_config=generation_config)
 
129
 
130
  # streaming effect
131
  chatbot[-1][1] = ""
 
145
  info="You have to provide your own GOOGLE_API_KEY for this app to function properly",
146
  visible=GOOGLE_API_KEY is None
147
  )
 
 
148
  chatbot_component = gr.Chatbot(
149
  label='Gemini',
150
  bubble_full_width=False,
 
153
  height=400
154
  )
155
  text_prompt_component = gr.Textbox(
156
+ placeholder="Hi there! [press Enter]", show_label=False, autofocus=True, scale=8
157
+ )
158
+ upload_button_component = gr.UploadButton(
159
+ label="Upload Images", file_count="multiple", file_types=["image"], scale=1
160
  )
161
+ run_button_component = gr.Button(value="Run", variant="primary", scale=1)
162
  temperature_component = gr.Slider(
163
  minimum=0,
164
  maximum=1.0,
 
224
 
225
  bot_inputs = [
226
  google_key_component,
227
+ upload_button_component,
228
  temperature_component,
229
  max_output_tokens_component,
230
  stop_sequences_component,
 
239
  gr.HTML(DUPLICATE)
240
  with gr.Column():
241
  google_key_component.render()
242
+ chatbot_component.render()
243
  with gr.Row():
244
+ text_prompt_component.render()
245
+ upload_button_component.render()
246
+ run_button_component.render()
 
247
  with gr.Accordion("Parameters", open=False):
248
  temperature_component.render()
249
  max_output_tokens_component.render()
 
270
  fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
271
  )
272
 
273
+ upload_button_component.upload(
274
+ fn=upload,
275
+ inputs=[upload_button_component, chatbot_component],
276
+ outputs=[chatbot_component],
277
+ queue=False
278
+ )
279
+
280
  demo.queue(max_size=99).launch(debug=False, show_error=True)