anand004 commited on
Commit
6c98e48
1 Parent(s): 17672b0

handle no images/text

Browse files
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -99,6 +99,7 @@ def get_vectordb(text, images):
99
  client.delete_collection("text_db")
100
  if "image_db" in [i.name for i in client.list_collections()]:
101
  client.delete_collection("image_db")
 
102
  text_collection = client.get_or_create_collection(
103
  name="text_db",
104
  embedding_function=sentence_transformer_ef,
@@ -114,22 +115,24 @@ def get_vectordb(text, images):
114
  image_descriptions = get_image_descriptions(images)
115
  image_dict = [{"image": image_to_bytes(img) for img in images}]
116
 
117
- image_collection.add(
118
- ids=[str(i) for i in range(len(images))],
119
- documents=image_descriptions,
120
- metadatas=image_dict,
121
- )
 
122
 
123
  splitter = RecursiveCharacterTextSplitter(
124
  chunk_size=500,
125
  chunk_overlap=10,
126
  )
127
 
128
- docs = splitter.create_documents([text])
129
- doc_texts = [i.page_content for i in docs]
130
- text_collection.add(
131
- ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
132
- )
 
133
  return client
134
 
135
 
@@ -214,7 +217,7 @@ def conversation(vectordb_client, msg, num_context, img_context, history):
214
  results = text_collection.query(
215
  query_texts=[msg], include=["documents"], n_results=num_context
216
  )["documents"][0]
217
-
218
  similar_images = image_collection.query(
219
  query_texts=[msg],
220
  include=["metadatas", "distances", "documents"],
@@ -246,7 +249,7 @@ def conversation(vectordb_client, msg, num_context, img_context, history):
246
  context = "\n\n".join(results)
247
  # references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
248
  response = llm(prompt.format(context=context, question=msg, images=img_desc))
249
- return history + [(msg, response)], results, images_and_locs
250
 
251
 
252
  def check_validity_and_llm(session_states):
 
99
  client.delete_collection("text_db")
100
  if "image_db" in [i.name for i in client.list_collections()]:
101
  client.delete_collection("image_db")
102
+
103
  text_collection = client.get_or_create_collection(
104
  name="text_db",
105
  embedding_function=sentence_transformer_ef,
 
115
  image_descriptions = get_image_descriptions(images)
116
  image_dict = [{"image": image_to_bytes(img) for img in images}]
117
 
118
+ if len(images)>0:
119
+ image_collection.add(
120
+ ids=[str(i) for i in range(len(images))],
121
+ documents=image_descriptions,
122
+ metadatas=image_dict,
123
+ )
124
 
125
  splitter = RecursiveCharacterTextSplitter(
126
  chunk_size=500,
127
  chunk_overlap=10,
128
  )
129
 
130
+ if len(text)>0:
131
+ docs = splitter.create_documents([text])
132
+ doc_texts = [i.page_content for i in docs]
133
+ text_collection.add(
134
+ ids=[str(i) for i in list(range(len(doc_texts)))], documents=doc_texts
135
+ )
136
  return client
137
 
138
 
 
217
  results = text_collection.query(
218
  query_texts=[msg], include=["documents"], n_results=num_context
219
  )["documents"][0]
220
+
221
  similar_images = image_collection.query(
222
  query_texts=[msg],
223
  include=["metadatas", "distances", "documents"],
 
249
  context = "\n\n".join(results)
250
  # references = [gr.Textbox(i, visible=True, interactive=False) for i in results]
251
  response = llm(prompt.format(context=context, question=msg, images=img_desc))
252
+ yield history + [(msg, response)], results, images_and_locs
253
 
254
 
255
  def check_validity_and_llm(session_states):