SkalskiP commited on
Commit
d21820e
1 Parent(s): ea30afb

Input image rescaling added

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -8,6 +8,8 @@ from PIL import Image
8
 
9
  print("google-generativeai:", genai.__version__)
10
 
 
 
11
  TITLE = """<h1 align="center">Gemini Playground 💬</h1>"""
12
  SUBTITLE = """<h2 align="center">Play with Gemini Pro and Gemini Pro Vision API</h2>"""
13
  DUPLICATE = """
@@ -20,13 +22,13 @@ DUPLICATE = """
20
  </span>
21
  </div>
22
  """
 
23
  AVATAR_IMAGES = (
24
  None,
25
  "https://media.roboflow.com/spaces/gemini-icon.png"
26
  )
27
 
28
-
29
- GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
30
 
31
 
32
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
@@ -35,6 +37,11 @@ def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
35
  return [sequence.strip() for sequence in stop_sequences.split(",")]
36
 
37
 
 
 
 
 
 
38
  def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
39
  return "", chatbot + [[text_prompt, None]]
40
 
@@ -42,7 +49,6 @@ def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
42
  def bot(
43
  google_key: str,
44
  image_prompt: Optional[Image.Image],
45
- image_prompt_2: Optional[Image.Image],
46
  temperature: float,
47
  max_output_tokens: int,
48
  stop_sequences: str,
@@ -65,7 +71,7 @@ def bot(
65
  top_k=top_k,
66
  top_p=top_p)
67
 
68
- if image_prompt is None and image_prompt_2 is None:
69
  model = genai.GenerativeModel('gemini-pro')
70
  response = model.generate_content(
71
  text_prompt,
@@ -73,11 +79,10 @@ def bot(
73
  generation_config=generation_config)
74
  response.resolve()
75
  else:
76
- contents = [text_prompt, image_prompt, image_prompt_2]
77
- contents = [content for content in contents if content is not None]
78
  model = genai.GenerativeModel('gemini-pro-vision')
79
  response = model.generate_content(
80
- contents=contents,
81
  stream=True,
82
  generation_config=generation_config)
83
  response.resolve()
@@ -101,8 +106,7 @@ google_key_component = gr.Textbox(
101
  visible=GOOGLE_API_KEY is None
102
  )
103
 
104
- image_prompt_component = gr.Image(type="pil", label="Image")
105
- image_prompt_2_component = gr.Image(type="pil", label="Image")
106
  chatbot_component = gr.Chatbot(
107
  label='Gemini',
108
  bubble_full_width=False,
@@ -180,7 +184,6 @@ user_inputs = [
180
  bot_inputs = [
181
  google_key_component,
182
  image_prompt_component,
183
- image_prompt_2_component,
184
  temperature_component,
185
  max_output_tokens_component,
186
  stop_sequences_component,
@@ -196,10 +199,7 @@ with gr.Blocks() as demo:
196
  with gr.Column():
197
  google_key_component.render()
198
  with gr.Row():
199
- with gr.Column(scale=1):
200
- image_prompt_component.render()
201
- with gr.Accordion("Multi Image", open=False):
202
- image_prompt_2_component.render()
203
  chatbot_component.render()
204
  text_prompt_component.render()
205
  run_button_component.render()
 
8
 
9
  print("google-generativeai:", genai.__version__)
10
 
11
+ GOOGLE_API_KEY = os.environ.get("GOOGLE_API_KEY")
12
+
13
  TITLE = """<h1 align="center">Gemini Playground 💬</h1>"""
14
  SUBTITLE = """<h2 align="center">Play with Gemini Pro and Gemini Pro Vision API</h2>"""
15
  DUPLICATE = """
 
22
  </span>
23
  </div>
24
  """
25
+
26
  AVATAR_IMAGES = (
27
  None,
28
  "https://media.roboflow.com/spaces/gemini-icon.png"
29
  )
30
 
31
+ IMAGE_WIDTH = 512
 
32
 
33
 
34
  def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
 
37
  return [sequence.strip() for sequence in stop_sequences.split(",")]
38
 
39
 
40
+ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
41
+ image_height = int(image.height * IMAGE_WIDTH / image.width)
42
+ return image.resize((IMAGE_WIDTH, image_height))
43
+
44
+
45
  def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
46
  return "", chatbot + [[text_prompt, None]]
47
 
 
49
  def bot(
50
  google_key: str,
51
  image_prompt: Optional[Image.Image],
 
52
  temperature: float,
53
  max_output_tokens: int,
54
  stop_sequences: str,
 
71
  top_k=top_k,
72
  top_p=top_p)
73
 
74
+ if image_prompt is None:
75
  model = genai.GenerativeModel('gemini-pro')
76
  response = model.generate_content(
77
  text_prompt,
 
79
  generation_config=generation_config)
80
  response.resolve()
81
  else:
82
+ image_prompt = preprocess_image(image_prompt)
 
83
  model = genai.GenerativeModel('gemini-pro-vision')
84
  response = model.generate_content(
85
+ contents=[text_prompt, image_prompt],
86
  stream=True,
87
  generation_config=generation_config)
88
  response.resolve()
 
106
  visible=GOOGLE_API_KEY is None
107
  )
108
 
109
+ image_prompt_component = gr.Image(type="pil", label="Image", scale=1)
 
110
  chatbot_component = gr.Chatbot(
111
  label='Gemini',
112
  bubble_full_width=False,
 
184
  bot_inputs = [
185
  google_key_component,
186
  image_prompt_component,
 
187
  temperature_component,
188
  max_output_tokens_component,
189
  stop_sequences_component,
 
199
  with gr.Column():
200
  google_key_component.render()
201
  with gr.Row():
202
+ image_prompt_component.render()
 
 
 
203
  chatbot_component.render()
204
  text_prompt_component.render()
205
  run_button_component.render()