Rauhan commited on
Commit
b352af8
1 Parent(s): fe0b51f

UPDATE: ParentDocumentRetreiver

Browse files
Files changed (1) hide show
  1. functions.py +28 -8
functions.py CHANGED
@@ -3,6 +3,9 @@ from langchain_text_splitters import RecursiveCharacterTextSplitter
3
  from langchain_qdrant import QdrantVectorStore
4
  from langchain_core.prompts.chat import ChatPromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
 
 
 
6
  from langchain_huggingface import HuggingFaceEmbeddings
7
  from supabase.client import create_client
8
  from qdrant_client import QdrantClient
@@ -40,6 +43,7 @@ Context:
40
  NOTE: generate responses WITHOUT prepending phrases like "Response:", "Output:", or "Answer:", etc
41
  """
42
  prompt = ChatPromptTemplate.from_template(prompt)
 
43
 
44
 
45
  def createUser(username: str, password: str) -> None:
@@ -100,19 +104,31 @@ def createTable(tablename: str):
100
  def addDocuments(text: str, vectorstore: str):
101
  try:
102
  global embeddings
103
- text_splitter = RecursiveCharacterTextSplitter(
104
- chunk_size = 1024,
105
- chunk_overlap = 200,
 
106
  add_start_index = True
107
  )
108
- texts = text_splitter.create_documents([text])
 
 
 
 
 
109
  vectorstore = QdrantVectorStore.from_existing_collection(
110
  embedding = embeddings,
111
  collection_name=vectorstore,
112
  url=os.environ["QDRANT_URL"],
113
  api_key=os.environ["QDRANT_API_KEY"]
114
  )
115
- vectorstore.add_documents(documents = texts)
 
 
 
 
 
 
116
  return {
117
  "output": "SUCCESS"
118
  }
@@ -139,11 +155,15 @@ def answerQuery(query: str, vectorstore: str, llmModel: str = "llama3-70b-8192")
139
  url=os.environ["QDRANT_URL"],
140
  api_key=os.environ["QDRANT_API_KEY"]
141
  )
142
- retriever = vectorstore.as_retriever()
 
 
 
 
143
  chain = (
144
  {"context": retriever | RunnableLambda(format_docs), "question": RunnablePassthrough()}
145
  | prompt
146
- | ChatGroq(model = llmModel, temperature = 0.3, max_tokens = 512)
147
  | StrOutputParser()
148
  )
149
  return {
@@ -168,7 +188,7 @@ def listTables(username: str):
168
  global qdrantClient
169
  qdrantCollections = qdrantClient.get_collections()
170
  return {
171
- "output": list(filter(lambda x: True if x.split("-")[1] == username else False, [x.name for x in qdrantCollections.collections]))
172
  }
173
  except Exception as e:
174
  return {
 
3
  from langchain_qdrant import QdrantVectorStore
4
  from langchain_core.prompts.chat import ChatPromptTemplate
5
  from langchain_core.output_parsers import StrOutputParser
6
+ from langchain.retrievers import ParentDocumentRetriever
7
+ from langchain.storage import InMemoryStore
8
+ from langchain.docstore.document import Document
9
  from langchain_huggingface import HuggingFaceEmbeddings
10
  from supabase.client import create_client
11
  from qdrant_client import QdrantClient
 
43
  NOTE: generate responses WITHOUT prepending phrases like "Response:", "Output:", or "Answer:", etc
44
  """
45
  prompt = ChatPromptTemplate.from_template(prompt)
46
+ store = InMemoryStore()
47
 
48
 
49
  def createUser(username: str, password: str) -> None:
 
104
  def addDocuments(text: str, vectorstore: str):
105
  try:
106
  global embeddings
107
+ global store
108
+ parentSplitter = RecursiveCharacterTextSplitter(
109
+ chunk_size = 2400,
110
+ chunk_overlap = 240,
111
  add_start_index = True
112
  )
113
+ childSplitter = RecursiveCharacterTextSplitter(
114
+ chunk_size = 300,
115
+ chunk_overlap = 90,
116
+ add_start_index = True
117
+ )
118
+ texts = [Document(page_content = text)]
119
  vectorstore = QdrantVectorStore.from_existing_collection(
120
  embedding = embeddings,
121
  collection_name=vectorstore,
122
  url=os.environ["QDRANT_URL"],
123
  api_key=os.environ["QDRANT_API_KEY"]
124
  )
125
+ retriever = ParentDocumentRetriever(
126
+ vectorstore=vectorstore,
127
+ docstore=store,
128
+ child_splitter=childSplitter,
129
+ parent_splitter=parentSplitter
130
+ )
131
+ retriever.add_documents(documents = texts)
132
  return {
133
  "output": "SUCCESS"
134
  }
 
155
  url=os.environ["QDRANT_URL"],
156
  api_key=os.environ["QDRANT_API_KEY"]
157
  )
158
+ retriever = ParentDocumentRetriever(
159
+ vectorstore=vectorstore,
160
+ docstore=store,
161
+ child_splitter=RecursiveCharacterTextSplitter()
162
+ )
163
  chain = (
164
  {"context": retriever | RunnableLambda(format_docs), "question": RunnablePassthrough()}
165
  | prompt
166
+ | ChatGroq(model = llmModel, temperature = 0.5, max_tokens = 512)
167
  | StrOutputParser()
168
  )
169
  return {
 
188
  global qdrantClient
189
  qdrantCollections = qdrantClient.get_collections()
190
  return {
191
+ "output": list(filter(lambda x: True if x.split("-")[1] == username else False, [x.name.split("-")[-1] for x in qdrantCollections.collections]))
192
  }
193
  except Exception as e:
194
  return {