cha0smagick commited on
Commit
c878823
1 Parent(s): 6c73801

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +274 -106
app.py CHANGED
@@ -1,109 +1,277 @@
1
  import streamlit as st
 
 
 
2
  from PIL import Image
3
- from google.generativeai import GenerativeModel # Assuming the actual module from Gemini Vision API
4
- from pathlib import Path
5
- import logging # Add logging import
6
-
7
- # Initialize session state
8
- def init_session_state():
9
- if 'api_key' not in st.session_state:
10
- st.session_state['api_key'] = ''
11
- if 'temperature' not in st.session_state:
12
- st.session_state['temperature'] = 0.1
13
- if 'top_k' not in st.session_state:
14
- st.session_state['top_k'] = 32
15
- if 'top_p' not in st.session_state:
16
- st.session_state['top_p'] = 1.0
17
- if 'gemini_model' not in st.session_state:
18
- st.session_state['gemini_model'] = None
19
- if 'uploaded_image' not in st.session_state:
20
- st.session_state['uploaded_image'] = None
21
- if 'image_description' not in st.session_state:
22
- st.session_state['image_description'] = ''
23
- if 'text_prompt' not in st.session_state:
24
- st.session_state['text_prompt'] = ''
25
- if 'logger' not in st.session_state or st.session_state['logger'] is None: # Add logger initialization
26
- st.session_state['logger'] = logging.getLogger(__name__)
27
- st.session_state['logger'].setLevel(logging.INFO)
28
-
29
- # Display support
30
- def display_support():
31
- st.markdown("<div style='text-align: center;'>Share and Support</div>", unsafe_allow_html=True)
32
-
33
- st.write("""
34
- <div style="display: flex; flex-direction: column; align-items: center; justify-content: center;">
35
- <ul style="list-style-type: none; margin: 0; padding: 0; display: flex;">
36
- <li style="margin-right: 10px;"><a href="https://twitter.com/haseeb_heaven" target="_blank"><img src="https://img.icons8.com/color/32/000000/twitter--v1.png"/></a></li>
37
- <li style="margin-right: 10px;"><a href="https://www.buymeacoffee.com/haseebheaven" target="_blank"><img src="https://img.icons8.com/color/32/000000/coffee-to-go--v1.png"/></a></li>
38
- <li style="margin-right: 10px;"><a href="https://www.youtube.com/@HaseebHeaven/videos" target="_blank"><img src="https://img.icons8.com/color/32/000000/youtube-play.png"/></a></li>
39
- <li><a href="https://github.com/haseeb-heaven/LangChain-Coder" target="_blank"><img src="https://img.icons8.com/color/32/000000/github--v1.png"/></a></li>
40
- </ul>
41
- </div>
42
- """, unsafe_allow_html=True)
43
-
44
- # Streamlit App
45
- def streamlit_app():
46
- # Google Logo and Title
47
- st.write('<div style="display: flex; flex-direction: row; align-items: center; justify-content: center;"><a style="margin-right: 10px;" href="https://www.google.com" target="_blank"><img src="https://img.icons8.com/color/32/000000/google-logo.png"/></a><h1 style="margin-left: 10px;">Google - Gemini Vision</h1></div>', unsafe_allow_html=True)
48
-
49
- # Display support
50
- display_support()
51
-
52
- # Display the Gemini Sidebar settings
53
- with st.sidebar.title("Gemini Settings"):
54
- st.session_state.api_key = st.sidebar.text_input("API Key", type="password")
55
- st.session_state.temperature = st.sidebar.slider("Temperature", 0.0, 1.0, 0.3)
56
- st.session_state.top_k = st.sidebar.number_input("Top K", value=32)
57
- st.session_state.top_p = st.sidebar.slider("Top P", 0.0, 1.0, 1.0)
58
-
59
- if (st.session_state.api_key is not None and st.session_state.api_key != '') \
60
- and (st.session_state.temperature is not None and st.session_state.temperature != '') \
61
- and (st.session_state.top_k is not None and st.session_state.top_k != '') \
62
- and (st.session_state.top_p is not None and st.session_state.top_p != ''):
63
- st.toast("Settings updated successfully!", icon="👍")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  else:
65
- st.toast("Please enter all the settings.\nAPI Key is required", icon="❌")
66
- raise ValueError("Please enter all values the settings.\nAPI Key is required")
67
-
68
- # Initialize services once
69
- if st.session_state.gemini_model is None:
70
- st.session_state.gemini_model = GenerativeModel(model_name="gemini-pro-vision",
71
- api_key=st.session_state['api_key'],
72
- temperature=st.session_state['temperature'],
73
- top_p=st.session_state['top_p'],
74
- top_k=st.session_state['top_k'])
75
-
76
- uploaded_image = st.file_uploader("Upload an image", type=["jpg", "jpeg", "png"])
77
-
78
- if uploaded_image:
79
- # Display the uploaded image
80
- st.image(uploaded_image, caption="Uploaded Image", use_column_width=True)
81
-
82
- # Store the uploaded image in session state
83
- st.session_state['uploaded_image'] = uploaded_image
84
-
85
- if st.button("Generate Image Description"):
86
- try:
87
- # Process the uploaded image using Gemini Vision API
88
- image_description = st.session_state.gemini_model.generate_content(contents=[uploaded_image, st.session_state['text_prompt']])
89
- st.session_state['image_description'] = image_description.text
90
- st.success(f"Image description generated: {image_description.text}")
91
- except Exception as e:
92
- st.error(f"An error occurred: {e}")
93
-
94
- # Input field for text prompt/question
95
- st.session_state['text_prompt'] = st.text_area("Enter Text Prompt/Question", st.session_state['text_prompt'])
96
-
97
- # Display the generated image description
98
- if st.session_state['image_description']:
99
- st.code(f"Image Description: {st.session_state['image_description']}", language="plaintext")
100
-
101
- if __name__ == "__main__":
102
- try:
103
- init_session_state()
104
- streamlit_app()
105
- except Exception as exception:
106
- import traceback
107
- st.session_state.logger.error(f"An error occurred: {exception}")
108
- st.session_state.logger.error(traceback.format_exc())
109
- st.error(f"An error occurred: {exception}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import pandas as pd
3
+ import google.generativeai as genai
4
+ import re
5
  from PIL import Image
6
+ import requests
7
+
8
+ #Je t'aime plus que les mots,
9
+ #Plus que les sentiments,
10
+ #Plus que la vie elle-même
11
+
12
+ st.set_page_config(
13
+ page_title="Google AI Chat",
14
+ page_icon="https://seeklogo.com/images/G/google-ai-logo-996E85F6FD-seeklogo.com.png",
15
+ layout="wide",
16
+ )
17
+ # Path: Main.py
18
+ #Author: Sergio Demis Lopez Martinez
19
+ #------------------------------------------------------------
20
+ #HEADER
21
+ st.markdown('''
22
+ Powered by Google AI <img src="https://seeklogo.com/images/G/google-ai-logo-996E85F6FD-seeklogo.com.png" width="20" height="20">
23
+ , Streamlit and Python''', unsafe_allow_html=True)
24
+ st.caption("By Sergio Demis Lopez Martinez")
25
+
26
+ #------------------------------------------------------------
27
+ #LANGUAGE
28
+ langcols = st.columns([0.2,0.8])
29
+ with langcols[0]:
30
+ lang = st.selectbox('Select your language',
31
+ ('English', 'Español', 'Français', 'Deutsch',
32
+ 'Italiano', 'Português', 'Polski', 'Nederlands',
33
+ 'Русский', '日本語', '한국어', '中文', 'العربية',
34
+ 'हिन्दी', 'Türkçe', 'Tiếng Việt', 'Bahasa Indonesia',
35
+ 'ภาษาไทย', 'Română', 'Ελληνικά', 'Magyar', 'Čeština',
36
+ 'Svenska', 'Norsk', 'Suomi', 'Dansk', 'हिन्दी', 'हिन्�'),index=1)
37
+
38
+ if 'lang' not in st.session_state:
39
+ st.session_state.lang = lang
40
+ st.divider()
41
+
42
+ #------------------------------------------------------------
43
+ #FUNCTIONS
44
+ def extract_graphviz_info(text: str) -> list[str]:
45
+ """
46
+ The function `extract_graphviz_info` takes in a text and returns a list of graphviz code blocks found in the text.
47
+
48
+ :param text: The `text` parameter is a string that contains the text from which you want to extract Graphviz information
49
+ :return: a list of strings that contain either the word "graph" or "digraph". These strings are extracted from the input
50
+ text.
51
+ """
52
+
53
+ graphviz_info = text.split('```')
54
+
55
+ return [graph for graph in graphviz_info if ('graph' in graph or 'digraph' in graph) and ('{' in graph and '}' in graph)]
56
+
57
+ def append_message(message: dict) -> None:
58
+ """
59
+ The function appends a message to a chat session.
60
+
61
+ :param message: The `message` parameter is a dictionary that represents a chat message. It typically contains
62
+ information such as the user who sent the message and the content of the message
63
+ :type message: dict
64
+ :return: The function is not returning anything.
65
+ """
66
+ st.session_state.chat_session.append({'user': message})
67
+ return
68
+
69
+ @st.cache_resource
70
+ def load_model() -> genai.GenerativeModel:
71
+ """
72
+ The function `load_model()` returns an instance of the `genai.GenerativeModel` class initialized with the model name
73
+ 'gemini-pro'.
74
+ :return: an instance of the `genai.GenerativeModel` class.
75
+ """
76
+ model = genai.GenerativeModel('gemini-pro')
77
+ return model
78
+
79
+ @st.cache_resource
80
+ def load_modelvision() -> genai.GenerativeModel:
81
+ """
82
+ The function `load_modelvision` loads a generative model for vision tasks using the `gemini-pro-vision` model.
83
+ :return: an instance of the `genai.GenerativeModel` class.
84
+ """
85
+ model = genai.GenerativeModel('gemini-pro-vision')
86
+ return model
87
+
88
+
89
+
90
+ #------------------------------------------------------------
91
+ #CONFIGURATION
92
+ genai.configure(api_key=st.secrets["GOOGLE_API_KEY"])
93
+
94
+ model = load_model()
95
+
96
+ vision = load_modelvision()
97
+
98
+ if 'chat' not in st.session_state:
99
+ st.session_state.chat = model.start_chat(history=[])
100
+
101
+ if 'chat_session' not in st.session_state:
102
+ st.session_state.chat_session = []
103
+
104
+ #st.session_state.chat_session
105
+
106
+ #------------------------------------------------------------
107
+ #CHAT
108
+
109
+ if 'messages' not in st.session_state:
110
+ st.session_state.messages = []
111
+
112
+ if 'welcome' not in st.session_state or lang != st.session_state.lang:
113
+ st.session_state.lang = lang
114
+ welcome = model.generate_content(f'''
115
+ Da un saludo de bienvenida al usuario y sugiere que puede hacer
116
+ (Puedes describir imágenes, responder preguntas, leer archivos texto, leer tablas,generar gráficos con graphviz, etc)
117
+ eres un chatbot en una aplicación de chat creada en streamlit y python. generate the answer in {lang}''')
118
+ welcome.resolve()
119
+ st.session_state.welcome = welcome
120
+
121
+ with st.chat_message('ai'):
122
+ st.write(st.session_state.welcome.text)
123
+ else:
124
+ with st.chat_message('ai'):
125
+ st.write(st.session_state.welcome.text)
126
+
127
+ if len(st.session_state.chat_session) > 0:
128
+ count = 0
129
+ for message in st.session_state.chat_session:
130
+
131
+ if message['user']['role'] == 'model':
132
+ with st.chat_message('ai'):
133
+ st.write(message['user']['parts'])
134
+ graphs = extract_graphviz_info(message['user']['parts'])
135
+ if len(graphs) > 0:
136
+ for graph in graphs:
137
+ st.graphviz_chart(graph,use_container_width=False)
138
+ if lang == 'Español':
139
+ view = "Ver texto"
140
+ else:
141
+ view = "View text"
142
+ with st.expander(view):
143
+ st.code(graph, language='dot')
144
  else:
145
+ with st.chat_message('user'):
146
+ st.write(message['user']['parts'][0])
147
+ if len(message['user']['parts']) > 1:
148
+ st.image(message['user']['parts'][1], width=200)
149
+ count += 1
150
+
151
+
152
+
153
+ #st.session_state.chat.history
154
+
155
+ cols=st.columns(4)
156
+
157
+ with cols[0]:
158
+ if lang == 'Español':
159
+ image_atachment = st.toggle("Adjuntar imagen", value=False, help="Activa este modo para adjuntar una imagen y que el chatbot pueda leerla")
160
+ else:
161
+ image_atachment = st.toggle("Attach image", value=False, help="Activate this mode to attach an image and let the chatbot read it")
162
+
163
+ with cols[1]:
164
+ if lang == 'Español':
165
+ txt_atachment = st.toggle("Adjuntar archivo de texto", value=False, help="Activa este modo para adjuntar un archivo de texto y que el chatbot pueda leerlo")
166
+ else:
167
+ txt_atachment = st.toggle("Attach text file", value=False, help="Activate this mode to attach a text file and let the chatbot read it")
168
+ with cols[2]:
169
+ if lang == 'Español':
170
+ csv_excel_atachment = st.toggle("Adjuntar CSV o Excel", value=False, help="Activa este modo para adjuntar un archivo CSV o Excel y que el chatbot pueda leerlo")
171
+ else:
172
+ csv_excel_atachment = st.toggle("Attach CSV or Excel", value=False, help="Activate this mode to attach a CSV or Excel file and let the chatbot read it")
173
+ with cols[3]:
174
+ if lang == 'Español':
175
+ graphviz_mode = st.toggle("Modo graphviz", value=False, help="Activa este modo para generar un grafo con graphviz en .dot a partir de tu mensaje")
176
+ else:
177
+ graphviz_mode = st.toggle("Graphviz mode", value=False, help="Activate this mode to generate a graph with graphviz in .dot from your message")
178
+ if image_atachment:
179
+ if lang == 'Español':
180
+ image = st.file_uploader("Sube tu imagen", type=['png', 'jpg', 'jpeg'])
181
+ url = st.text_input("O pega la url de tu imagen")
182
+ else:
183
+ image = st.file_uploader("Upload your image", type=['png', 'jpg', 'jpeg'])
184
+ url = st.text_input("Or paste your image url")
185
+ else:
186
+ image = None
187
+ url = ''
188
+
189
+
190
+
191
+ if txt_atachment:
192
+ if lang == 'Español':
193
+ txtattachment = st.file_uploader("Sube tu archivo de texto", type=['txt'])
194
+ else:
195
+ txtattachment = st.file_uploader("Upload your text file", type=['txt'])
196
+ else:
197
+ txtattachment = None
198
+
199
+ if csv_excel_atachment:
200
+ if lang == 'Español':
201
+ csvexcelattachment = st.file_uploader("Sube tu archivo CSV o Excel", type=['csv', 'xlsx'])
202
+ else:
203
+ csvexcelattachment = st.file_uploader("Upload your CSV or Excel file", type=['csv', 'xlsx'])
204
+ else:
205
+ csvexcelattachment = None
206
+ if lang == 'Español':
207
+ prompt = st.chat_input("Escribe tu mensaje")
208
+ else:
209
+ prompt = st.chat_input("Write your message")
210
+
211
+ if prompt:
212
+ txt = ''
213
+ if txtattachment:
214
+ txt = txtattachment.getvalue().decode("utf-8")
215
+ if lang == 'Español':
216
+ txt = ' Archivo de texto: \n' + txt
217
+ else:
218
+ txt = ' Text file: \n' + txt
219
+
220
+ if csvexcelattachment:
221
+ try:
222
+ df = pd.read_csv(csvexcelattachment)
223
+ except:
224
+ df = pd.read_excel(csvexcelattachment)
225
+ txt += ' Dataframe: \n' + str(df)
226
+
227
+ if graphviz_mode:
228
+ if lang == 'Español':
229
+ txt += ' Genera un grafo con graphviz en .dot \n'
230
+ else:
231
+ txt += ' Generate a graph with graphviz in .dot \n'
232
+
233
+ if len(txt) > 5000:
234
+ txt = txt[:5000] + '...'
235
+ if image or url != '':
236
+ if url != '':
237
+ img = Image.open(requests.get(url, stream=True).raw)
238
+ else:
239
+ img = Image.open(image)
240
+ prmt = {'role': 'user', 'parts':[prompt+txt, img]}
241
+ else:
242
+ prmt = {'role': 'user', 'parts':[prompt+txt]}
243
+
244
+ append_message(prmt)
245
+
246
+ if lang == 'Español':
247
+ spinertxt = 'Espera un momento, estoy pensando...'
248
+ else:
249
+ spinertxt = 'Wait a moment, I am thinking...'
250
+ with st.spinner(spinertxt):
251
+ if len(prmt['parts']) > 1:
252
+ response = vision.generate_content(prmt['parts'],stream=True,safety_settings=[
253
+ {
254
+ "category": "HARM_CATEGORY_HARASSMENT",
255
+ "threshold": "BLOCK_LOW_AND_ABOVE",
256
+ },
257
+ {
258
+ "category": "HARM_CATEGORY_HATE_SPEECH",
259
+ "threshold": "BLOCK_LOW_AND_ABOVE",
260
+ },
261
+ ]
262
+ )
263
+ response.resolve()
264
+ else:
265
+ response = st.session_state.chat.send_message(prmt['parts'][0])
266
+
267
+ try:
268
+ append_message({'role': 'model', 'parts':response.text})
269
+ except Exception as e:
270
+ append_message({'role': 'model', 'parts':f'{type(e).__name__}: {e}'})
271
+
272
+
273
+ st.rerun()
274
+
275
+
276
+
277
+ #st.session_state.chat_session