erseux commited on
Commit
d92f48e
1 Parent(s): be85079

🔨 regex sub lowercase course codes to uppercase

Browse files
Files changed (1) hide show
  1. magic/self_query_retriever.py +26 -19
magic/self_query_retriever.py CHANGED
@@ -1,25 +1,30 @@
1
  """Retriever that generates and executes structured queries over its own data source.
2
 
3
- This code is adapted from the original implementation in the LangChain repo,
4
  but has been modified to work with the KTH QA system.
5
 
6
  """
7
 
 
 
 
 
 
 
 
 
 
8
  import re
9
  from typing import Any, Dict, List, Optional, Type, cast
 
 
10
 
11
- from pydantic import BaseModel, Field, root_validator
12
 
13
- from langchain import LLMChain
14
- from langchain.base_language import BaseLanguageModel
15
- from langchain.chains.query_constructor.base import load_query_constructor_chain
16
- from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
17
- from langchain.chains.query_constructor.schema import AttributeInfo
18
- from langchain.retrievers.self_query.pinecone import PineconeTranslator
19
- from langchain.schema import BaseRetriever, Document
20
- from langchain.vectorstores import Pinecone, VectorStore
21
 
22
- COURSE_PATTERN = r"\w{2,3}\d{3,4}\w?" # e.g. DD1315
 
 
23
 
24
 
25
  def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
@@ -76,25 +81,26 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
76
  List of relevant documents
77
  """
78
  if re.findall(COURSE_PATTERN, query):
 
79
  inputs = self.llm_chain.prep_inputs(query)
80
  structured_query = cast(
81
- StructuredQuery, self.llm_chain.predict_and_parse(callbacks=None, **inputs)
 
82
  )
83
  if self.verbose:
84
- print("Found course pattern in query, using structured query:")
85
- print(structured_query)
 
86
  new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
87
  structured_query
88
  )
89
  search_kwargs = {**self.search_kwargs, **new_kwargs}
90
  else:
91
  search_kwargs = self.search_kwargs
92
- docs = self.vectorstore.search(query, self.search_type, **search_kwargs)
 
93
  return docs
94
 
95
- async def aget_relevant_documents(self, query: str) -> List[Document]:
96
- raise NotImplementedError
97
-
98
  @classmethod
99
  def from_llm(
100
  cls,
@@ -107,7 +113,8 @@ class SelfQueryRetriever(BaseRetriever, BaseModel):
107
  **kwargs: Any,
108
  ) -> "SelfQueryRetriever":
109
  if structured_query_translator is None:
110
- structured_query_translator = _get_builtin_translator(vectorstore.__class__)
 
111
  chain_kwargs = chain_kwargs or {}
112
  if "allowed_comparators" not in chain_kwargs:
113
  chain_kwargs[
 
1
  """Retriever that generates and executes structured queries over its own data source.
2
 
3
+ NOTE: This code is adapted from the original implementation in the LangChain repo,
4
  but has been modified to work with the KTH QA system.
5
 
6
  """
7
 
8
+ from langchain.vectorstores import Pinecone, VectorStore
9
+ from langchain.schema import BaseRetriever, Document
10
+ from langchain.retrievers.self_query.pinecone import PineconeTranslator
11
+ from langchain.chains.query_constructor.schema import AttributeInfo
12
+ from langchain.chains.query_constructor.ir import StructuredQuery, Visitor
13
+ from langchain.chains.query_constructor.base import load_query_constructor_chain
14
+ from langchain.base_language import BaseLanguageModel
15
+ from langchain import LLMChain
16
+ from pydantic import BaseModel, Field, root_validator
17
  import re
18
  from typing import Any, Dict, List, Optional, Type, cast
19
+ import logging
20
+ logger = logging.getLogger()
21
 
 
22
 
23
+ COURSE_PATTERN = r"[a-zA-Z]{2,3}\d{3,4}\w?" # e.g. DD1315
 
 
 
 
 
 
 
24
 
25
+
26
+ def make_uppercase(matchobj):
27
+ return matchobj.group(0).upper()
28
 
29
 
30
  def _get_builtin_translator(vectorstore_cls: Type[VectorStore]) -> Visitor:
 
81
  List of relevant documents
82
  """
83
  if re.findall(COURSE_PATTERN, query):
84
+ query = re.sub(COURSE_PATTERN, make_uppercase, query)
85
  inputs = self.llm_chain.prep_inputs(query)
86
  structured_query = cast(
87
+ StructuredQuery, self.llm_chain.predict_and_parse(
88
+ callbacks=None, **inputs)
89
  )
90
  if self.verbose:
91
+ logger.info(
92
+ "Found course pattern in query, using structured query:")
93
+ logger.info(structured_query)
94
  new_query, new_kwargs = self.structured_query_translator.visit_structured_query(
95
  structured_query
96
  )
97
  search_kwargs = {**self.search_kwargs, **new_kwargs}
98
  else:
99
  search_kwargs = self.search_kwargs
100
+ docs = self.vectorstore.search(
101
+ query, self.search_type, **search_kwargs)
102
  return docs
103
 
 
 
 
104
  @classmethod
105
  def from_llm(
106
  cls,
 
113
  **kwargs: Any,
114
  ) -> "SelfQueryRetriever":
115
  if structured_query_translator is None:
116
+ structured_query_translator = _get_builtin_translator(
117
+ vectorstore.__class__)
118
  chain_kwargs = chain_kwargs or {}
119
  if "allowed_comparators" not in chain_kwargs:
120
  chain_kwargs[