AI-ANK commited on
Commit
ffa0700
1 Parent(s): 9a275cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -25
app.py CHANGED
@@ -8,6 +8,8 @@ from llama_index import ServiceContext, VectorStoreIndex, Document, StorageConte
8
  from llama_index.memory import ChatMemoryBuffer
9
  import os
10
  import datetime
 
 
11
 
12
  #imports for resnet
13
  from transformers import AutoFeatureExtractor, ResNetForImageClassification
@@ -45,8 +47,7 @@ This application, titled 'AInimal Go!', is a conceptual prototype designed to de
45
  cookie_manager = stx.CookieManager()
46
 
47
  #Function to init resnet
48
-
49
- @st.cache_resource()
50
  def load_model_and_labels():
51
  # Load animal labels as a dictionary
52
  animal_labels_dict = {}
@@ -81,9 +82,11 @@ def get_image_caption(image_data):
81
  return predicted_label_name, predicted_label_id
82
 
83
 
84
- @st.cache_resource
85
  def init_llm(api_key):
86
- llm = PaLM(api_key=api_key)
 
 
87
  service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
88
 
89
  storage_context = StorageContext.from_defaults(persist_dir="storage")
@@ -92,27 +95,34 @@ def init_llm(api_key):
92
 
93
  return llm, service_context, storage_context, index, chatmemory
94
 
95
- llm, service_context, storage_context, index, chatmemory = init_llm(st.secrets['GOOGLE_API_KEY'])
96
 
97
  def is_animal(predicted_label_id):
98
  # Check if the predicted label ID is within the animal classes range
99
  return 0 <= predicted_label_id <= 398
100
 
101
-
102
  # Function to create the chat engine.
103
  @st.cache_resource
104
  def create_chat_engine(img_desc, api_key):
 
 
 
105
  doc = Document(text=img_desc)
106
-
107
- chat_engine = index.as_chat_engine(
108
- chat_mode="react",
109
- verbose=True,
110
- memory=chatmemory
 
 
 
 
 
111
  )
112
 
113
- return chat_engine
 
114
 
115
-
116
  # Clear chat function
117
  def clear_chat():
118
  if "messages" in st.session_state:
@@ -149,7 +159,7 @@ else:
149
 
150
  col1, col2, col3 = st.columns([1, 2, 1])
151
  with col2: # Camera input will be in the middle column
152
- camera_image = st.camera_input("Take a picture")
153
 
154
 
155
  # Determine the source of the image (upload or camera)
@@ -162,17 +172,20 @@ else:
162
 
163
  if image_data:
164
  # Display the uploaded image at a standard width.
 
165
  st.image(image_data, caption='Uploaded Image.', width=200)
166
 
167
  # Process the uploaded image to get a caption.
 
168
  img_desc, label_id = get_image_caption(image_data)
169
 
170
  if not (is_animal(label_id)):
 
171
  st.error("Please upload image of an animal!")
172
  st.stop()
173
 
174
  # Initialize the chat engine with the image description.
175
- chat_engine = create_chat_engine(img_desc, st.secrets['GOOGLE_API_KEY'])
176
  st.write("Image Uploaded Successfully. Ask me anything about it.")
177
 
178
 
@@ -182,8 +195,9 @@ else:
182
 
183
  # Display previous messages
184
  for message in st.session_state.messages:
185
- with st.chat_message(message["role"]):
186
- st.markdown(message["content"])
 
187
 
188
  # Handle new user input
189
  user_input = st.chat_input("Ask me about the image:", key="chat_input")
@@ -193,27 +207,38 @@ else:
193
 
194
  # Display user message immediately
195
  with st.chat_message("user"):
196
- st.markdown(user_input)
197
 
198
  # Call the chat engine to get the response if an image has been uploaded
199
  if image_data and user_input:
200
  try:
201
  with st.spinner('Waiting for the chat engine to respond...'):
202
  # Get the response from your chat engine
203
- response = chat_engine.chat(f"""You are a chatbot that roleplays as an animal and also makes animal sounds when chatting.
204
- You always answer in great detail and are polite. Your responses always descriptive.
205
- Your job is to rolelpay as the animal that is mentioned in the image the user has uploaded. Image description: {img_desc}. User question
206
- {user_input}""")
207
-
 
 
 
 
 
 
 
 
208
  # Append assistant message to the session state
209
- st.session_state.messages.append({"role": "assistant", "content": response})
210
 
211
  # Display the assistant message
212
  with st.chat_message("assistant"):
213
- st.markdown(response)
 
214
 
215
  except Exception as e:
216
  st.error(f'An error occurred.')
 
 
217
 
218
  # Increment the message count and update the cookie
219
  message_count += 1
 
8
  from llama_index.memory import ChatMemoryBuffer
9
  import os
10
  import datetime
11
+ from llama_index.llms import Cohere
12
+ from llama_index.query_engine import CitationQueryEngine
13
 
14
  #imports for resnet
15
  from transformers import AutoFeatureExtractor, ResNetForImageClassification
 
47
  cookie_manager = stx.CookieManager()
48
 
49
  #Function to init resnet
50
+ @st.cache_resource(show_spinner="Initializing ResNet model for image classification. Please wait...")
 
51
  def load_model_and_labels():
52
  # Load animal labels as a dictionary
53
  animal_labels_dict = {}
 
82
  return predicted_label_name, predicted_label_id
83
 
84
 
85
+ @st.cache_resource(show_spinner="Initializing LLM and setting up service context. Please wait...")
86
  def init_llm(api_key):
87
+ # llm = PaLM(api_key=api_key)
88
+ llm = Cohere(model="command", api_key=st.secrets['COHERE_API_TOKEN'])
89
+
90
  service_context = ServiceContext.from_defaults(llm=llm, embed_model="local")
91
 
92
  storage_context = StorageContext.from_defaults(persist_dir="storage")
 
95
 
96
  return llm, service_context, storage_context, index, chatmemory
97
 
98
+ llm, service_context, storage_context, index, chatmemory = init_llm(os.environ["GOOGLE_API_KEY"])
99
 
100
  def is_animal(predicted_label_id):
101
  # Check if the predicted label ID is within the animal classes range
102
  return 0 <= predicted_label_id <= 398
103
 
 
104
  # Function to create the chat engine.
105
  @st.cache_resource
106
  def create_chat_engine(img_desc, api_key):
107
+
108
+ #llm = PaLM(api_key=api_key)
109
+ #service_context = ServiceContext.from_defaults(llm=llm,embed_model="local")
110
  doc = Document(text=img_desc)
111
+
112
+ # Now is_animal is a boolean indicating whether the image is of an animal
113
+ print("Is the image of an animal:", is_animal)
114
+
115
+ query_engine = CitationQueryEngine.from_args(
116
+ index,
117
+ similarity_top_k=3,
118
+ # here we can control how granular citation sources are, the default is 512
119
+ citation_chunk_size=512,
120
+ verbose=True
121
  )
122
 
123
+ return query_engine
124
+
125
 
 
126
  # Clear chat function
127
  def clear_chat():
128
  if "messages" in st.session_state:
 
159
 
160
  col1, col2, col3 = st.columns([1, 2, 1])
161
  with col2: # Camera input will be in the middle column
162
+ camera_image = st.camera_input("Take a picture", on_change=on_image_upload)
163
 
164
 
165
  # Determine the source of the image (upload or camera)
 
172
 
173
  if image_data:
174
  # Display the uploaded image at a standard width.
175
+ st.session_state['assistant_avatar'] = image_data
176
  st.image(image_data, caption='Uploaded Image.', width=200)
177
 
178
  # Process the uploaded image to get a caption.
179
+ #img_desc = get_image_caption(image_data)
180
  img_desc, label_id = get_image_caption(image_data)
181
 
182
  if not (is_animal(label_id)):
183
+ #st.error("Please upload image of an animal!")
184
  st.error("Please upload image of an animal!")
185
  st.stop()
186
 
187
  # Initialize the chat engine with the image description.
188
+ chat_engine = create_chat_engine(img_desc, os.environ["GOOGLE_API_KEY"])
189
  st.write("Image Uploaded Successfully. Ask me anything about it.")
190
 
191
 
 
195
 
196
  # Display previous messages
197
  for message in st.session_state.messages:
198
+ avatar = st.session_state['assistant_avatar'] if message["role"] == "assistant" else None
199
+ with st.chat_message(message["role"], avatar = avatar):
200
+ st.write(message["content"])
201
 
202
  # Handle new user input
203
  user_input = st.chat_input("Ask me about the image:", key="chat_input")
 
207
 
208
  # Display user message immediately
209
  with st.chat_message("user"):
210
+ st.write(user_input)
211
 
212
  # Call the chat engine to get the response if an image has been uploaded
213
  if image_data and user_input:
214
  try:
215
  with st.spinner('Waiting for the chat engine to respond...'):
216
  # Get the response from your chat engine
217
+ system_prompt=f"""
218
+ You are a chatbot, able to have normal interactions. Do not make up information.
219
+ You always answer in great detail and are polite. Your job is to roleplay as an {img_desc}.
220
+ Remember to make {img_desc} sounds while talking but dont overdo it.
221
+ """
222
+
223
+ response = chat_engine.query(f"{system_prompt}. {user_input}")
224
+
225
+ #response = chat_engine.chat(f"""You are a chatbot that roleplays as an animal and also makes animal sounds when chatting.
226
+ #You always answer in great detail and are polite. Your responses always descriptive.
227
+ #Your job is to rolelpay as the animal that is mentioned in the image the user has uploaded. Image description: {img_desc}. User question
228
+ #{user_input}""")
229
+
230
  # Append assistant message to the session state
231
+ st.session_state.messages.append({"role": "assistant", "content": response.response})
232
 
233
  # Display the assistant message
234
  with st.chat_message("assistant"):
235
+ st.write(response.response)
236
+ st.expander("hello")
237
 
238
  except Exception as e:
239
  st.error(f'An error occurred.')
240
+ # Optionally, you can choose to break the flow here if a critical error happens
241
+ # return
242
 
243
  # Increment the message count and update the cookie
244
  message_count += 1