lindsay-qu commited on
Commit
6ec8a6d
1 Parent(s): 6c8a93e

Update core/retriever/chroma_retriever.py

Browse files
Files changed (1) hide show
  1. core/retriever/chroma_retriever.py +11 -5
core/retriever/chroma_retriever.py CHANGED
@@ -31,6 +31,7 @@ class ChromaRetriever(BaseRetriever):
31
  if not os.path.exists("persist"):
32
  os.mkdir("persist")
33
  client = PersistentClient(path="persist")
 
34
 
35
  try:
36
  collection = client.get_collection(name=collection_name)
@@ -41,8 +42,11 @@ class ChromaRetriever(BaseRetriever):
41
  docs = pdf_loader.load()
42
 
43
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"])
44
- texts = text_splitter.split_documents(docs)
45
- texts = [text.page_content for text in texts]
 
 
 
46
 
47
  collection = client.create_collection(name=collection_name)
48
  if embed_model is not None:
@@ -50,12 +54,14 @@ class ChromaRetriever(BaseRetriever):
50
  collection.add(
51
  embeddings=embeddings,
52
  documents=texts,
53
- ids=[str(i+1) for i in range(len(texts))]
 
54
  )
55
  else:
56
  collection.add(
57
  documents=texts,
58
- ids=[str(i+1) for i in range(len(texts))]
 
59
  )
60
 
61
  self.collection = collection
@@ -82,4 +88,4 @@ class ChromaRetriever(BaseRetriever):
82
  query_texts=[query],
83
  n_results=k,
84
  )
85
- return results['documents'][0]
 
31
  if not os.path.exists("persist"):
32
  os.mkdir("persist")
33
  client = PersistentClient(path="persist")
34
+ print(client.list_collections())
35
 
36
  try:
37
  collection = client.get_collection(name=collection_name)
 
42
  docs = pdf_loader.load()
43
 
44
  text_splitter = RecursiveCharacterTextSplitter(chunk_size=split_args["size"], chunk_overlap=split_args["overlap"])
45
+ split_docs = text_splitter.split_documents(docs)
46
+ texts = [doc.page_content for doc in split_docs]
47
+
48
+ # TODO
49
+ titles = [doc.metadata["title"] for doc in split_docs]
50
 
51
  collection = client.create_collection(name=collection_name)
52
  if embed_model is not None:
 
54
  collection.add(
55
  embeddings=embeddings,
56
  documents=texts,
57
+ ids=[str(i+1) for i in range(len(texts))],
58
+ metadatas=[{"title": title} for title in titles]
59
  )
60
  else:
61
  collection.add(
62
  documents=texts,
63
+ ids=[str(i+1) for i in range(len(texts))],
64
+ metadatas=[{"title": title} for title in titles]
65
  )
66
 
67
  self.collection = collection
 
88
  query_texts=[query],
89
  n_results=k,
90
  )
91
+ return results['documents'][0], [result["title"] for result in results['metadatas'][0]]