Tomoniai commited on
Commit
7536f52
1 Parent(s): aa4a8aa

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +216 -0
app.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from typing import List, Tuple, Optional, Dict
4
+
5
+ import google.generativeai as genai
6
+ import gradio as gr
7
+ from PIL import Image
8
+
9
+ print("google-generativeai:", genai.__version__)
10
+
11
+ GG_API_KEY = os.environ.get("GG_API_KEY")
12
+ oaiusr = os.environ.get("OAI_USR")
13
+ oaipwd = os.environ.get("OAI_PWD")
14
+
15
+ TITLE = """<h2 align="center">Tomoniai's Gemini Pro Chat</h2>"""
16
+
17
+ AVATAR_IMAGES = ("./user.png", "./botg.png")
18
+
19
+ IMAGE_WIDTH = 512
20
+
21
+
22
+ def preprocess_stop_sequences(stop_sequences: str) -> Optional[List[str]]:
23
+ if not stop_sequences:
24
+ return None
25
+ return [sequence.strip() for sequence in stop_sequences.split(",")]
26
+
27
+
28
+ def preprocess_image(image: Image.Image) -> Optional[Image.Image]:
29
+ image_height = int(image.height * IMAGE_WIDTH / image.width)
30
+ return image.resize((IMAGE_WIDTH, image_height))
31
+
32
+
33
+ def preprocess_chat_history(
34
+ history: List[Tuple[Optional[str], Optional[str]]]
35
+ ) -> List[Dict[str, List[str]]]:
36
+ messages = []
37
+ for user_message, model_message in history:
38
+ if user_message is not None:
39
+ messages.append({'role': 'user', 'parts': [user_message]})
40
+ if model_message is not None:
41
+ messages.append({'role': 'model', 'parts': [model_message]})
42
+ return messages
43
+
44
+
45
+ def user(text_prompt: str, chatbot: List[Tuple[str, str]]):
46
+ return "", chatbot + [[text_prompt, None]]
47
+
48
+
49
+ def bot(
50
+ google_key: str,
51
+ image_prompt: Optional[Image.Image],
52
+ temperature: float,
53
+ max_output_tokens: int,
54
+ stop_sequences: str,
55
+ top_k: int,
56
+ top_p: float,
57
+ chatbot: List[Tuple[str, str]]
58
+ ):
59
+ google_key = google_key if google_key else GG_API_KEY
60
+
61
+ text_prompt = chatbot[-1][0]
62
+ genai.configure(api_key=google_key)
63
+ generation_config = genai.types.GenerationConfig(
64
+ temperature=temperature,
65
+ max_output_tokens=max_output_tokens,
66
+ stop_sequences=preprocess_stop_sequences(stop_sequences=stop_sequences),
67
+ top_k=top_k,
68
+ top_p=top_p)
69
+
70
+ if image_prompt is None:
71
+ model = genai.GenerativeModel('gemini-pro')
72
+ response = model.generate_content(
73
+ preprocess_chat_history(chatbot),
74
+ stream=True,
75
+ generation_config=generation_config)
76
+ response.resolve()
77
+ else:
78
+ image_prompt = preprocess_image(image_prompt)
79
+ model = genai.GenerativeModel('gemini-pro-vision')
80
+ response = model.generate_content(
81
+ contents=[text_prompt, image_prompt],
82
+ stream=True,
83
+ generation_config=generation_config)
84
+ response.resolve()
85
+
86
+ # streaming effect
87
+ chatbot[-1][1] = ""
88
+ for chunk in response:
89
+ for i in range(0, len(chunk.text), 10):
90
+ section = chunk.text[i:i + 10]
91
+ chatbot[-1][1] += section
92
+ time.sleep(0.01)
93
+ yield chatbot
94
+
95
+
96
+ image_prompt_component = gr.Image(type="pil", label="Image", scale=1, height=400)
97
+ chatbot_component = gr.Chatbot(
98
+ label='Gemini',
99
+ bubble_full_width=False,
100
+ avatar_images=AVATAR_IMAGES,
101
+ scale=2,
102
+ height=400
103
+ )
104
+ text_prompt_component = gr.Textbox(
105
+ placeholder="Hi there!",
106
+ label="Ask me anything and press Enter"
107
+ )
108
+ run_button_component = gr.Button()
109
+ temperature_component = gr.Slider(
110
+ minimum=0,
111
+ maximum=1.0,
112
+ value=0.4,
113
+ step=0.05,
114
+ label="Temperature",
115
+ info=(
116
+ "Temperature controls the degree of randomness in token selection. Lower "
117
+ "temperatures are good for prompts that expect a true or correct response, "
118
+ "while higher temperatures can lead to more diverse or unexpected results. "
119
+ ))
120
+ max_output_tokens_component = gr.Slider(
121
+ minimum=1,
122
+ maximum=2048,
123
+ value=1024,
124
+ step=1,
125
+ label="Token limit",
126
+ info=(
127
+ "Token limit determines the maximum amount of text output from one prompt. A "
128
+ "token is approximately four characters. The default value is 2048."
129
+ ))
130
+ stop_sequences_component = gr.Textbox(
131
+ label="Add stop sequence",
132
+ value="",
133
+ type="text",
134
+ placeholder="STOP, END",
135
+ info=(
136
+ "A stop sequence is a series of characters (including spaces) that stops "
137
+ "response generation if the model encounters it. The sequence is not included "
138
+ "as part of the response. You can add up to five stop sequences."
139
+ ))
140
+ top_k_component = gr.Slider(
141
+ minimum=1,
142
+ maximum=40,
143
+ value=32,
144
+ step=1,
145
+ label="Top-K",
146
+ info=(
147
+ "Top-k changes how the model selects tokens for output. A top-k of 1 means the "
148
+ "selected token is the most probable among all tokens in the model’s "
149
+ "vocabulary (also called greedy decoding), while a top-k of 3 means that the "
150
+ "next token is selected from among the 3 most probable tokens (using "
151
+ "temperature)."
152
+ ))
153
+ top_p_component = gr.Slider(
154
+ minimum=0,
155
+ maximum=1,
156
+ value=1,
157
+ step=0.01,
158
+ label="Top-P",
159
+ info=(
160
+ "Top-p changes how the model selects tokens for output. Tokens are selected "
161
+ "from most probable to least until the sum of their probabilities equals the "
162
+ "top-p value. For example, if tokens A, B, and C have a probability of .3, .2, "
163
+ "and .1 and the top-p value is .5, then the model will select either A or B as "
164
+ "the next token (using temperature). "
165
+ ))
166
+
167
+ user_inputs = [
168
+ text_prompt_component,
169
+ chatbot_component
170
+ ]
171
+
172
+ bot_inputs = [
173
+ image_prompt_component,
174
+ temperature_component,
175
+ max_output_tokens_component,
176
+ stop_sequences_component,
177
+ top_k_component,
178
+ top_p_component,
179
+ chatbot_component
180
+ ]
181
+
182
+ with gr.Blocks() as demo:
183
+ gr.HTML(TITLE)
184
+ with gr.Column():
185
+ with gr.Row():
186
+ image_prompt_component.render()
187
+ chatbot_component.render()
188
+ text_prompt_component.render()
189
+ run_button_component.render()
190
+ with gr.Accordion("Parameters", open=False):
191
+ temperature_component.render()
192
+ max_output_tokens_component.render()
193
+ stop_sequences_component.render()
194
+ with gr.Accordion("Advanced", open=False):
195
+ top_k_component.render()
196
+ top_p_component.render()
197
+
198
+ run_button_component.click(
199
+ fn=user,
200
+ inputs=user_inputs,
201
+ outputs=[text_prompt_component, chatbot_component],
202
+ queue=False
203
+ ).then(
204
+ fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
205
+ )
206
+
207
+ text_prompt_component.submit(
208
+ fn=user,
209
+ inputs=user_inputs,
210
+ outputs=[text_prompt_component, chatbot_component],
211
+ queue=False
212
+ ).then(
213
+ fn=bot, inputs=bot_inputs, outputs=[chatbot_component],
214
+ )
215
+
216
+ demo.queue(max_size=99).launch(auth=(oaiusr, oaipwd),show_api=False, debug=False, show_error=True)