Update chains/arxiv_chains.py

#2
by mpsk - opened
Files changed (2) hide show
  1. app.py +2 -2
  2. chains/arxiv_chains.py +17 -0
app.py CHANGED
@@ -22,8 +22,8 @@ from langchain.chains import LLMChain
22
  from langchain_experimental.utilities.sql_database import SQLDatabase
23
  from langchain_experimental.retrievers.sql_database import SQLDatabaseChainRetriever
24
  from langchain_experimental.sql.base import SQLDatabaseChain
25
- from langchain_experimental.sql.parser import VectorSQLRetrieveAllOutputParser
26
 
 
27
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
28
  from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
29
  ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
@@ -155,7 +155,7 @@ def build_retriever():
155
  template=_myscale_prompt,
156
  )
157
 
158
- output_parser = VectorSQLRetrieveAllOutputParser.from_embeddings(
159
  model=embeddings)
160
  sql_query_chain = SQLDatabaseChain.from_llm(
161
  llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
 
22
  from langchain_experimental.utilities.sql_database import SQLDatabase
23
  from langchain_experimental.retrievers.sql_database import SQLDatabaseChainRetriever
24
  from langchain_experimental.sql.base import SQLDatabaseChain
 
25
 
26
+ from chains.arxiv_chains import VectorSQLRetrieveCustomOutputParser
27
  from chains.arxiv_chains import ArXivQAwithSourcesChain, ArXivStuffDocumentChain
28
  from callbacks.arxiv_callbacks import ChatDataSelfSearchCallBackHandler, \
29
  ChatDataSelfAskCallBackHandler, ChatDataSQLSearchCallBackHandler, \
 
155
  template=_myscale_prompt,
156
  )
157
 
158
+ output_parser = VectorSQLRetrieveCustomOutputParser.from_embeddings(
159
  model=embeddings)
160
  sql_query_chain = SQLDatabaseChain.from_llm(
161
  llm=OpenAI(openai_api_key=OPENAI_API_KEY, temperature=0),
chains/arxiv_chains.py CHANGED
@@ -16,6 +16,23 @@ from langchain.chains.combine_documents.base import BaseCombineDocumentsChain
16
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  class ArXivStuffDocumentChain(StuffDocumentsChain):
20
  """Combine arxiv documents with PDF reference number"""
21
 
 
16
  from langchain.chains.combine_documents.stuff import StuffDocumentsChain
17
 
18
 
19
+ class VectorSQLRetrieveCustomOutputParser(VectorSQLOutputParser):
20
+ """Based on VectorSQLOutputParser
21
+ It also modify the SQL to get all columns
22
+ """
23
+
24
+ @property
25
+ def _type(self) -> str:
26
+ return "vector_sql_retrieve_custom"
27
+
28
+ def parse(self, text: str) -> Dict[str, Any]:
29
+ text = text.strip()
30
+ start = text.upper().find("SELECT")
31
+ if start >= 0:
32
+ end = text.upper().find("FROM")
33
+ text = text.replace(text[start + len("SELECT") + 1 : end - 1], "title, abstract, authors, pubdate, categories, id")
34
+ return super().parse(text)
35
+
36
  class ArXivStuffDocumentChain(StuffDocumentsChain):
37
  """Combine arxiv documents with PDF reference number"""
38