cha0smagick commited on
Commit
5b568c0
1 Parent(s): 11cc9fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -275
app.py CHANGED
@@ -1,277 +1,41 @@
1
  import streamlit as st
2
- import pandas as pd
3
- import google.generativeai as genai
4
- import re
5
- from PIL import Image
6
  import requests
7
-
8
- #Je t'aime plus que les mots,
9
- #Plus que les sentiments,
10
- #Plus que la vie elle-même
11
-
12
- st.set_page_config(
13
- page_title="Google AI Chat",
14
- page_icon="https://seeklogo.com/images/G/google-ai-logo-996E85F6FD-seeklogo.com.png",
15
- layout="wide",
16
- )
17
- # Path: Main.py
18
- #Author: Sergio Demis Lopez Martinez
19
- #------------------------------------------------------------
20
- #HEADER
21
- st.markdown('''
22
- Powered by Google AI <img src="https://seeklogo.com/images/G/google-ai-logo-996E85F6FD-seeklogo.com.png" width="20" height="20">
23
- , Streamlit and Python''', unsafe_allow_html=True)
24
- st.caption("By Sergio Demis Lopez Martinez")
25
-
26
- #------------------------------------------------------------
27
- #LANGUAGE
28
- langcols = st.columns([0.2,0.8])
29
- with langcols[0]:
30
- lang = st.selectbox('Select your language',
31
- ('English', 'Español', 'Français', 'Deutsch',
32
- 'Italiano', 'Português', 'Polski', 'Nederlands',
33
- 'Русский', '日本語', '한국어', '中文', 'العربية',
34
- 'हिन्दी', 'Türkçe', 'Tiếng Việt', 'Bahasa Indonesia',
35
- 'ภาษาไทย', 'Română', 'Ελληνικά', 'Magyar', 'Čeština',
36
- 'Svenska', 'Norsk', 'Suomi', 'Dansk', 'हिन्दी', 'हिन्�'),index=1)
37
-
38
- if 'lang' not in st.session_state:
39
- st.session_state.lang = lang
40
- st.divider()
41
-
42
- #------------------------------------------------------------
43
- #FUNCTIONS
44
- def extract_graphviz_info(text: str) -> list[str]:
45
- """
46
- The function `extract_graphviz_info` takes in a text and returns a list of graphviz code blocks found in the text.
47
-
48
- :param text: The `text` parameter is a string that contains the text from which you want to extract Graphviz information
49
- :return: a list of strings that contain either the word "graph" or "digraph". These strings are extracted from the input
50
- text.
51
- """
52
-
53
- graphviz_info = text.split('```')
54
-
55
- return [graph for graph in graphviz_info if ('graph' in graph or 'digraph' in graph) and ('{' in graph and '}' in graph)]
56
-
57
- def append_message(message: dict) -> None:
58
- """
59
- The function appends a message to a chat session.
60
-
61
- :param message: The `message` parameter is a dictionary that represents a chat message. It typically contains
62
- information such as the user who sent the message and the content of the message
63
- :type message: dict
64
- :return: The function is not returning anything.
65
- """
66
- st.session_state.chat_session.append({'user': message})
67
- return
68
-
69
- @st.cache_resource
70
- def load_model() -> genai.GenerativeModel:
71
- """
72
- The function `load_model()` returns an instance of the `genai.GenerativeModel` class initialized with the model name
73
- 'gemini-pro'.
74
- :return: an instance of the `genai.GenerativeModel` class.
75
- """
76
- model = genai.GenerativeModel('gemini-pro')
77
- return model
78
-
79
- @st.cache_resource
80
- def load_modelvision() -> genai.GenerativeModel:
81
- """
82
- The function `load_modelvision` loads a generative model for vision tasks using the `gemini-pro-vision` model.
83
- :return: an instance of the `genai.GenerativeModel` class.
84
- """
85
- model = genai.GenerativeModel('gemini-pro-vision')
86
- return model
87
-
88
-
89
-
90
- #------------------------------------------------------------
91
- #CONFIGURATION
92
- genai.configure(api_key=st.secrets["GOOGLE_API_KEY"])
93
-
94
- model = load_model()
95
-
96
- vision = load_modelvision()
97
-
98
- if 'chat' not in st.session_state:
99
- st.session_state.chat = model.start_chat(history=[])
100
-
101
- if 'chat_session' not in st.session_state:
102
- st.session_state.chat_session = []
103
-
104
- #st.session_state.chat_session
105
-
106
- #------------------------------------------------------------
107
- #CHAT
108
-
109
- if 'messages' not in st.session_state:
110
- st.session_state.messages = []
111
-
112
- if 'welcome' not in st.session_state or lang != st.session_state.lang:
113
- st.session_state.lang = lang
114
- welcome = model.generate_content(f'''
115
- Da un saludo de bienvenida al usuario y sugiere que puede hacer
116
- (Puedes describir imágenes, responder preguntas, leer archivos texto, leer tablas,generar gráficos con graphviz, etc)
117
- eres un chatbot en una aplicación de chat creada en streamlit y python. generate the answer in {lang}''')
118
- welcome.resolve()
119
- st.session_state.welcome = welcome
120
-
121
- with st.chat_message('ai'):
122
- st.write(st.session_state.welcome.text)
123
- else:
124
- with st.chat_message('ai'):
125
- st.write(st.session_state.welcome.text)
126
-
127
- if len(st.session_state.chat_session) > 0:
128
- count = 0
129
- for message in st.session_state.chat_session:
130
-
131
- if message['user']['role'] == 'model':
132
- with st.chat_message('ai'):
133
- st.write(message['user']['parts'])
134
- graphs = extract_graphviz_info(message['user']['parts'])
135
- if len(graphs) > 0:
136
- for graph in graphs:
137
- st.graphviz_chart(graph,use_container_width=False)
138
- if lang == 'Español':
139
- view = "Ver texto"
140
- else:
141
- view = "View text"
142
- with st.expander(view):
143
- st.code(graph, language='dot')
144
- else:
145
- with st.chat_message('user'):
146
- st.write(message['user']['parts'][0])
147
- if len(message['user']['parts']) > 1:
148
- st.image(message['user']['parts'][1], width=200)
149
- count += 1
150
-
151
-
152
-
153
- #st.session_state.chat.history
154
-
155
- cols=st.columns(4)
156
-
157
- with cols[0]:
158
- if lang == 'Español':
159
- image_atachment = st.toggle("Adjuntar imagen", value=False, help="Activa este modo para adjuntar una imagen y que el chatbot pueda leerla")
160
- else:
161
- image_atachment = st.toggle("Attach image", value=False, help="Activate this mode to attach an image and let the chatbot read it")
162
-
163
- with cols[1]:
164
- if lang == 'Español':
165
- txt_atachment = st.toggle("Adjuntar archivo de texto", value=False, help="Activa este modo para adjuntar un archivo de texto y que el chatbot pueda leerlo")
166
- else:
167
- txt_atachment = st.toggle("Attach text file", value=False, help="Activate this mode to attach a text file and let the chatbot read it")
168
- with cols[2]:
169
- if lang == 'Español':
170
- csv_excel_atachment = st.toggle("Adjuntar CSV o Excel", value=False, help="Activa este modo para adjuntar un archivo CSV o Excel y que el chatbot pueda leerlo")
171
- else:
172
- csv_excel_atachment = st.toggle("Attach CSV or Excel", value=False, help="Activate this mode to attach a CSV or Excel file and let the chatbot read it")
173
- with cols[3]:
174
- if lang == 'Español':
175
- graphviz_mode = st.toggle("Modo graphviz", value=False, help="Activa este modo para generar un grafo con graphviz en .dot a partir de tu mensaje")
176
- else:
177
- graphviz_mode = st.toggle("Graphviz mode", value=False, help="Activate this mode to generate a graph with graphviz in .dot from your message")
178
- if image_atachment:
179
- if lang == 'Español':
180
- image = st.file_uploader("Sube tu imagen", type=['png', 'jpg', 'jpeg'])
181
- url = st.text_input("O pega la url de tu imagen")
182
- else:
183
- image = st.file_uploader("Upload your image", type=['png', 'jpg', 'jpeg'])
184
- url = st.text_input("Or paste your image url")
185
- else:
186
- image = None
187
- url = ''
188
-
189
-
190
-
191
- if txt_atachment:
192
- if lang == 'Español':
193
- txtattachment = st.file_uploader("Sube tu archivo de texto", type=['txt'])
194
- else:
195
- txtattachment = st.file_uploader("Upload your text file", type=['txt'])
196
- else:
197
- txtattachment = None
198
-
199
- if csv_excel_atachment:
200
- if lang == 'Español':
201
- csvexcelattachment = st.file_uploader("Sube tu archivo CSV o Excel", type=['csv', 'xlsx'])
202
- else:
203
- csvexcelattachment = st.file_uploader("Upload your CSV or Excel file", type=['csv', 'xlsx'])
204
- else:
205
- csvexcelattachment = None
206
- if lang == 'Español':
207
- prompt = st.chat_input("Escribe tu mensaje")
208
- else:
209
- prompt = st.chat_input("Write your message")
210
-
211
- if prompt:
212
- txt = ''
213
- if txtattachment:
214
- txt = txtattachment.getvalue().decode("utf-8")
215
- if lang == 'Español':
216
- txt = ' Archivo de texto: \n' + txt
217
- else:
218
- txt = ' Text file: \n' + txt
219
-
220
- if csvexcelattachment:
221
- try:
222
- df = pd.read_csv(csvexcelattachment)
223
- except:
224
- df = pd.read_excel(csvexcelattachment)
225
- txt += ' Dataframe: \n' + str(df)
226
-
227
- if graphviz_mode:
228
- if lang == 'Español':
229
- txt += ' Genera un grafo con graphviz en .dot \n'
230
- else:
231
- txt += ' Generate a graph with graphviz in .dot \n'
232
-
233
- if len(txt) > 5000:
234
- txt = txt[:5000] + '...'
235
- if image or url != '':
236
- if url != '':
237
- img = Image.open(requests.get(url, stream=True).raw)
238
- else:
239
- img = Image.open(image)
240
- prmt = {'role': 'user', 'parts':[prompt+txt, img]}
241
- else:
242
- prmt = {'role': 'user', 'parts':[prompt+txt]}
243
-
244
- append_message(prmt)
245
-
246
- if lang == 'Español':
247
- spinertxt = 'Espera un momento, estoy pensando...'
248
- else:
249
- spinertxt = 'Wait a moment, I am thinking...'
250
- with st.spinner(spinertxt):
251
- if len(prmt['parts']) > 1:
252
- response = vision.generate_content(prmt['parts'],stream=True,safety_settings=[
253
- {
254
- "category": "HARM_CATEGORY_HARASSMENT",
255
- "threshold": "BLOCK_LOW_AND_ABOVE",
256
- },
257
- {
258
- "category": "HARM_CATEGORY_HATE_SPEECH",
259
- "threshold": "BLOCK_LOW_AND_ABOVE",
260
- },
261
- ]
262
- )
263
- response.resolve()
264
- else:
265
- response = st.session_state.chat.send_message(prmt['parts'][0])
266
-
267
- try:
268
- append_message({'role': 'model', 'parts':response.text})
269
- except Exception as e:
270
- append_message({'role': 'model', 'parts':f'{type(e).__name__}: {e}'})
271
-
272
-
273
- st.rerun()
274
-
275
-
276
-
277
- #st.session_state.chat_session
 
1
  import streamlit as st
2
+ from transformers import AutoTokenizer, AutoModelForImageCaptioning
 
 
 
3
  import requests
4
+ from PIL import Image
5
+ import numpy as np
6
+
7
+ # Initialize the tokenizer and model
8
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/beit-base-patch16-224-in21k")
9
+ model = AutoModelForImageCaptioning.from_pretrained("microsoft/beit-base-patch16-224-in21k")
10
+
11
+ def generate_caption(image_url):
12
+ # Get the image from the URL
13
+ image = Image.open(requests.get(image_url, stream=True).raw)
14
+
15
+ # Preprocess the image
16
+ input_array = np.array(image) / 255.0
17
+ input_array = np.transpose(input_array, (2, 0, 1))
18
+ input_ids = tokenizer(image_url, return_tensors="pt").input_ids
19
+
20
+ # Generate the caption
21
+ output = model.generate(input_ids, max_length=20)
22
+ caption = tokenizer.batch_decode(output, skip_special_tokens=True)
23
+
24
+ return caption[0]
25
+
26
+ def main():
27
+ # Create a sidebar for the user to input the image URL
28
+ st.sidebar.header("Image Caption Generator")
29
+ image_url = st.sidebar.text_input("Enter the URL of an image:")
30
+
31
+ # Generate the caption if the user clicks the button
32
+ if st.sidebar.button("Generate Caption"):
33
+ if image_url != "":
34
+ caption = generate_caption(image_url)
35
+ st.success(f"Caption: {caption}")
36
+ else:
37
+ st.error("Please enter a valid image URL.")
38
+
39
+ # Run the main function
40
+ if __name__ == "__main__":
41
+ main()