rk68 commited on
Commit
d0d09f7
·
verified ·
1 Parent(s): fa3fc07

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +330 -334
  2. retrievers.py +465 -0
  3. utils_code.py +212 -0
app.py CHANGED
@@ -1,352 +1,348 @@
1
- import logging
2
- import json
3
- import pandas as pd
4
  import streamlit as st
5
- from pinecone import Pinecone
6
- from llama_index.vector_stores.pinecone import PineconeVectorStore
7
- from llama_index.core import (
8
- StorageContext, VectorStoreIndex, SimpleDirectoryReader,
9
- get_response_synthesizer, Settings
10
- )
11
- from llama_index.core.node_parser import SentenceSplitter
12
- from llama_index.core.retrievers import (
13
- VectorIndexRetriever, RouterRetriever
14
- )
15
- from llama_index.retrievers.bm25 import BM25Retriever
16
- from llama_index.core.tools import RetrieverTool
17
- from llama_index.core.query_engine import (
18
- RetrieverQueryEngine, FLAREInstructQueryEngine, MultiStepQueryEngine
19
- )
20
- from llama_index.core.indices.query.query_transform import (
21
- StepDecomposeQueryTransform
22
- )
23
- from llama_index.llms.groq import Groq
24
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
25
- from llama_index.llms.azure_openai import AzureOpenAI
26
- from llama_index.embeddings.openai import OpenAIEmbedding
27
- from llama_index.readers.file import PyMuPDFReader
28
- import traceback
29
- from oauth2client.service_account import ServiceAccountCredentials
30
- import gspread
31
- import uuid
32
- from dotenv import load_dotenv
33
  import os
34
- from datetime import datetime
35
-
36
- # Load environment variables
37
- load_dotenv()
38
-
39
- # Configure logging
40
- logging.basicConfig(level=logging.INFO)
41
-
42
- # Google Sheets setup
43
- scope = ["https://spreadsheets.google.com/feeds", "https://www.googleapis.com/auth/drive"]
44
- creds_dict = {
45
- "type": os.getenv("type"),
46
- "project_id": os.getenv("project_id"),
47
- "private_key_id": os.getenv("private_key_id"),
48
- "private_key": os.getenv("private_key"),
49
- "client_email": os.getenv("client_email"),
50
- "client_id": os.getenv("client_id"),
51
- "auth_uri": os.getenv("auth_uri"),
52
- "token_uri": os.getenv("token_uri"),
53
- "auth_provider_x509_cert_url": os.getenv("auth_provider_x509_cert_url"),
54
- "client_x509_cert_url": os.getenv("client_x509_cert_url")
55
- }
56
- creds_dict['private_key'] = creds_dict['private_key'].replace('\\n', '\n')
57
- creds = ServiceAccountCredentials.from_json_keyfile_dict(creds_dict, scope)
58
- client = gspread.authorize(creds)
59
- sheet = client.open("RAG").sheet1
60
-
61
- # Fixed variables
62
- AZURE_DEPLOYMENT_NAME = os.getenv("AZURE_DEPLOYMENT_NAME")
63
- AZURE_API_VERSION = os.getenv("AZURE_API_VERSION")
64
- AZURE_OPENAI_ENDPOINT = os.getenv("AZURE_OPENAI_ENDPOINT")
65
-
66
- # Global variables for lazy loading
67
- llm = None
68
- pinecone_index = None
69
-
70
- def log_and_exit(message):
71
- logging.error(message)
72
- raise SystemExit(message)
73
-
74
- def initialize_apis(api, model, pinecone_api_key, groq_api_key, azure_api_key):
75
- global llm, pinecone_index
76
- try:
77
- if llm is None:
78
- llm = initialize_llm(api, model, groq_api_key, azure_api_key)
79
- if pinecone_index is None:
80
- pinecone_client = Pinecone(pinecone_api_key)
81
- pinecone_index = pinecone_client.Index("ll144")
82
- logging.info("Initialized LLM and Pinecone.")
83
- except Exception as e:
84
- log_and_exit(f"Error initializing APIs: {e}")
85
-
86
- def initialize_llm(api, model, groq_api_key, azure_api_key):
87
- if api == 'groq':
88
- model_mappings = {
89
- 'mixtral-8x7b': "mixtral-8x7b-32768",
90
- 'llama3-8b': "llama3-8b-8192",
91
- 'llama3-70b': "llama3-70b-8192",
92
- 'gemma-7b': "gemma-7b-it"
93
- }
94
- return Groq(model=model_mappings[model], api_key=groq_api_key)
95
- elif api == 'azure':
96
- if model == 'gpt35':
97
- return AzureOpenAI(
98
- deployment_name=AZURE_DEPLOYMENT_NAME,
99
- temperature=0,
100
- api_key=azure_api_key,
101
- azure_endpoint=AZURE_OPENAI_ENDPOINT,
102
- api_version=AZURE_API_VERSION
103
- )
104
 
105
- def load_pdf_data(chunk_size):
106
- reader = PyMuPDFReader()
107
- file_extractor = {".pdf": reader}
108
- documents = SimpleDirectoryReader(input_files=['LL144.pdf', 'LL144_Definitions.pdf'], file_extractor=file_extractor).load_data()
109
- return documents
110
-
111
- def create_index(documents, embedding_model_type="HF", embedding_model="BAAI/bge-large-en-v1.5", retriever_method="BM25", chunk_size=512):
112
- global llm, pinecone_index
113
- try:
114
- embed_model = select_embedding_model(embedding_model_type, embedding_model)
115
-
116
- Settings.llm = llm
117
- Settings.embed_model = embed_model
118
- Settings.chunk_size = chunk_size
119
-
120
- if retriever_method in ["BM25", "BM25+Vector"]:
121
- nodes = create_bm25_nodes(documents, chunk_size)
122
- logging.info("Created BM25 nodes from documents.")
123
- if retriever_method == "BM25+Vector":
124
- vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
125
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
126
- index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
127
- logging.info("Created index for BM25+Vector from documents.")
128
- return index, nodes
129
- return None, nodes
130
- else:
131
- vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
132
- storage_context = StorageContext.from_defaults(vector_store=vector_store)
133
- index = VectorStoreIndex.from_documents(documents, storage_context=storage_context)
134
- logging.info("Created index from documents.")
135
- return index, None
136
- except Exception as e:
137
- log_and_exit(f"Error creating index: {e}")
138
-
139
- def select_embedding_model(embedding_model_type, embedding_model):
140
- if embedding_model_type == "HF":
141
- return HuggingFaceEmbedding(model_name=embedding_model)
142
- elif embedding_model_type == "OAI":
143
- return OpenAIEmbedding() # Implement OAI Embedding if needed
144
-
145
- def create_bm25_nodes(documents, chunk_size):
146
- splitter = SentenceSplitter(chunk_size=chunk_size)
147
- nodes = splitter.get_nodes_from_documents(documents)
148
- return nodes
149
-
150
- def select_retriever(index, nodes, retriever_method, top_k):
151
- logging.info(f"Selecting retriever with method: {retriever_method}")
152
- if nodes is not None:
153
- logging.info(f"Available document IDs: {list(range(len(nodes)))}")
154
- else:
155
- logging.warning("Nodes are None")
156
 
157
- if retriever_method == 'BM25':
158
- return BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k)
159
- elif retriever_method == "BM25+Vector":
160
- if index is None:
161
- log_and_exit("Index must be initialized when using BM25+Vector retriever method.")
162
-
163
- bm25_retriever = RetrieverTool.from_defaults(
164
- retriever=BM25Retriever.from_defaults(nodes=nodes, similarity_top_k=top_k),
165
- description="BM25 Retriever"
166
- )
167
-
168
- vector_retriever = RetrieverTool.from_defaults(
169
- retriever=VectorIndexRetriever(index=index),
170
- description="Vector Retriever"
171
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
172
 
173
- router_retriever = RouterRetriever.from_defaults(
174
- retriever_tools=[bm25_retriever, vector_retriever],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  llm=llm,
176
- select_multi=True
 
 
 
 
 
 
 
 
177
  )
178
- return router_retriever
179
- elif retriever_method == "Vector Search":
180
- if index is None:
181
- log_and_exit("Index must be initialized when using Vector Search retriever method.")
182
- return VectorIndexRetriever(index=index, similarity_top_k=top_k)
183
  else:
184
- log_and_exit(f"Unsupported retriever method: {retriever_method}")
 
 
 
 
 
 
 
 
185
 
186
- def setup_query_engine(index, response_mode, nodes=None, query_engine_method=None, retriever_method=None, top_k=2):
187
- global llm
188
- try:
189
- logging.info(f"Setting up query engine with retriever_method: {retriever_method} and query_engine_method: {query_engine_method}")
190
- retriever = select_retriever(index, nodes, retriever_method, top_k)
191
-
192
- if retriever is None:
193
- log_and_exit("Failed to create retriever. Index or nodes might be None.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
- response_synthesizer = get_response_synthesizer(response_mode=response_mode)
196
- index_query_engine = index.as_query_engine(similarity_top_k=top_k) if index else None
 
197
 
198
- if query_engine_method == "FLARE":
199
- query_engine = FLAREInstructQueryEngine(
200
- query_engine=index_query_engine,
201
- max_iterations=4,
202
- verbose=False
203
- )
204
- elif query_engine_method == "MS":
205
- query_engine = MultiStepQueryEngine(
206
- query_engine=index_query_engine,
207
- query_transform=StepDecomposeQueryTransform(llm=llm, verbose=False),
208
- index_summary="Used to answer questions about the regulation"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  )
210
- else:
211
- query_engine = RetrieverQueryEngine(retriever=retriever, response_synthesizer=response_synthesizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
- if query_engine is None:
214
- log_and_exit("Failed to create query engine.")
215
 
216
- return query_engine
217
- except Exception as e:
218
- logging.error(f"Error setting up query engine: {e}")
219
- traceback.print_exc()
220
- log_and_exit(f"Error setting up query engine: {e}")
221
-
222
- def log_to_google_sheets(data):
223
- try:
224
- sheet.append_row(data)
225
- logging.info("Logged data to Google Sheets.")
226
- except Exception as e:
227
- logging.error(f"Error logging data to Google Sheets: {e}")
228
-
229
- def update_google_sheets(question_id, feedback=None, detailed_feedback=None, annotated_answer=None):
230
- try:
231
- existing_data = sheet.get_all_values()
232
- headers = existing_data[0]
233
- for i, row in enumerate(existing_data):
234
- if row[0] == question_id:
235
- if feedback is not None:
236
- sheet.update_cell(i+1, headers.index("Feedback") + 1, feedback)
237
- if detailed_feedback is not None:
238
- sheet.update_cell(i+1, headers.index("Detailed Feedback") + 1, detailed_feedback)
239
- if annotated_answer is not None:
240
- sheet.update_cell(i+1, headers.index("annotated_answer") + 1, annotated_answer)
241
- logging.info("Updated data in Google Sheets.")
242
- return
243
- except Exception as e:
244
- logging.error(f"Error updating data in Google Sheets: {e}")
245
-
246
- def run_streamlit_app():
247
- if 'query_engine' not in st.session_state:
248
- st.session_state.query_engine = None
249
-
250
- st.title("RAG Chat Application")
251
-
252
- col1, col2 = st.columns(2)
253
-
254
- with col1:
255
- pinecone_api_key = st.text_input("Pinecone API Key")
256
- azure_api_key = st.text_input("Azure API Key")
257
- groq_api_key = st.text_input("Groq API Key")
258
-
259
- def update_api_based_on_model():
260
- selected_model = st.session_state['selected_model']
261
- if selected_model == 'gpt35':
262
- st.session_state['selected_api'] = 'azure'
263
- else:
264
- st.session_state['selected_api'] = 'groq'
265
-
266
- with col2:
267
- selected_model = st.selectbox("Select Model", ["llama3-8b", "llama3-70b", "mixtral-8x7b", "gemma-7b", "gpt35"], index=4, key='selected_model', on_change=update_api_based_on_model)
268
- selected_api = st.selectbox("Select API", ["azure", "groq"], index=0, key='selected_api', disabled=True)
269
- embedding_model_type = "HF"
270
- embedding_model = st.selectbox("Select Embedding Model", ["BAAI/bge-large-en-v1.5", "other_model"])
271
- retriever_method = st.selectbox("Select Retriever Method", ["Vector Search", "BM25", "BM25+Vector"])
272
-
273
- col3, col4 = st.columns(2)
274
- with col3:
275
- chunk_size = st.selectbox("Select Chunk Size", [128, 256, 512, 1024], index=2)
276
- with col4:
277
- top_k = st.selectbox("Select Top K", [1, 2, 3, 5, 6], index=1)
278
-
279
- if st.button("Initialize"):
280
- initialize_apis(st.session_state['selected_api'], selected_model, pinecone_api_key, groq_api_key, azure_api_key)
281
- documents = load_pdf_data(chunk_size)
282
- index, nodes = create_index(documents, embedding_model_type=embedding_model_type, embedding_model=embedding_model, retriever_method=retriever_method, chunk_size=chunk_size)
283
- st.session_state.query_engine = setup_query_engine(index, response_mode="compact", nodes=nodes, query_engine_method=None, retriever_method=retriever_method, top_k=top_k)
284
- st.success("Initialization complete.")
285
-
286
- if 'chat_history' not in st.session_state:
287
- st.session_state.chat_history = []
288
-
289
- for chat_index, chat in enumerate(st.session_state.chat_history):
290
- with st.chat_message("user"):
291
- st.markdown(chat['user'])
292
- with st.chat_message("bot"):
293
- st.markdown("### Retrieved Contexts")
294
- for node in chat.get('contexts', []):
295
- st.markdown(
296
- f"<div style='border:1px solid #ccc; padding:10px; margin:10px 0; font-size:small;'>{node.text}</div>",
297
- unsafe_allow_html=True
298
- )
299
- st.markdown("### Answer")
300
- st.markdown(chat['response'])
301
-
302
- col1, col2 = st.columns([1, 1])
303
- with col1:
304
- if st.button("Annotate 👎", key=f"annotate_{chat_index}"):
305
- chat['annotate'] = True
306
- chat['feedback'] = -1
307
- st.session_state.chat_history[chat_index] = chat
308
- update_google_sheets(chat['id'], feedback=-1)
309
- st.rerun()
310
- with col2:
311
- if st.button("Approve 👍", key=f"approve_{chat_index}"):
312
- chat['approved'] = True
313
- chat['feedback'] = 1
314
- st.session_state.chat_history[chat_index] = chat
315
- update_google_sheets(chat['id'], feedback=1, annotated_answer=chat['response'])
316
-
317
- if chat.get('annotate', False):
318
- annotated_answer = st.text_area("Annotate Answer", value=chat['response'], key=f"annotate_text_{chat_index}")
319
- if st.button("Submit Annotated Answer", key=f"submit_annotate_{chat_index}"):
320
- chat['annotated_answer'] = annotated_answer
321
- chat['annotate'] = False
322
- st.session_state.chat_history[chat_index] = chat
323
- update_google_sheets(chat['id'], annotated_answer=annotated_answer)
324
-
325
- feedback = st.text_area("How was the response? Does it match the context? Does it answer the question fully?", key=f"textarea_{chat_index}")
326
- if st.button("Submit Feedback", key=f"submit_{chat_index}"):
327
- chat['detailed_feedback'] = feedback
328
- st.session_state.chat_history[chat_index] = chat
329
- update_google_sheets(chat['id'], detailed_feedback=feedback)
330
-
331
- if question := st.chat_input("Enter your question"):
332
- if st.session_state.query_engine:
333
- with st.spinner('Generating response...'):
334
- # Compile chat history for context
335
- history = "\n".join([f"Q: {chat['user']}\nA: {chat['response']}" for chat in st.session_state.chat_history])
336
- full_query = f"{history}\nQ: {question}"
337
- response = st.session_state.query_engine.query(full_query)
338
- logging.info(f"Generated response: {response.response}")
339
- logging.info(f"Retrieved contexts: {[node.text for node in response.source_nodes]}")
340
- question_id = str(uuid.uuid4())
341
- timestamp = datetime.now().isoformat()
342
- st.session_state.chat_history.append({'id': question_id, 'user': question, 'response': response.response, 'contexts': response.source_nodes, 'feedback': 0, 'detailed_feedback': '', 'annotated_answer': '', 'timestamp': timestamp})
343
-
344
- # Log initial query and response to Google Sheets without feedback
345
- log_to_google_sheets([question_id, question, response.response, st.session_state['selected_api'], selected_model, embedding_model, retriever_method, chunk_size, top_k, 0, "", "", timestamp])
346
 
347
- st.rerun()
348
- else:
349
- st.error("Query engine is not initialized. Please initialize it first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
350
 
351
  if __name__ == "__main__":
352
- run_streamlit_app()
 
 
 
 
1
  import streamlit as st
2
+ from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer
3
+ from retrievers import PARetriever
4
+ from utils_code import create_chat_engine
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ from llama_index.core import Settings
 
 
 
 
 
 
 
7
  import os
8
+ from llama_index.llms.azure_openai import AzureOpenAI
9
+ from dotenv import load_dotenv, find_dotenv
10
+ from retrievers import HyPARetriever, PARetriever
11
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
12
+ from llama_index.core import VectorStoreIndex
13
+ from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
14
+ from llama_index.core import PropertyGraphIndex
15
+ from llama_index.core.vector_stores import MetadataFilter, MetadataFilters, FilterOperator
16
+ from llama_index.retrievers.bm25 import BM25Retriever
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ # Load environment variables from the .env file
19
+ dotenv_path = find_dotenv()
20
+ #print(f"Dotenv Path: {dotenv_path}")
21
+ load_dotenv(dotenv_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+
24
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
25
+ Settings.embed_model = embed_model
26
+
27
+ # Set Azure OpenAI keys for Giskard if needed
28
+ #os.environ["AZURE_OPENAI_API_KEY"] = os.getenv("GSK_AZURE_OPENAI_API_KEY")
29
+ #os.environ["AZURE_OPENAI_ENDPOINT"] = os.getenv("GSK_AZURE_OPENAI_ENDPOINT")
30
+ os.environ["GSK_LLM_MODEL"] = "gpt-4o-mini"
31
+
32
+ # Pinecone and Neo4j credentials
33
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
34
+ ll144_index_name = 'll144'
35
+ euaiact_index_name = 'euaiact'
36
+
37
+ # Initialize Pinecone
38
+ from pinecone import Pinecone
39
+ pc = Pinecone(api_key=pinecone_api_key)
40
+
41
+
42
+ def metadata_filter(corpus_name):
43
+
44
+ if corpus_name == "EUAIACT":
45
+
46
+ # Filter for 'EUAIACT.pdf'
47
+ filter = MetadataFilters(filters=[MetadataFilter(key="filepath", value="'EUAIACT.pdf'", operator=FilterOperator.CONTAINS)])
48
+
49
+ elif corpus_name == "LL144":
50
+ # Filter for 'LLL144.pdf' or 'LL144_Definitions.pdf'
51
+ filter = MetadataFilters(filters=[
52
+ MetadataFilter(key="filepath", value="'LL144.pdf'", operator=FilterOperator.CONTAINS),
53
+ MetadataFilter(key="filepath", value="'LL144_Definitions.pdf'", operator=FilterOperator.CONTAINS)
54
+ ])
55
+
56
+ return filter
57
 
58
+
59
+ # Load vector index
60
+ #@st.cache_data(ttl=None, persist=None)
61
+ def load_vector_index(corpus_name):
62
+ if corpus_name == "LL144":
63
+ pinecone_index = pc.Index(ll144_index_name)
64
+ elif corpus_name == "EUAIACT":
65
+ pinecone_index = pc.Index(euaiact_index_name)
66
+
67
+ vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
68
+ vector_index = VectorStoreIndex.from_vector_store(vector_store)
69
+
70
+ return vector_index
71
+
72
+ # Load property graph index
73
+ #@st.cache_data(ttl=None, persist=None)
74
+ def load_pg_index():
75
+ neo4j_username = os.getenv("NEO4J_USERNAME")
76
+ neo4j_password = os.getenv("NEO4J_PASSWORD")
77
+ neo4j_url = os.getenv("NEO4J_URI")
78
+
79
+ graph_store = Neo4jPropertyGraphStore(username=neo4j_username, password=neo4j_password, url=neo4j_url)
80
+ pg_index = PropertyGraphIndex.from_existing(property_graph_store=graph_store)
81
+ return pg_index
82
+
83
+ # Initialize the retriever (HyPA or PA)
84
+ def init_retriever(retriever_type, corpus_name, use_reranker, use_rewriter, classifier_model):
85
+ # Check if vector index is cached, if not, load it
86
+ if "vector_index" not in st.session_state:
87
+ st.session_state.vector_index = load_vector_index(corpus_name)
88
+
89
+ # Check if property graph index is cached, if not, load it
90
+ if "pg_index" not in st.session_state:
91
+ st.session_state.pg_index = load_pg_index()
92
+
93
+ vector_index = st.session_state.vector_index
94
+ graph_index = st.session_state.pg_index
95
+ llm = st.session_state.llm
96
+
97
+ filter = metadata_filter(corpus_name=corpus_name)
98
+ # Set the reranker model if selected
99
+ reranker_model_name = "BAAI/bge-reranker-large" if use_reranker else None
100
+
101
+ # Choose the appropriate retriever based on user selection
102
+ if retriever_type == "HyPA":
103
+ retriever = HyPARetriever(
104
  llm=llm,
105
+ vector_retriever=vector_index.as_retriever(similarity_top_k=10),
106
+ bm25_retriever=None,#BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10),
107
+ kg_index=graph_index, # Include KG for HyPA
108
+ rewriter=use_rewriter, # Set rewriter option
109
+ classifier_model=classifier_model, # Use the selected classifier model
110
+ verbose=False,
111
+ property_index=True, # Use property graph index
112
+ reranker_model_name=reranker_model_name, # Use reranker if selected
113
+ pg_filters=filter
114
  )
 
 
 
 
 
115
  else:
116
+ retriever = PARetriever(
117
+ llm=llm,
118
+ vector_retriever=vector_index.as_retriever(similarity_top_k=10),
119
+ bm25_retriever=None,#BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10),
120
+ rewriter=use_rewriter, # Set rewriter option
121
+ classifier_model=classifier_model, # Use the selected classifier model
122
+ verbose=False,
123
+ reranker_model_name=reranker_model_name # Use reranker if selected
124
+ )
125
 
126
+ memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
127
+ chat_engine = create_chat_engine(retriever=retriever, memory=memory, llm=llm)
128
+ st.session_state.chat_engine = chat_engine
129
+ #return chat_engine
130
+
131
+
132
+ def process_query(query):
133
+ """Processes the input query and displays it along with the response in the main chat area."""
134
+ # Append the user query to the message history and display it
135
+ st.session_state.messages.append({"role": "user", "content": query})
136
+ with st.chat_message("user"):
137
+ st.write(query)
138
+
139
+ # Ensure the chat engine is initialized
140
+ chat_engine = st.session_state.get('chat_engine', None)
141
+ if chat_engine:
142
+ # Process the query through the chat engine
143
+ with st.chat_message("assistant"):
144
+ with st.spinner("Retrieving Knowledge..."):
145
+ response = chat_engine.stream_chat(query)
146
+ response_str = ""
147
+ response_container = st.empty()
148
+ for token in response.response_gen:
149
+ response_str += token
150
+ response_container.write(response_str)
151
+ # Append the assistant's response to the message history
152
+ st.session_state.messages.append({"role": "assistant", "content": response_str})
153
+
154
+ # Expander for additional info
155
+ with st.expander("Source Nodes"):
156
+ # Display source nodes
157
+ if hasattr(response, 'source_nodes') and response.source_nodes:
158
+
159
+ for idx, node in enumerate(response.source_nodes):
160
+ st.markdown(f"#### Source Node {idx + 1}")
161
+ st.write(f"**Node ID:** {node.node_id}")
162
+ st.write(f"**Node Score:** {node.score}")
163
+
164
+ st.write("**Metadata:**")
165
+ for key, value in node.metadata.items():
166
+ st.write(f"- **{key}:** {value}")
167
+
168
+ st.write("**Content:**")
169
+ st.write(node.node.get_content())
170
+
171
+ # Add a horizontal line to separate nodes
172
+ st.markdown("---")
173
+ else:
174
+ st.write("No additional source nodes available.")
175
+
176
+ st.session_state.messages.append({"role": "assistant", "content": str(response)})
177
+
178
+
179
+
180
+
181
+
182
+ # Streamlit App
183
+ def main():
184
+
185
+
186
+ # Sidebar for retriever options
187
+ with st.sidebar:
188
+ st.image('holisticai.svg', use_column_width=True)
189
+ st.title("Retriever Settings")
190
+
191
+ # Azure OpenAI credentials input fields (start with blank fields)
192
+ azure_api_key = st.text_input("Azure OpenAI API Key", value="", type="password")
193
+ azure_endpoint = st.text_input("Azure OpenAI Endpoint", value="", type="password")
194
+
195
+ llm_model_choice = st.selectbox("Select LLM Model", ["gpt-4o-mini", "gpt35"])
196
+
197
+ # Let the user make selections without updating session state yet
198
+ retriever_type = st.selectbox("Select Retriever Method", ["PA", "HyPA"])
199
+ corpus_name = st.selectbox("Select Corpus", ["LL144", "EUAIACT"])
200
+ temperature = st.slider("Set LLM Temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
201
+
202
+ # Display a red warning about non-zero temperature
203
+ if temperature > 0:
204
+ st.markdown(
205
+ "<p style='color:red;'>Warning: A non-zero temperature may lead to hallucinations in the generated responses.</p>",
206
+ unsafe_allow_html=True
207
+ )
208
 
209
+ # Checkboxes for reranker and rewriter options
210
+ use_reranker = st.checkbox("Use Reranker")
211
+ use_rewriter = st.checkbox("Use Rewriter")
212
 
213
+ # Radio buttons for classifier model
214
+ classifier_type = st.radio("Select Classifier Type", ["2-Class", "3-Class"])
215
+ classifier_model = "rk68/distilbert-q-classifier-2" if classifier_type == "2-Class" else "rk68/distilbert-q-classifier-3"
216
+
217
+
218
+
219
+ # When the user clicks "Initialize", store everything in session state
220
+ if st.button("Initialize"):
221
+ st.session_state.retriever_type = retriever_type
222
+ st.session_state.corpus_name = corpus_name
223
+ st.session_state.temperature = temperature
224
+ st.session_state.use_reranker = use_reranker
225
+ st.session_state.use_rewriter = use_rewriter
226
+ st.session_state.classifier_type = classifier_type
227
+ st.session_state.classifier_model = classifier_model
228
+
229
+ # Store the user inputs in session state
230
+ st.session_state.azure_api_key = azure_api_key
231
+ st.session_state.azure_endpoint = azure_endpoint
232
+
233
+ # Set the environment variables from user inputs
234
+ os.environ["AZURE_OPENAI_API_KEY"] = azure_api_key
235
+ os.environ["AZURE_OPENAI_ENDPOINT"] = azure_endpoint
236
+
237
+ llm = AzureOpenAI(
238
+ deployment_name=llm_model_choice, temperature=temperature,
239
+ api_key=azure_api_key, azure_endpoint=azure_endpoint,
240
+ api_version=os.getenv("AZURE_API_VERSION")
241
  )
242
+ Settings.llm = llm
243
+ st.session_state.llm = llm
244
+
245
+ # Initialize retriever after storing the settings
246
+ init_retriever(retriever_type, corpus_name, use_reranker, use_rewriter, classifier_model)
247
+ st.success("Retriever Initialized")
248
+
249
+ # Example questions based on selected corpus
250
+ st.markdown("### Example Queries")
251
+ # Example questions with unique button handling
252
+ example_questions = {
253
+ "LL144": [
254
+ "What is a bias audit?",
255
+ "When does it come into effect?",
256
+ "Summarise Local Law 144"
257
+ ],
258
+ "EUAIACT": [
259
+ "What is an AI system?",
260
+ "What are the key takeaways?",
261
+ "Explain the key provisions of EUAIACT."
262
+ ]
263
+ }
264
+
265
+
266
+ # Display buttons for the example queries
267
+ for idx, question in enumerate(example_questions.get(corpus_name, [])):
268
+ if st.button(f"{question} [{idx}]"):
269
+ process_query(question)
270
+
271
+
272
+
273
+
274
 
275
+ # Add a disclaimer at the bottom
276
+ st.markdown("---") # Horizontal line for separation
277
 
278
+ st.markdown(
279
+ """
280
+ <p style="color:grey; font-size:12px;">
281
+ <strong>Disclaimer:</strong> This system is an academic prototype demonstration of our hybrid parameter-adaptive retrieval-augmented generation system. It is <strong>NOT</strong> a production-ready application. All outputs should be considered experimental and may not be fully accurate. This system should not be used for making important legal decisions. For complete, specific, and tailored legal advice, please consult a licensed legal professional.<br><br>
282
+ </p>
283
+ """,
284
+ unsafe_allow_html=True
285
+ )
286
+
287
+
288
+
289
+ # Check if the retriever is initialized
290
+ if "chat_engine" in st.session_state:
291
+ chat_engine = st.session_state.chat_engine
292
+ else:
293
+ st.warning("Please initialize the retriever from the sidebar.")
294
+
295
+
296
+ # Initialize session state for chat messages
297
+ if "messages" not in st.session_state:
298
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you?"}]
299
+
300
+ # Display chat messages
301
+ for message in st.session_state.messages:
302
+ with st.chat_message(message["role"]):
303
+ st.write(message["content"])
304
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
+
307
+ # User-provided prompt
308
+ if prompt := st.chat_input():
309
+ st.session_state.messages.append({"role": "user", "content": prompt})
310
+ with st.chat_message("user"):
311
+ st.write(prompt)
312
+
313
+ # Generate a response if the last message is from the user
314
+ if st.session_state.messages[-1]["role"] == "user":
315
+ with st.chat_message("assistant"):
316
+ with st.spinner("Retrieving Knowledge..."):
317
+ response = chat_engine.stream_chat(prompt)
318
+ response_str = ""
319
+ response_container = st.empty()
320
+ for token in response.response_gen:
321
+ response_str += token
322
+ response_container.write(response_str)
323
+ # Expander for additional info
324
+ with st.expander("Source Nodes"):
325
+ # Display source nodes
326
+ if hasattr(response, 'source_nodes') and response.source_nodes:
327
+
328
+ for idx, node in enumerate(response.source_nodes):
329
+ st.markdown(f"#### Source Node {idx + 1}")
330
+ st.write(f"**Node ID:** {node.node_id}")
331
+ st.write(f"**Node Score:** {node.score}")
332
+
333
+ st.write("**Metadata:**")
334
+ for key, value in node.metadata.items():
335
+ st.write(f"- **{key}:** {value}")
336
+
337
+ st.write("**Content:**")
338
+ st.write(node.node.get_content())
339
+
340
+ # Add a horizontal line to separate nodes
341
+ st.markdown("---")
342
+ else:
343
+ st.write("No additional source nodes available.")
344
+
345
+ st.session_state.messages.append({"role": "assistant", "content": str(response)})
346
 
347
  if __name__ == "__main__":
348
+ main()
retrievers.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompts import get_classification_prompt, get_query_generation_prompt
2
+ from utils_code import initialize_openai_creds, create_llm
3
+ from llama_index.core.schema import QueryBundle, NodeWithScore
4
+ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
5
+ from transformers import pipeline
6
+ from typing import List, Optional
7
+ import asyncio
8
+ from llama_index.core.postprocessor import SentenceTransformerRerank
9
+ from llama_index.core.indices.property_graph import LLMSynonymRetriever
10
+ from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever
11
+ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
12
+ import os
13
+
14
+
15
+ class PARetriever(BaseRetriever):
16
+ """Custom retriever that performs query rewriting, Vector search, and BM25 search without Knowledge Graph search."""
17
+
18
+ def __init__(
19
+ self,
20
+ llm, # LLM for query generation
21
+ vector_retriever: Optional[VectorIndexRetriever] = None,
22
+ bm25_retriever: Optional[BaseRetriever] = None,
23
+ mode: str = "OR",
24
+ rewriter: bool = True,
25
+ classifier_model: Optional[str] = None, # Optional classifier model
26
+ device: str = 'mps', # Set to 'mps' as the default device
27
+ reranker_model_name: Optional[str] = None, # Model name for SentenceTransformerRerank
28
+ verbose: bool = False, # Verbose flag
29
+ fixed_params: Optional[dict] = None, # New parameter to pass in fixed parameters
30
+ categories_list: Optional[List[str]] = None, # List of categories for query classification
31
+ param_mappings: Optional[dict] = None # Custom parameter mappings based on classifier labels
32
+ ) -> None:
33
+ """Initialize PARetriever parameters."""
34
+ self._vector_retriever = vector_retriever
35
+ self._bm25_retriever = bm25_retriever
36
+ self._llm = llm
37
+ self._rewriter = rewriter
38
+ self._mode = mode
39
+ self._reranker_model_name = reranker_model_name
40
+ self._reranker = None # Initialize reranker as None
41
+ self.verbose = verbose
42
+ self.fixed_params = fixed_params
43
+ self.categories_list = categories_list
44
+ self.param_mappings = param_mappings or {
45
+ "label_0": {"top_k": 5, "max_keywords_per_query": 3, "max_knowledge_sequence": 1},
46
+ "label_1": {"top_k": 7, "max_keywords_per_query": 4, "max_knowledge_sequence": 2},
47
+ "label_2": {"top_k": 10, "max_keywords_per_query": 5, "max_knowledge_sequence": 3}
48
+ }
49
+
50
+ # Initialize the classifier if provided
51
+ self.classifier = None
52
+ if classifier_model:
53
+ self.classifier = pipeline("text-classification", model=classifier_model, device=device)
54
+
55
+ if mode not in ("AND", "OR"):
56
+ raise ValueError("Invalid mode.")
57
+
58
+ def classify_query_and_get_params(self, query: str) -> (str, dict):
59
+ """Classify the query and determine adaptive parameters or use fixed parameters."""
60
+ if self.fixed_params:
61
+ # Use fixed parameters from the dictionary if provided
62
+ params = self.fixed_params
63
+ classification_result = "Fixed"
64
+ if self.verbose:
65
+ print(f"Using fixed parameters: {params}")
66
+ else:
67
+ params = {
68
+ "top_k": 5, # Default top-k
69
+ "max_keywords_per_query": 4, # Default max keywords
70
+ "max_knowledge_sequence": 2 # Default max knowledge sequence
71
+ }
72
+ classification_result = None
73
+
74
+ if self.classifier:
75
+ classification = self.classifier(query)[0]
76
+ label = classification['label'] # Get the classification label directly
77
+ classification_result = label # Store the classification result
78
+ if self.verbose:
79
+ print(f"Query Classification: {classification['label']} with score {classification['score']}")
80
+
81
+ # Use custom mappings or default mappings
82
+ if label in self.param_mappings:
83
+ params = self.param_mappings[label]
84
+ else:
85
+ if self.verbose:
86
+ print(f"Warning: No mapping found for label {label}, using default parameters.")
87
+
88
+ self._classification_result = classification_result
89
+ return classification_result, params
90
+
91
+ def classify_query(self, query_str: str) -> Optional[str]:
92
+ """Classify the query into one of the predefined categories using LLM, or skip if no categories are provided."""
93
+ if not self.categories_list:
94
+ if self.verbose:
95
+ print("No categories provided, skipping query classification.")
96
+ return None
97
+
98
+ # Generate the classification prompt using external function
99
+ classification_prompt = get_classification_prompt(self.categories_list) + f" Query: '{query_str}'"
100
+
101
+ response = self._llm.complete(classification_prompt)
102
+ category = response.text.strip()
103
+
104
+ # Return the category only if it's in the categories list
105
+ return category if category in self.categories_list else None
106
+
107
+ def generate_queries(self, query_str: str, category: Optional[str], num_queries: int = 3) -> List[str]:
108
+ """Generate query variations using the LLM, taking into account the category if applicable."""
109
+
110
+ # Generate query generation prompt using external function
111
+ query_gen_prompt = get_query_generation_prompt(query_str, num_queries)
112
+
113
+ response = self._llm.complete(query_gen_prompt)
114
+ queries = response.text.split("\n")
115
+
116
+ queries = [query.strip() for query in queries if query.strip()]
117
+
118
+ if category:
119
+ category_query = f"{category}"
120
+ queries.append(category_query)
121
+
122
+ return queries
123
+
124
+ async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
125
+ """Run queries against retrievers."""
126
+ tasks = []
127
+ for query in queries:
128
+ for retriever in retrievers:
129
+ tasks.append(retriever.aretrieve(query))
130
+
131
+ task_results = await asyncio.gather(*tasks)
132
+
133
+ results_dict = {}
134
+ for i, (query, query_result) in enumerate(zip(queries, task_results)):
135
+ results_dict[(query, i)] = query_result
136
+ return results_dict
137
+
138
+ def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]:
139
+ """Fuse results from Vector and BM25 retrievers."""
140
+ k = 60.0 # `k` is a parameter used to control the impact of outlier rankings.
141
+ fused_scores = {}
142
+ text_to_node = {}
143
+
144
+ for nodes_with_scores in results_dict.values():
145
+ for rank, node_with_score in enumerate(
146
+ sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
147
+ ):
148
+ text = node_with_score.node.get_content()
149
+ text_to_node[text] = node_with_score
150
+ if text not in fused_scores:
151
+ fused_scores[text] = 0.0
152
+ fused_scores[text] += 1.0 / (rank + k)
153
+
154
+ reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))
155
+
156
+ reranked_nodes: List[NodeWithScore] = []
157
+ for text, score in reranked_results.items():
158
+ if text in text_to_node:
159
+ node = text_to_node[text]
160
+ node.score = score
161
+ reranked_nodes.append(node)
162
+ else:
163
+ if self.verbose:
164
+ print(f"Warning: Text not found in `text_to_node`: {text}")
165
+
166
+ return reranked_nodes[:similarity_top_k]
167
+
168
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
169
+ """Retrieve nodes given query."""
170
+ if self._rewriter:
171
+ category = self.classify_query(query_bundle.query_str)
172
+ if self.verbose and category:
173
+ print(f"Classified Category: {category}")
174
+
175
+ classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
176
+ self._classification_result = classification_result
177
+
178
+ top_k = params["top_k"]
179
+
180
+ if self._reranker_model_name:
181
+ self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
182
+ if self.verbose:
183
+ print(f"Initialized reranker with top_n: {top_k}")
184
+
185
+ num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
186
+ if self.verbose:
187
+ print(f"Number of Query Rewrites: {num_queries}")
188
+
189
+ if self._rewriter:
190
+ queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
191
+ if self.verbose:
192
+ print(f"Generated Queries: {queries}")
193
+ else:
194
+ queries = [query_bundle.query_str]
195
+
196
+ active_retrievers = []
197
+ if self._vector_retriever:
198
+ active_retrievers.append(self._vector_retriever)
199
+ if self._bm25_retriever:
200
+ active_retrievers.append(self._bm25_retriever)
201
+
202
+ if not active_retrievers:
203
+ raise ValueError("No active retriever provided!")
204
+
205
+ results = {}
206
+ if active_retrievers:
207
+ results = asyncio.run(self.run_queries(queries, active_retrievers))
208
+ if self.verbose:
209
+ print(f"Fusion Results: {results}")
210
+
211
+ final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)
212
+
213
+ if self._reranker:
214
+ final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
215
+ if self.verbose:
216
+ print(f"Reranked Results: {final_results}")
217
+ else:
218
+ final_results = final_results[:top_k]
219
+
220
+ if self._rewriter:
221
+ unique_nodes = {}
222
+ for node in final_results:
223
+ content = node.node.get_content()
224
+ if content not in unique_nodes:
225
+ unique_nodes[content] = node
226
+ final_results = list(unique_nodes.values())
227
+
228
+ if self.verbose:
229
+ print(f"Final Results: {final_results}")
230
+
231
+ return final_results
232
+
233
+ def get_classification_result(self) -> str:
234
+ return getattr(self, "_classification_result", None)
235
+
236
+
237
+ class HyPARetriever(PARetriever):
238
+ """Custom retriever that extends PARetriever with knowledge graph (KG) search."""
239
+
240
+ def __init__(
241
+ self,
242
+ llm, # LLM for query generation
243
+ vector_retriever: Optional[VectorIndexRetriever] = None,
244
+ bm25_retriever: Optional[BaseRetriever] = None,
245
+ kg_index=None, # Pass the knowledge graph index
246
+ property_index: bool = True, # Whether to use the property graph for retrieval
247
+ pg_filters=None,
248
+ **kwargs, # Pass any additional arguments to PARetriever
249
+ ):
250
+ # Initialize PARetriever to reuse all its functionality
251
+ super().__init__(
252
+ llm=llm,
253
+ vector_retriever=vector_retriever,
254
+ bm25_retriever=bm25_retriever,
255
+ **kwargs
256
+ )
257
+
258
+ # Initialize knowledge graph (KG) specific components
259
+ self._pg_filters = pg_filters
260
+ self._kg_index = kg_index
261
+ self.property_index = property_index
262
+
263
+ def _initialize_kg_retriever(self, params):
264
+ """Initialize the KG retriever based on retrieval mode."""
265
+ graph_index = self._kg_index
266
+ filters = self._pg_filters
267
+
268
+ if self._kg_index and not self.property_index:
269
+ # If not using property index, use KGTableRetriever
270
+ return KGTableRetriever(
271
+ index=self._kg_index,
272
+ retriever_mode='hybrid',
273
+ max_keywords_per_query=params["max_keywords_per_query"],
274
+ max_knowledge_sequence=params["max_knowledge_sequence"]
275
+ )
276
+
277
+ elif self._kg_index and self.property_index:
278
+ # If using property index, use the simpler graph index retriever
279
+ # Use this for the DEMO
280
+
281
+ vector_retriever = VectorContextRetriever(
282
+ graph_store=graph_index.property_graph_store,
283
+ similarity_top_k=params["max_keywords_per_query"],
284
+ path_depth=params["max_knowledge_sequence"],
285
+ include_text=True,
286
+ filters=filters
287
+ )
288
+ synonym_retriever = LLMSynonymRetriever(
289
+ graph_store=graph_index.property_graph_store,
290
+ llm=self._llm,
291
+ include_text=True,
292
+ filters=filters
293
+ )
294
+ return graph_index.as_retriever(sub_retrievers=[vector_retriever, synonym_retriever])
295
+ #return graph_index.as_retriever(similarity_top_k=params["top_k"])
296
+
297
+ return None
298
+
299
+ def _combine_with_kg_results(self, vector_bm25_results, kg_results):
300
+ """Combine KG results with vector and BM25 results."""
301
+ vector_ids = {n.node.id_ for n in vector_bm25_results}
302
+ kg_ids = {n.node.id_ for n in kg_results}
303
+
304
+ combined_results = {n.node.id_: n for n in vector_bm25_results}
305
+ combined_results.update({n.node.id_: n for n in kg_results})
306
+
307
+ if self._mode == "AND":
308
+ result_ids = vector_ids.intersection(kg_ids)
309
+ else:
310
+ result_ids = vector_ids.union(kg_ids)
311
+
312
+ return [combined_results[rid] for rid in result_ids]
313
+
314
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
315
+ """Retrieve nodes with KG integration."""
316
+ # Call PARetriever's _retrieve to get the vector and BM25 results
317
+ final_results = super()._retrieve(query_bundle)
318
+
319
+ # If we have a KG index, initialize the retriever
320
+ if self._kg_index:
321
+ kg_retriever = self._initialize_kg_retriever(self.classify_query_and_get_params(query_bundle.query_str)[1])
322
+
323
+ if kg_retriever:
324
+ kg_nodes = kg_retriever.retrieve(query_bundle)
325
+
326
+ # Only combine KG and vector/BM25 results if property_index is True
327
+ if self.property_index:
328
+ final_results = self._combine_with_kg_results(final_results, kg_nodes)
329
+
330
+ return final_results
331
+
332
+
333
+
334
+ import os
335
+ from dotenv import load_dotenv
336
+ from llama_index.llms.azure_openai import AzureOpenAI
337
+ from llama_index.core import VectorStoreIndex, Settings
338
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
339
+ from llama_index.core.node_parser import SentenceSplitter
340
+ from llama_index.core.retrievers import KGTableRetriever, VectorIndexRetriever
341
+ from llama_index.retrievers.bm25 import BM25Retriever
342
+ from llama_index.readers.file import PyMuPDFReader
343
+ from llama_index.core.chat_engine import ContextChatEngine
344
+ from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer
345
+ from llama_index.core import KnowledgeGraphIndex
346
+ from retrievers import PARetriever, HyPARetriever
347
+
348
+
349
+ def load_documents():
350
+ """Load and return documents from specified file paths."""
351
+ loader = PyMuPDFReader()
352
+ documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf")
353
+ documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf")
354
+ return documents1 + documents2
355
+
356
+ def create_indices(documents, llm, embed_model):
357
+ """Create and return VectorStoreIndex and KnowledgeGraphIndex from documents."""
358
+ splitter = SentenceSplitter(chunk_size=512)
359
+
360
+ vector_index = VectorStoreIndex.from_documents(
361
+ documents,
362
+ embed_model=embed_model,
363
+ transformations=[splitter]
364
+ )
365
+
366
+ """graph_index = KnowledgeGraphIndex.from_documents(
367
+ documents,
368
+ max_triplets_per_chunk=5,
369
+ llm=llm,
370
+ embed_model=embed_model,
371
+ include_embeddings=True,
372
+ transformations=[splitter]
373
+ )"""
374
+
375
+ return vector_index#, graph_index
376
+
377
+ def create_retrievers(vector_index, graph_index, llm, category_list):
378
+ """Create and return the PA and HyPA retrievers."""
379
+ vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
380
+ bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10)
381
+
382
+ PA_retriever = PARetriever(
383
+ llm=llm,
384
+ categories_list=category_list,
385
+ rewriter=True,
386
+ vector_retriever=vector_retriever,
387
+ bm25_retriever=bm25_retriever,
388
+ classifier_model="rk68/distilbert-q-classifier-3",
389
+ verbose=False
390
+ )
391
+
392
+ HyPA_retriever = HyPARetriever(
393
+ llm=llm,
394
+ categories_list=category_list,
395
+ rewriter=True,
396
+ kg_index=graph_index,
397
+ vector_retriever=vector_retriever,
398
+ bm25_retriever=bm25_retriever,
399
+ classifier_model="rk68/distilbert-q-classifier-3",
400
+ verbose=False,
401
+ property_index=False
402
+ )
403
+
404
+ return PA_retriever, HyPA_retriever
405
+
406
+ def create_chat_engine(retriever, memory):
407
+ """Create and return the ContextChatEngine using the provided retriever and memory."""
408
+ return ContextChatEngine.from_defaults(
409
+ retriever=retriever,
410
+ verbose=False,
411
+ chat_mode="context",
412
+ memory_cls=memory,
413
+ memory=memory
414
+ )
415
+
416
+ def main():
417
+ # Initialize environment and LLM
418
+ gpt35_creds, gpt4o_mini_creds, gpt4o_creds = initialize_openai_creds()
419
+ llm_gpt35 = create_llm(gpt35_creds=gpt35_creds, gpt4o_mini_creds=gpt4o_mini_creds, gpt4o_creds=gpt4o_creds)
420
+
421
+ # Set global settings for embedding model and LLM
422
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
423
+ Settings.embed_model = embed_model
424
+ Settings.llm = llm_gpt35
425
+
426
+ category_list = [
427
+ '5-301 Bias Audit',
428
+ '5-302 Data Requirements',
429
+ '§ 5-303 Published Results',
430
+ '§ 5-304 Notice to Candidates and Employees'
431
+ ]
432
+
433
+ # Load documents and create indices
434
+ documents = load_documents()
435
+ vector_index, graph_index = create_indices(documents, llm_gpt35, embed_model)
436
+
437
+ # Create retrievers
438
+ PA_retriever, HyPA_retriever = create_retrievers(vector_index, graph_index, llm_gpt35, category_list)
439
+
440
+ # Initialize chat memory
441
+ memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
442
+
443
+ # Create chat engines
444
+ PA_chat_engine = create_chat_engine(PA_retriever, memory)
445
+ HyPA_chat_engine = create_chat_engine(HyPA_retriever, memory)
446
+
447
+ # Sample question and response
448
+ question = "What is a bias audit?"
449
+ PA_response = PA_chat_engine.chat(question)
450
+ HyPA_response = HyPA_chat_engine.chat(question)
451
+
452
+ # Output responses in a nicely formatted manner
453
+ print("\n" + "="*50)
454
+ print(f"Question: {question}")
455
+ print("="*50)
456
+
457
+ print("\n------- PA Retriever Response -------")
458
+ print(PA_response)
459
+
460
+ print("\n------- HyPA Retriever Response -------")
461
+ print(HyPA_response)
462
+ print("="*50 + "\n")
463
+
464
+ if __name__ == '__main__':
465
+ main()
utils_code.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from dotenv import load_dotenv, find_dotenv
4
+ from llama_index.llms.azure_openai import AzureOpenAI
5
+ from llama_index.readers.file import PyMuPDFReader
6
+ from llama_index.core.chat_engine import ContextChatEngine
7
+ from llama_index.core import KnowledgeGraphIndex
8
+ from llama_index.core.node_parser import SentenceSplitter
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+
11
+ def initialize_openai_creds():
12
+ """Load environment variables and set API keys."""
13
+ dotenv_path = find_dotenv()
14
+ if dotenv_path == "":
15
+ print("No .env file found. Make sure the .env file is in the correct directory.")
16
+ else:
17
+ print(f".env file found at: {dotenv_path}")
18
+
19
+ load_dotenv(dotenv_path)
20
+
21
+ # General Azure OpenAI settings for gpt35 and gpt-4o-mini
22
+ general_creds = {
23
+ "api_key": os.getenv('AZURE_OPENAI_API_KEY'),
24
+ "api_version": os.getenv("AZURE_API_VERSION"),
25
+ "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
26
+ "temperature": 0, # Default temperature for models
27
+ "gpt35_deployment_name": os.getenv("AZURE_DEPLOYMENT_NAME"),
28
+ "gpt4o_mini_deployment_name": os.getenv("GPT4O_MINI_DEPLOYMENT_NAME")
29
+ }
30
+
31
+ # GPT-4o specific settings
32
+ gpt4o_creds = {
33
+ "api_key": os.getenv('GPT4O_API_KEY'),
34
+ "api_version": os.getenv("GPT4O_API_VERSION"),
35
+ "endpoint": os.getenv("GPT4O_AZURE_ENDPOINT"),
36
+ "deployment_name": os.getenv("GPT4O_DEPLOYMENT_NAME"),
37
+ "temperature": os.getenv("GPT4O_TEMPERATURE", 0) # Default temperature for GPT-4o
38
+ }
39
+
40
+ return general_creds, gpt4o_creds
41
+
42
+
43
+
44
+ def initialize_openai_creds():
45
+ """Load environment variables and set API keys."""
46
+ dotenv_path = find_dotenv()
47
+ if dotenv_path == "":
48
+ print("No .env file found. Make sure the .env file is in the correct directory.")
49
+ else:
50
+ print(f".env file found at: {dotenv_path}")
51
+
52
+ load_dotenv(dotenv_path)
53
+
54
+ # GPT-3.5 Credentials
55
+ gpt35_creds = {
56
+ "api_key": os.getenv('AZURE_OPENAI_API_KEY_GPT35'),
57
+ "api_version": os.getenv("AZURE_API_VERSION"),
58
+ "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT_GPT35"),
59
+ "temperature": 0, # Default temperature for models
60
+ "deployment_name": os.getenv("AZURE_DEPLOYMENT_NAME_GPT35")
61
+ }
62
+
63
+ # GPT-4o-mini Credentials (shares the same API key as GPT-3.5 but different deployment name and endpoint)
64
+ gpt4o_mini_creds = {
65
+ "api_key": os.getenv('AZURE_OPENAI_API_KEY_GPT4O_MINI'),
66
+ "api_version": os.getenv("AZURE_API_VERSION"),
67
+ "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT_GPT4O_MINI"),
68
+ "temperature": 0, # Default temperature for models
69
+ "deployment_name": os.getenv("GPT4O_MINI_DEPLOYMENT_NAME")
70
+ }
71
+
72
+ # GPT-4o specific credentials
73
+ gpt4o_creds = {
74
+ "api_key": os.getenv('GPT4O_API_KEY'),
75
+ "api_version": os.getenv("GPT4O_API_VERSION"),
76
+ "endpoint": os.getenv("GPT4O_AZURE_ENDPOINT"),
77
+ "deployment_name": os.getenv("GPT4O_DEPLOYMENT_NAME"),
78
+ "temperature": os.getenv("GPT4O_TEMPERATURE", 0) # Default temperature for GPT-4o
79
+ }
80
+
81
+ return gpt35_creds, gpt4o_mini_creds, gpt4o_creds
82
+
83
+
84
+
85
+ def create_llm(model: str, gpt35_creds: dict, gpt4o_mini_creds: dict, gpt4o_creds: dict):
86
+ """
87
+ Initialize and return the Azure OpenAI LLM based on the selected model.
88
+
89
+ :param model: The model to initialize ("gpt35", "gpt4o", or "gpt-4o-mini").
90
+ :param gpt35_creds: Credentials for gpt35.
91
+ :param gpt4o_mini_creds: Credentials for gpt-4o-mini.
92
+ :param gpt4o_creds: Credentials for gpt4o.
93
+ """
94
+ if model == "gpt35":
95
+ return AzureOpenAI(
96
+ deployment_name=gpt35_creds["deployment_name"],
97
+ temperature=gpt35_creds["temperature"],
98
+ api_key=gpt35_creds["api_key"],
99
+ azure_endpoint=gpt35_creds["endpoint"],
100
+ api_version=gpt35_creds["api_version"]
101
+ )
102
+ elif model == "gpt-4o-mini":
103
+ return AzureOpenAI(
104
+ deployment_name=gpt4o_mini_creds["deployment_name"],
105
+ temperature=gpt4o_mini_creds["temperature"],
106
+ api_key=gpt4o_mini_creds["api_key"],
107
+ azure_endpoint=gpt4o_mini_creds["endpoint"],
108
+ api_version=gpt4o_mini_creds["api_version"]
109
+ )
110
+ elif model == "gpt4o":
111
+ return AzureOpenAI(
112
+ deployment_name=gpt4o_creds["deployment_name"],
113
+ temperature=gpt4o_creds["temperature"],
114
+ api_key=gpt4o_creds["api_key"],
115
+ azure_endpoint=gpt4o_creds["endpoint"],
116
+ api_version=gpt4o_creds["api_version"]
117
+ )
118
+ else:
119
+ raise ValueError(f"Invalid model: {model}. Choose from 'gpt35', 'gpt4o', or 'gpt-4o-mini'.")
120
+
121
+
122
+
123
+ def create_chat_engine(retriever, memory, llm):
124
+ """Create and return the ContextChatEngine using the provided retriever and memory."""
125
+ chat_engine = ContextChatEngine.from_defaults(
126
+ retriever=retriever,
127
+ memory=memory,
128
+ llm=llm
129
+ )
130
+ return chat_engine
131
+
132
+
133
+ def load_documents(filepaths):
134
+ """
135
+ Load and return documents from specified file paths.
136
+
137
+ :param filepaths: A string (single file path) or a list of strings (multiple file paths).
138
+ :return: A list of loaded documents.
139
+ """
140
+ loader = PyMuPDFReader()
141
+
142
+ # If a single string is passed, convert it to a list for consistent handling
143
+ if isinstance(filepaths, str):
144
+ filepaths = [filepaths]
145
+
146
+ # Load and accumulate documents
147
+ all_documents = []
148
+ for filepath in filepaths:
149
+ documents = loader.load(file_path=filepath)
150
+ all_documents += documents
151
+
152
+ return all_documents
153
+
154
+
155
+ def create_kg_index(
156
+ documents,
157
+ storage_context,
158
+ llm,
159
+ max_triplets_per_chunk=10,
160
+ embed_model=HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5"),
161
+ include_embeddings=True,
162
+ chunk_size=512
163
+ ):
164
+ splitter = SentenceSplitter(chunk_size=chunk_size)
165
+ graph_index = KnowledgeGraphIndex.from_documents(
166
+ documents,
167
+ storage_context=storage_context,
168
+ max_triplets_per_chunk=max_triplets_per_chunk,
169
+ llm=llm,
170
+ embed_model=embed_model,
171
+ include_embeddings=include_embeddings,
172
+ transformations=[splitter]
173
+ )
174
+ return graph_index
175
+
176
+
177
+ from llama_index.core.indices.property_graph import SimpleLLMPathExtractor
178
+ from llama_index.core.indices.property_graph import DynamicLLMPathExtractor
179
+ from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
180
+ from llama_index.core import PropertyGraphIndex
181
+
182
+
183
+ def create_pg_index(
184
+ llm,
185
+ documents,
186
+ graph_store,
187
+ max_triplets_per_chunk=10,
188
+ num_workers=4,
189
+ embed_kg_nodes=True,
190
+ embed_model=HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
191
+ ):
192
+
193
+ splitter = SentenceSplitter(chunk_size=512)
194
+ # Initialize the LLM path extractor
195
+ kg_extractor = DynamicLLMPathExtractor(
196
+ llm=llm,
197
+ max_triplets_per_chunk=max_triplets_per_chunk,
198
+ num_workers=num_workers
199
+ )
200
+
201
+
202
+ # Create the Property Graph Index
203
+ graph_index = PropertyGraphIndex.from_documents(
204
+ documents,
205
+ property_graph_store=graph_store,
206
+ embed_model=embed_model,
207
+ embed_kg_nodes=embed_kg_nodes,
208
+ kg_extractors=[kg_extractor],
209
+ transformations=[splitter]
210
+ )
211
+
212
+ return graph_index