timeki commited on
Commit
ecc6c98
·
1 Parent(s): 190826e

Switch vectorestore to azure search

Browse files
app.py CHANGED
@@ -7,7 +7,7 @@ from azure.storage.fileshare import ShareServiceClient
7
  # Import custom modules
8
  from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
- from climateqa.engine.vectorstore import get_pinecone_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
@@ -66,17 +66,11 @@ user_id = create_user_id()
66
 
67
  # Create vectorstore and retriever
68
  embeddings_function = get_embeddings_function()
69
- vectorstore = get_pinecone_vectorstore(
70
- embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")
71
- )
72
- vectorstore_graphs = get_pinecone_vectorstore(
73
- embeddings_function,
74
- index_name=os.getenv("PINECONE_API_INDEX_OWID"),
75
- text_key="description",
76
- )
77
- vectorstore_region = get_pinecone_vectorstore(
78
- embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")
79
- )
80
 
81
  llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
82
  if os.environ["GRADIO_ENV"] == "local":
 
7
  # Import custom modules
8
  from climateqa.engine.embeddings import get_embeddings_function
9
  from climateqa.engine.llm import get_llm
10
+ from climateqa.engine.vectorstore import get_vectorstore
11
  from climateqa.engine.reranker import get_reranker
12
  from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc
13
  from climateqa.engine.chains.retrieve_papers import find_papers
 
66
 
67
  # Create vectorstore and retriever
68
  embeddings_function = get_embeddings_function()
69
+
70
+ vectorstore = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-ipx")
71
+ vectorstore_graphs = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-owid", text_key="description")
72
+ vectorstore_region = get_vectorstore(provider="azure_search", embeddings=embeddings_function, index_name="climateqa-v2")
73
+
 
 
 
 
 
 
74
 
75
  llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0)
76
  if os.environ["GRADIO_ENV"] == "local":
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -19,7 +19,7 @@ from ..llm import get_llm
19
  from .prompts import retrieve_chapter_prompt_template
20
  from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.output_parsers import StrOutputParser
22
- from ..vectorstore import get_pinecone_vectorstore
23
  from ..embeddings import get_embeddings_function
24
  import ast
25
 
@@ -134,7 +134,7 @@ def get_ToCs(version: str) :
134
  "version": version
135
  }
136
  embeddings_function = get_embeddings_function()
137
- vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
138
  tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
139
 
140
  # remove duplicates or almost duplicates
@@ -236,7 +236,7 @@ async def get_POC_documents_by_ToC_relevant_documents(
236
  filters_text_toc = {
237
  **filters,
238
  "chunk_type":"text",
239
- "toc_level0": {"$in": toc_filters},
240
  "version": version
241
  # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
242
  }
@@ -273,6 +273,22 @@ async def get_POC_documents_by_ToC_relevant_documents(
273
  "docs_images" : docs_images
274
  }
275
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
276
 
277
  async def get_IPCC_relevant_documents(
278
  query: str,
@@ -299,9 +315,9 @@ async def get_IPCC_relevant_documents(
299
  filters = {}
300
 
301
  if len(reports) > 0:
302
- filters["short_name"] = {"$in":reports}
303
  else:
304
- filters["source"] = { "$in": sources}
305
 
306
  # INIT
307
  docs_summaries = []
@@ -323,18 +339,16 @@ async def get_IPCC_relevant_documents(
323
  filters_summaries = {
324
  **filters,
325
  "chunk_type":"text",
326
- "report_type": { "$in":["SPM"]},
327
  }
328
 
329
  docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
330
  docs_summaries = [x for x in docs_summaries if x[1] > threshold]
331
 
332
  # Search for k_total - k_summary documents in the full reports dataset
333
- filters_full = {
334
- **filters,
335
- "chunk_type":"text",
336
- "report_type": { "$nin":["SPM"]},
337
- }
338
  docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
339
 
340
  if search_figures:
 
19
  from .prompts import retrieve_chapter_prompt_template
20
  from langchain_core.prompts import ChatPromptTemplate
21
  from langchain_core.output_parsers import StrOutputParser
22
+ from ..vectorstore import get_vectorstore
23
  from ..embeddings import get_embeddings_function
24
  import ast
25
 
 
134
  "version": version
135
  }
136
  embeddings_function = get_embeddings_function()
137
+ vectorstore = get_vectorstore(provider="qdrant", embeddings=embeddings_function, index_name="climateqa")
138
  tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
139
 
140
  # remove duplicates or almost duplicates
 
236
  filters_text_toc = {
237
  **filters,
238
  "chunk_type":"text",
239
+ "toc_level0": toc_filters, # Changed from {"$in": toc_filters} to direct list
240
  "version": version
241
  # "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
242
  }
 
273
  "docs_images" : docs_images
274
  }
275
 
276
+ def filter_for_full_report_documents(filters: dict) -> dict:
277
+ """
278
+ Filter for full report documents.
279
+ Returns a dictionary format compatible with all vectorstore providers.
280
+ """
281
+ # Start with the base filters
282
+ full_filters = filters.copy()
283
+
284
+ # Add chunk_type filter
285
+ full_filters["chunk_type"] = "text"
286
+
287
+ # Add report_type exclusion using the new _exclude suffix format
288
+ # This will be converted to appropriate OData filter by Azure Search wrapper
289
+ full_filters["report_type_exclude"] = ["SPM"]
290
+
291
+ return full_filters
292
 
293
  async def get_IPCC_relevant_documents(
294
  query: str,
 
315
  filters = {}
316
 
317
  if len(reports) > 0:
318
+ filters["short_name"] = reports # Changed from {"$in":reports} to direct list
319
  else:
320
+ filters["source"] = sources # Changed from {"$in": sources} to direct list
321
 
322
  # INIT
323
  docs_summaries = []
 
339
  filters_summaries = {
340
  **filters,
341
  "chunk_type":"text",
342
+ "report_type": ["SPM"], # Changed from {"$in":["SPM"]} to direct list
343
  }
344
 
345
  docs_summaries = vectorstore.similarity_search_with_score(query=query,filter = filters_summaries,k = k_summary)
346
  docs_summaries = [x for x in docs_summaries if x[1] > threshold]
347
 
348
  # Search for k_total - k_summary documents in the full reports dataset
349
+ filters_full = filter_for_full_report_documents(filters)
350
+
351
+
 
 
352
  docs_full = vectorstore.similarity_search_with_score(query=query,filter = filters_full,k = k_total)
353
 
354
  if search_figures:
climateqa/engine/graph_retriever.py CHANGED
@@ -60,10 +60,9 @@ async def retrieve_graphs(
60
  assert sources
61
  assert any([x in ["OWID"] for x in sources])
62
 
63
- # Prepare base search kwargs
64
- filters = {}
65
-
66
- filters["source"] = {"$in": sources}
67
 
68
  docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
69
 
 
60
  assert sources
61
  assert any([x in ["OWID"] for x in sources])
62
 
63
+ # Prepare base search kwargs for Azure AI Search
64
+ # Azure expects a filter string, e.g. "source eq 'OWID' or source eq 'IEA'"
65
+ filters = {"source":"OWID"}
 
66
 
67
  docs = vectorstore.similarity_search_with_score(query=query, filter=filters, k=k_total)
68
 
climateqa/engine/llm/openai.py CHANGED
@@ -8,7 +8,6 @@ except Exception:
8
  pass
9
 
10
  def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
11
-
12
  llm = ChatOpenAI(
13
  model=model,
14
  api_key=os.environ.get("THEO_API_KEY", None),
 
8
  pass
9
 
10
  def get_llm(model="gpt-4o-mini",max_tokens=1024, temperature=0.0, streaming=True,timeout=30, **kwargs):
 
11
  llm = ChatOpenAI(
12
  model=model,
13
  api_key=os.environ.get("THEO_API_KEY", None),
climateqa/engine/vectorstore.py CHANGED
@@ -1,11 +1,11 @@
1
- # Pinecone
2
- # More info at https://docs.pinecone.io/docs/langchain
3
- # And https://python.langchain.com/docs/integrations/vectorstores/pinecone
4
  import os
5
- from pinecone import Pinecone
6
- from langchain_community.vectorstores import Pinecone as PineconeVectorstore
7
 
8
- # LOAD ENVIRONMENT VARIABLES
 
 
 
 
9
  try:
10
  from dotenv import load_dotenv
11
  load_dotenv()
@@ -13,44 +13,136 @@ except:
13
  pass
14
 
15
 
16
-
17
-
18
- def get_pinecone_vectorstore(embeddings,text_key = "content", index_name = os.getenv("PINECONE_API_INDEX")):
19
-
20
- # # initialize pinecone
21
- # pinecone.init(
22
- # api_key=os.getenv("PINECONE_API_KEY"), # find at app.pinecone.io
23
- # environment=os.getenv("PINECONE_API_ENVIRONMENT"), # next to api key in console
24
- # )
25
-
26
- # index_name = os.getenv("PINECONE_API_INDEX")
27
- # vectorstore = Pinecone.from_existing_index(index_name, embeddings,text_key = text_key)
28
-
29
- # return vectorstore
30
-
31
- pc = Pinecone(api_key=os.getenv("PINECONE_API_KEY"))
32
- index = pc.Index(index_name)
33
-
34
- vectorstore = PineconeVectorstore(
35
- index, embeddings, text_key,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  )
37
- return vectorstore
38
-
39
-
40
-
41
- # def get_pinecone_retriever(vectorstore,k = 10,namespace = "vectors",sources = ["IPBES","IPCC"]):
42
-
43
- # assert isinstance(sources,list)
44
-
45
- # # Check if all elements in the list are either IPCC or IPBES
46
- # filter = {
47
- # "source": { "$in":sources},
48
- # }
49
-
50
- # retriever = vectorstore.as_retriever(search_kwargs={
51
- # "k": k,
52
- # "namespace":"vectors",
53
- # "filter":filter
54
- # })
55
 
56
- # return retriever
 
1
+ # Azure AI Search: https://python.langchain.com/docs/integrations/vectorstores/azuresearch
 
 
2
  import os
 
 
3
 
4
+ # Azure AI Search imports
5
+ from langchain_community.vectorstores.azuresearch import AzureSearch
6
+
7
+
8
+ # Load environment variables
9
  try:
10
  from dotenv import load_dotenv
11
  load_dotenv()
 
13
  pass
14
 
15
 
16
+ class AzureSearchWrapper:
17
+ """
18
+ Wrapper class for Azure AI Search vectorstore to handle filter conversion.
19
+
20
+ This wrapper automatically converts dictionary-style filters to Azure Search OData filter format,
21
+ ensuring seamless compatibility when switching from other providers.
22
+ """
23
+
24
+ def __init__(self, azure_search_vectorstore):
25
+ self.vectorstore = azure_search_vectorstore
26
+
27
+ def __getattr__(self, name):
28
+ """Delegate all other attributes to the wrapped vectorstore."""
29
+ return getattr(self.vectorstore, name)
30
+
31
+ def _convert_dict_filter_to_odata(self, filter_dict):
32
+ """
33
+ Convert dictionary-style filters to Azure Search OData filter format.
34
+
35
+ Args:
36
+ filter_dict (dict): Dictionary-style filter
37
+
38
+ Returns:
39
+ str: OData filter string
40
+ """
41
+ if not filter_dict:
42
+ return None
43
+
44
+ conditions = []
45
+
46
+ for key, value in filter_dict.items():
47
+ if key.endswith('_exclude'):
48
+ # Handle exclusion filters (e.g., report_type_exclude)
49
+ base_key = key.replace('_exclude', '')
50
+ if isinstance(value, list):
51
+ if len(value) == 1:
52
+ conditions.append(f"{base_key} ne '{value[0]}'")
53
+ else:
54
+ exclude_conditions = [f"{base_key} ne '{v}'" for v in value]
55
+ conditions.append(f"({' and '.join(exclude_conditions)})")
56
+ else:
57
+ conditions.append(f"{base_key} ne '{value}'")
58
+ elif isinstance(value, list):
59
+ # Handle list values (equivalent to $in operator)
60
+ if len(value) == 1:
61
+ conditions.append(f"{key} eq '{value[0]}'")
62
+ else:
63
+ list_conditions = [f"{key} eq '{v}'" for v in value]
64
+ conditions.append(f"({' or '.join(list_conditions)})")
65
+ else:
66
+ # Handle single values
67
+ conditions.append(f"{key} eq '{value}'")
68
+
69
+ return " and ".join(conditions) if conditions else None
70
+
71
+ def similarity_search_with_score(self, query, k=4, filter=None, **kwargs):
72
+ """Override similarity_search_with_score to convert filters."""
73
+ if filter is not None:
74
+ filter = self._convert_dict_filter_to_odata(filter)
75
+
76
+ return self.vectorstore.hybrid_search_with_score(
77
+ query=query, k=k, filters=filter, **kwargs
78
+ )
79
+
80
+
81
+ def similarity_search(self, query, k=4, filter=None, **kwargs):
82
+ """Override similarity_search to convert filters."""
83
+ if filter is not None:
84
+ filter = self._convert_dict_filter_to_odata(filter)
85
+
86
+ return self.vectorstore.similarity_search(
87
+ query=query, k=k, filter=filter, **kwargs
88
+ )
89
+
90
+ def similarity_search_by_vector(self, embedding, k=4, filter=None, **kwargs):
91
+ """Override similarity_search_by_vector to convert filters."""
92
+ if filter is not None:
93
+ filter = self._convert_dict_filter_to_odata(filter)
94
+
95
+ return self.vectorstore.similarity_search_by_vector(
96
+ embedding=embedding, k=k, filter=filter, **kwargs
97
+ )
98
+
99
+ def as_retriever(self, search_type="similarity", search_kwargs=None, **kwargs):
100
+ """Override as_retriever to handle filter conversion in search_kwargs."""
101
+ if search_kwargs and "filter" in search_kwargs:
102
+ # Convert the filter in search_kwargs
103
+ search_kwargs = search_kwargs.copy() # Don't modify the original
104
+ if search_kwargs["filter"] is not None:
105
+ search_kwargs["filter"] = self._convert_dict_filter_to_odata(search_kwargs["filter"])
106
+
107
+ return self.vectorstore.as_retriever(
108
+ search_type=search_type, search_kwargs=search_kwargs, **kwargs
109
+ )
110
+
111
+
112
+ def get_azure_search_vectorstore(embeddings, text_key="content", index_name=None):
113
+ """
114
+ Create an Azure AI Search vectorstore instance.
115
+
116
+ Args:
117
+ embeddings: The embeddings function to use
118
+ text_key: The key for text content in the payload (default: "content")
119
+ index_name: The name of the Azure Search index
120
+
121
+ Returns:
122
+ AzureSearchWrapper: A wrapped Azure AI Search vectorstore instance with filter compatibility
123
+ """
124
+ # Get Azure AI Search configuration from environment variables
125
+ azure_search_endpoint = os.getenv("AI_SEARCH_INDEX_ENDPOINT")
126
+ azure_search_key = os.getenv("AI_SEARCH_KEY")
127
+
128
+ if not azure_search_endpoint:
129
+ raise ValueError("AI_SEARCH_INDEX_ENDPOINT environment variable is required")
130
+
131
+ if not azure_search_key:
132
+ raise ValueError("AI_SEARCH_KEY environment variable is required")
133
+
134
+ if not index_name:
135
+ raise ValueError("index_name must be provided for Azure Search")
136
+
137
+ # Create Azure Search vectorstore
138
+ vectorstore = AzureSearch(
139
+ azure_search_endpoint=azure_search_endpoint,
140
+ azure_search_key=azure_search_key,
141
+ index_name=index_name,
142
+ embedding_function=embeddings.embed_query,
143
+ content_key=text_key,
144
  )
145
+
146
+ # Wrap the vectorstore to handle filter conversion
147
+ return AzureSearchWrapper(vectorstore)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
 
 
climateqa/utils.py CHANGED
@@ -25,7 +25,7 @@ def remove_duplicates_keep_highest_score(documents):
25
  unique_docs = {}
26
 
27
  for doc in documents:
28
- doc_id = doc.metadata.get('doc_id')
29
  if doc_id in unique_docs:
30
  if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
31
  unique_docs[doc_id] = doc
 
25
  unique_docs = {}
26
 
27
  for doc in documents:
28
+ doc_id = doc.metadata.get('id')
29
  if doc_id in unique_docs:
30
  if doc.metadata['reranking_score'] > unique_docs[doc_id].metadata['reranking_score']:
31
  unique_docs[doc_id] = doc
requirements.txt CHANGED
@@ -1,6 +1,9 @@
1
  gradio==5.0.2
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob==12.23.0
 
 
 
4
  python-dotenv==1.0.0
5
  langchain==0.2.1
6
  langchain_openai==0.1.7
 
1
  gradio==5.0.2
2
  azure-storage-file-share==12.11.1
3
  azure-storage-blob==12.23.0
4
+ # Azure AI Search support
5
+ azure-search-documents>=11.4.0
6
+ azure-core>=1.29.0
7
  python-dotenv==1.0.0
8
  langchain==0.2.1
9
  langchain_openai==0.1.7
sandbox/20241104 - CQA - StepByStep CQA.ipynb CHANGED
The diff for this file is too large to render. See raw diff