wu981526092 commited on
Commit
4559323
·
verified ·
1 Parent(s): d8bc587

Upload 11 files

Browse files
Files changed (11) hide show
  1. EUAIACT.pdf +0 -0
  2. LL144.pdf +0 -0
  3. LL144_Definitions.pdf +0 -0
  4. README.md +5 -7
  5. app.py +348 -0
  6. holisticai.svg +76 -0
  7. policy.pdf +0 -0
  8. prompts.py +27 -0
  9. requirements.txt +18 -0
  10. retrievers.py +465 -0
  11. utils_code.py +212 -0
EUAIACT.pdf ADDED
Binary file (923 kB). View file
 
LL144.pdf ADDED
Binary file (492 kB). View file
 
LL144_Definitions.pdf ADDED
Binary file (175 kB). View file
 
README.md CHANGED
@@ -1,14 +1,12 @@
1
  ---
2
- title: HyPA RAG
3
- emoji: 🌍
4
- colorFrom: indigo
5
- colorTo: blue
6
  sdk: streamlit
7
- sdk_version: 1.39.0
8
  app_file: app.py
9
  pinned: false
10
- license: mit
11
- short_description: RAG demo to test queries against the NYC Local Law 144
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: Ragdemo
3
+ emoji: 🦀
4
+ colorFrom: pink
5
+ colorTo: green
6
  sdk: streamlit
7
+ sdk_version: 1.36.0
8
  app_file: app.py
9
  pinned: false
 
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer
3
+ from retrievers import PARetriever
4
+ from utils_code import create_chat_engine
5
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
6
+ from llama_index.core import Settings
7
+ import os
8
+ from llama_index.llms.azure_openai import AzureOpenAI
9
+ from dotenv import load_dotenv, find_dotenv
10
+ from retrievers import HyPARetriever, PARetriever
11
+ from llama_index.vector_stores.pinecone import PineconeVectorStore
12
+ from llama_index.core import VectorStoreIndex
13
+ from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
14
+ from llama_index.core import PropertyGraphIndex
15
+ from llama_index.core.vector_stores import MetadataFilter, MetadataFilters, FilterOperator
16
+ from llama_index.retrievers.bm25 import BM25Retriever
17
+
18
+ # Load environment variables from the .env file
19
+ dotenv_path = find_dotenv()
20
+ #print(f"Dotenv Path: {dotenv_path}")
21
+ load_dotenv(dotenv_path)
22
+
23
+
24
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
25
+ Settings.embed_model = embed_model
26
+
27
+ # Set Azure OpenAI keys for Giskard if needed
28
+ #os.environ["AZURE_OPENAI_API_KEY"] = os.getenv("GSK_AZURE_OPENAI_API_KEY")
29
+ #os.environ["AZURE_OPENAI_ENDPOINT"] = os.getenv("GSK_AZURE_OPENAI_ENDPOINT")
30
+ os.environ["GSK_LLM_MODEL"] = "gpt-4o-mini"
31
+
32
+ # Pinecone and Neo4j credentials
33
+ pinecone_api_key = os.getenv("PINECONE_API_KEY")
34
+ ll144_index_name = 'll144'
35
+ euaiact_index_name = 'euaiact'
36
+
37
+ # Initialize Pinecone
38
+ from pinecone import Pinecone
39
+ pc = Pinecone(api_key=pinecone_api_key)
40
+
41
+
42
+ def metadata_filter(corpus_name):
43
+
44
+ if corpus_name == "EUAIACT":
45
+
46
+ # Filter for 'EUAIACT.pdf'
47
+ filter = MetadataFilters(filters=[MetadataFilter(key="filepath", value="'EUAIACT.pdf'", operator=FilterOperator.CONTAINS)])
48
+
49
+ elif corpus_name == "LL144":
50
+ # Filter for 'LLL144.pdf' or 'LL144_Definitions.pdf'
51
+ filter = MetadataFilters(filters=[
52
+ MetadataFilter(key="filepath", value="'LL144.pdf'", operator=FilterOperator.CONTAINS),
53
+ MetadataFilter(key="filepath", value="'LL144_Definitions.pdf'", operator=FilterOperator.CONTAINS)
54
+ ])
55
+
56
+ return filter
57
+
58
+
59
+ # Load vector index
60
+ #@st.cache_data(ttl=None, persist=None)
61
+ def load_vector_index(corpus_name):
62
+ if corpus_name == "LL144":
63
+ pinecone_index = pc.Index(ll144_index_name)
64
+ elif corpus_name == "EUAIACT":
65
+ pinecone_index = pc.Index(euaiact_index_name)
66
+
67
+ vector_store = PineconeVectorStore(pinecone_index=pinecone_index)
68
+ vector_index = VectorStoreIndex.from_vector_store(vector_store)
69
+
70
+ return vector_index
71
+
72
+ # Load property graph index
73
+ #@st.cache_data(ttl=None, persist=None)
74
+ def load_pg_index():
75
+ neo4j_username = os.getenv("NEO4J_USERNAME")
76
+ neo4j_password = os.getenv("NEO4J_PASSWORD")
77
+ neo4j_url = os.getenv("NEO4J_URI")
78
+
79
+ graph_store = Neo4jPropertyGraphStore(username=neo4j_username, password=neo4j_password, url=neo4j_url)
80
+ pg_index = PropertyGraphIndex.from_existing(property_graph_store=graph_store)
81
+ return pg_index
82
+
83
+ # Initialize the retriever (HyPA or PA)
84
+ def init_retriever(retriever_type, corpus_name, use_reranker, use_rewriter, classifier_model):
85
+ # Check if vector index is cached, if not, load it
86
+ if "vector_index" not in st.session_state:
87
+ st.session_state.vector_index = load_vector_index(corpus_name)
88
+
89
+ # Check if property graph index is cached, if not, load it
90
+ if "pg_index" not in st.session_state:
91
+ st.session_state.pg_index = load_pg_index()
92
+
93
+ vector_index = st.session_state.vector_index
94
+ graph_index = st.session_state.pg_index
95
+ llm = st.session_state.llm
96
+
97
+ filter = metadata_filter(corpus_name=corpus_name)
98
+ # Set the reranker model if selected
99
+ reranker_model_name = "BAAI/bge-reranker-large" if use_reranker else None
100
+
101
+ # Choose the appropriate retriever based on user selection
102
+ if retriever_type == "HyPA":
103
+ retriever = HyPARetriever(
104
+ llm=llm,
105
+ vector_retriever=vector_index.as_retriever(similarity_top_k=10),
106
+ bm25_retriever=None,#BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10),
107
+ kg_index=graph_index, # Include KG for HyPA
108
+ rewriter=use_rewriter, # Set rewriter option
109
+ classifier_model=classifier_model, # Use the selected classifier model
110
+ verbose=False,
111
+ property_index=True, # Use property graph index
112
+ reranker_model_name=reranker_model_name, # Use reranker if selected
113
+ pg_filters=filter
114
+ )
115
+ else:
116
+ retriever = PARetriever(
117
+ llm=llm,
118
+ vector_retriever=vector_index.as_retriever(similarity_top_k=10),
119
+ bm25_retriever=None,#BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10),
120
+ rewriter=use_rewriter, # Set rewriter option
121
+ classifier_model=classifier_model, # Use the selected classifier model
122
+ verbose=False,
123
+ reranker_model_name=reranker_model_name # Use reranker if selected
124
+ )
125
+
126
+ memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
127
+ chat_engine = create_chat_engine(retriever=retriever, memory=memory, llm=llm)
128
+ st.session_state.chat_engine = chat_engine
129
+ #return chat_engine
130
+
131
+
132
+ def process_query(query):
133
+ """Processes the input query and displays it along with the response in the main chat area."""
134
+ # Append the user query to the message history and display it
135
+ st.session_state.messages.append({"role": "user", "content": query})
136
+ with st.chat_message("user"):
137
+ st.write(query)
138
+
139
+ # Ensure the chat engine is initialized
140
+ chat_engine = st.session_state.get('chat_engine', None)
141
+ if chat_engine:
142
+ # Process the query through the chat engine
143
+ with st.chat_message("assistant"):
144
+ with st.spinner("Retrieving Knowledge..."):
145
+ response = chat_engine.stream_chat(query)
146
+ response_str = ""
147
+ response_container = st.empty()
148
+ for token in response.response_gen:
149
+ response_str += token
150
+ response_container.write(response_str)
151
+ # Append the assistant's response to the message history
152
+ st.session_state.messages.append({"role": "assistant", "content": response_str})
153
+
154
+ # Expander for additional info
155
+ with st.expander("Source Nodes"):
156
+ # Display source nodes
157
+ if hasattr(response, 'source_nodes') and response.source_nodes:
158
+
159
+ for idx, node in enumerate(response.source_nodes):
160
+ st.markdown(f"#### Source Node {idx + 1}")
161
+ st.write(f"**Node ID:** {node.node_id}")
162
+ st.write(f"**Node Score:** {node.score}")
163
+
164
+ st.write("**Metadata:**")
165
+ for key, value in node.metadata.items():
166
+ st.write(f"- **{key}:** {value}")
167
+
168
+ st.write("**Content:**")
169
+ st.write(node.node.get_content())
170
+
171
+ # Add a horizontal line to separate nodes
172
+ st.markdown("---")
173
+ else:
174
+ st.write("No additional source nodes available.")
175
+
176
+ st.session_state.messages.append({"role": "assistant", "content": str(response)})
177
+
178
+
179
+
180
+
181
+
182
+ # Streamlit App
183
+ def main():
184
+
185
+
186
+ # Sidebar for retriever options
187
+ with st.sidebar:
188
+ st.image('holisticai.svg', use_column_width=True)
189
+ st.title("Retriever Settings")
190
+
191
+ # Azure OpenAI credentials input fields (start with blank fields)
192
+ azure_api_key = st.text_input("Azure OpenAI API Key", value="", type="password")
193
+ azure_endpoint = st.text_input("Azure OpenAI Endpoint", value="", type="password")
194
+
195
+ llm_model_choice = st.selectbox("Select LLM Model", ["gpt-4o-mini", "gpt35"])
196
+
197
+ # Let the user make selections without updating session state yet
198
+ retriever_type = st.selectbox("Select Retriever Method", ["PA", "HyPA"])
199
+ corpus_name = st.selectbox("Select Corpus", ["LL144", "EUAIACT"])
200
+ temperature = st.slider("Set LLM Temperature", min_value=0.0, max_value=1.0, value=0.0, step=0.1)
201
+
202
+ # Display a red warning about non-zero temperature
203
+ if temperature > 0:
204
+ st.markdown(
205
+ "<p style='color:red;'>Warning: A non-zero temperature may lead to hallucinations in the generated responses.</p>",
206
+ unsafe_allow_html=True
207
+ )
208
+
209
+ # Checkboxes for reranker and rewriter options
210
+ use_reranker = st.checkbox("Use Reranker")
211
+ use_rewriter = st.checkbox("Use Rewriter")
212
+
213
+ # Radio buttons for classifier model
214
+ classifier_type = st.radio("Select Classifier Type", ["2-Class", "3-Class"])
215
+ classifier_model = "rk68/distilbert-q-classifier-2" if classifier_type == "2-Class" else "rk68/distilbert-q-classifier-3"
216
+
217
+
218
+
219
+ # When the user clicks "Initialize", store everything in session state
220
+ if st.button("Initialize"):
221
+ st.session_state.retriever_type = retriever_type
222
+ st.session_state.corpus_name = corpus_name
223
+ st.session_state.temperature = temperature
224
+ st.session_state.use_reranker = use_reranker
225
+ st.session_state.use_rewriter = use_rewriter
226
+ st.session_state.classifier_type = classifier_type
227
+ st.session_state.classifier_model = classifier_model
228
+
229
+ # Store the user inputs in session state
230
+ st.session_state.azure_api_key = azure_api_key
231
+ st.session_state.azure_endpoint = azure_endpoint
232
+
233
+ # Set the environment variables from user inputs
234
+ os.environ["AZURE_OPENAI_API_KEY"] = azure_api_key
235
+ os.environ["AZURE_OPENAI_ENDPOINT"] = azure_endpoint
236
+
237
+ llm = AzureOpenAI(
238
+ deployment_name=llm_model_choice, temperature=temperature,
239
+ api_key=azure_api_key, azure_endpoint=azure_endpoint,
240
+ api_version=os.getenv("AZURE_API_VERSION")
241
+ )
242
+ Settings.llm = llm
243
+ st.session_state.llm = llm
244
+
245
+ # Initialize retriever after storing the settings
246
+ init_retriever(retriever_type, corpus_name, use_reranker, use_rewriter, classifier_model)
247
+ st.success("Retriever Initialized")
248
+
249
+ # Example questions based on selected corpus
250
+ st.markdown("### Example Queries")
251
+ # Example questions with unique button handling
252
+ example_questions = {
253
+ "LL144": [
254
+ "What is a bias audit?",
255
+ "When does it come into effect?",
256
+ "Summarise Local Law 144"
257
+ ],
258
+ "EUAIACT": [
259
+ "What is an AI system?",
260
+ "What are the key takeaways?",
261
+ "Explain the key provisions of EUAIACT."
262
+ ]
263
+ }
264
+
265
+
266
+ # Display buttons for the example queries
267
+ for idx, question in enumerate(example_questions.get(corpus_name, [])):
268
+ if st.button(f"{question} [{idx}]"):
269
+ process_query(question)
270
+
271
+
272
+
273
+
274
+
275
+ # Add a disclaimer at the bottom
276
+ st.markdown("---") # Horizontal line for separation
277
+
278
+ st.markdown(
279
+ """
280
+ <p style="color:grey; font-size:12px;">
281
+ <strong>Disclaimer:</strong> This system is an academic prototype demonstration of our hybrid parameter-adaptive retrieval-augmented generation system. It is <strong>NOT</strong> a production-ready application. All outputs should be considered experimental and may not be fully accurate. This system should not be used for making important legal decisions. For complete, specific, and tailored legal advice, please consult a licensed legal professional.<br><br>
282
+ </p>
283
+ """,
284
+ unsafe_allow_html=True
285
+ )
286
+
287
+
288
+
289
+ # Check if the retriever is initialized
290
+ if "chat_engine" in st.session_state:
291
+ chat_engine = st.session_state.chat_engine
292
+ else:
293
+ st.warning("Please initialize the retriever from the sidebar.")
294
+
295
+
296
+ # Initialize session state for chat messages
297
+ if "messages" not in st.session_state:
298
+ st.session_state.messages = [{"role": "assistant", "content": "How may I assist you?"}]
299
+
300
+ # Display chat messages
301
+ for message in st.session_state.messages:
302
+ with st.chat_message(message["role"]):
303
+ st.write(message["content"])
304
+
305
+
306
+
307
+ # User-provided prompt
308
+ if prompt := st.chat_input():
309
+ st.session_state.messages.append({"role": "user", "content": prompt})
310
+ with st.chat_message("user"):
311
+ st.write(prompt)
312
+
313
+ # Generate a response if the last message is from the user
314
+ if st.session_state.messages[-1]["role"] == "user":
315
+ with st.chat_message("assistant"):
316
+ with st.spinner("Retrieving Knowledge..."):
317
+ response = chat_engine.stream_chat(prompt)
318
+ response_str = ""
319
+ response_container = st.empty()
320
+ for token in response.response_gen:
321
+ response_str += token
322
+ response_container.write(response_str)
323
+ # Expander for additional info
324
+ with st.expander("Source Nodes"):
325
+ # Display source nodes
326
+ if hasattr(response, 'source_nodes') and response.source_nodes:
327
+
328
+ for idx, node in enumerate(response.source_nodes):
329
+ st.markdown(f"#### Source Node {idx + 1}")
330
+ st.write(f"**Node ID:** {node.node_id}")
331
+ st.write(f"**Node Score:** {node.score}")
332
+
333
+ st.write("**Metadata:**")
334
+ for key, value in node.metadata.items():
335
+ st.write(f"- **{key}:** {value}")
336
+
337
+ st.write("**Content:**")
338
+ st.write(node.node.get_content())
339
+
340
+ # Add a horizontal line to separate nodes
341
+ st.markdown("---")
342
+ else:
343
+ st.write("No additional source nodes available.")
344
+
345
+ st.session_state.messages.append({"role": "assistant", "content": str(response)})
346
+
347
+ if __name__ == "__main__":
348
+ main()
holisticai.svg ADDED
policy.pdf ADDED
Binary file (463 kB). View file
 
prompts.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ def get_classification_prompt(categories_list: List[str]) -> str:
4
+ """Generate classification prompt based on the categories list."""
5
+ categories_str = ', '.join([f"'{category}'" for category in categories_list])
6
+ return (
7
+ f"Classify the following query into one of the following categories: {categories_str}. "
8
+ f"If it doesn't fit into any category, respond with 'None'. "
9
+ f"Return the classification, do not output absolutely anything else."
10
+ )
11
+
12
+
13
+ def get_query_generation_prompt(query_str: str, num_queries: int) -> str:
14
+ """Generate query generation prompt based on query string and number of sub-queries."""
15
+ return (
16
+ f"You are an expert at distilling a user question into sub-questions that can be used to fully answer the original query. "
17
+ f"First, identify the key words from the original question below: \n"
18
+ f"{query_str}"
19
+ f"Generate {num_queries} sub-queries that cover the different aspects needed to fully address the user's query.\n\n"
20
+ f"Here is an example: \n"
21
+ f"Original Question: What does test data mean and what do I need to know about it?\n"
22
+ f"Output:\n"
23
+ f"definition of 'test data'\n"
24
+ f"test data requirements and conditions for a bias audit\n"
25
+ f"examples of the use of test data in a bias audit\n\n"
26
+ f"Output the rewritten sub-queries, one on each line, do not output absolutely anything else."
27
+ )
requirements.txt ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pinecone-client
2
+ llama-index
3
+ llama-index-core
4
+ llama-index-llms-openai
5
+ llama-index-llms-replicate
6
+ llama-index-embeddings-huggingface
7
+ llama-index-vector-stores-pinecone
8
+ llama-index-readers-file
9
+ llama-index-retrievers-bm25
10
+ llama-index-llms-groq
11
+ llama-index-llms-azure-openai
12
+ llama-index-llms-openai
13
+ llama-index-readers-file
14
+ llama-index-graph-stores-neo4j
15
+ oauth2client
16
+ gspread
17
+ python-dotenv
18
+ PyMuPDF==1.24.0
retrievers.py ADDED
@@ -0,0 +1,465 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prompts import get_classification_prompt, get_query_generation_prompt
2
+ from utils_code import initialize_openai_creds, create_llm
3
+ from llama_index.core.schema import QueryBundle, NodeWithScore
4
+ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever
5
+ from transformers import pipeline
6
+ from typing import List, Optional
7
+ import asyncio
8
+ from llama_index.core.postprocessor import SentenceTransformerRerank
9
+ from llama_index.core.indices.property_graph import LLMSynonymRetriever
10
+ from llama_index.core.indices.property_graph import VectorContextRetriever, PGRetriever
11
+ from llama_index.core.retrievers import BaseRetriever, VectorIndexRetriever, KGTableRetriever
12
+ import os
13
+
14
+
15
+ class PARetriever(BaseRetriever):
16
+ """Custom retriever that performs query rewriting, Vector search, and BM25 search without Knowledge Graph search."""
17
+
18
+ def __init__(
19
+ self,
20
+ llm, # LLM for query generation
21
+ vector_retriever: Optional[VectorIndexRetriever] = None,
22
+ bm25_retriever: Optional[BaseRetriever] = None,
23
+ mode: str = "OR",
24
+ rewriter: bool = True,
25
+ classifier_model: Optional[str] = None, # Optional classifier model
26
+ device: str = 'cpu', # Device to CPU for huggingface demo
27
+ reranker_model_name: Optional[str] = None, # Model name for SentenceTransformerRerank
28
+ verbose: bool = False, # Verbose flag
29
+ fixed_params: Optional[dict] = None, # New parameter to pass in fixed parameters
30
+ categories_list: Optional[List[str]] = None, # List of categories for query classification
31
+ param_mappings: Optional[dict] = None # Custom parameter mappings based on classifier labels
32
+ ) -> None:
33
+ """Initialize PARetriever parameters."""
34
+ self._vector_retriever = vector_retriever
35
+ self._bm25_retriever = bm25_retriever
36
+ self._llm = llm
37
+ self._rewriter = rewriter
38
+ self._mode = mode
39
+ self._reranker_model_name = reranker_model_name
40
+ self._reranker = None # Initialize reranker as None
41
+ self.verbose = verbose
42
+ self.fixed_params = fixed_params
43
+ self.categories_list = categories_list
44
+ self.param_mappings = param_mappings or {
45
+ "label_0": {"top_k": 5, "max_keywords_per_query": 3, "max_knowledge_sequence": 1},
46
+ "label_1": {"top_k": 7, "max_keywords_per_query": 4, "max_knowledge_sequence": 2},
47
+ "label_2": {"top_k": 10, "max_keywords_per_query": 5, "max_knowledge_sequence": 3}
48
+ }
49
+
50
+ # Initialize the classifier if provided
51
+ self.classifier = None
52
+ if classifier_model:
53
+ self.classifier = pipeline("text-classification", model=classifier_model, device=device)
54
+
55
+ if mode not in ("AND", "OR"):
56
+ raise ValueError("Invalid mode.")
57
+
58
+ def classify_query_and_get_params(self, query: str) -> (str, dict):
59
+ """Classify the query and determine adaptive parameters or use fixed parameters."""
60
+ if self.fixed_params:
61
+ # Use fixed parameters from the dictionary if provided
62
+ params = self.fixed_params
63
+ classification_result = "Fixed"
64
+ if self.verbose:
65
+ print(f"Using fixed parameters: {params}")
66
+ else:
67
+ params = {
68
+ "top_k": 5, # Default top-k
69
+ "max_keywords_per_query": 4, # Default max keywords
70
+ "max_knowledge_sequence": 2 # Default max knowledge sequence
71
+ }
72
+ classification_result = None
73
+
74
+ if self.classifier:
75
+ classification = self.classifier(query)[0]
76
+ label = classification['label'] # Get the classification label directly
77
+ classification_result = label # Store the classification result
78
+ if self.verbose:
79
+ print(f"Query Classification: {classification['label']} with score {classification['score']}")
80
+
81
+ # Use custom mappings or default mappings
82
+ if label in self.param_mappings:
83
+ params = self.param_mappings[label]
84
+ else:
85
+ if self.verbose:
86
+ print(f"Warning: No mapping found for label {label}, using default parameters.")
87
+
88
+ self._classification_result = classification_result
89
+ return classification_result, params
90
+
91
+ def classify_query(self, query_str: str) -> Optional[str]:
92
+ """Classify the query into one of the predefined categories using LLM, or skip if no categories are provided."""
93
+ if not self.categories_list:
94
+ if self.verbose:
95
+ print("No categories provided, skipping query classification.")
96
+ return None
97
+
98
+ # Generate the classification prompt using external function
99
+ classification_prompt = get_classification_prompt(self.categories_list) + f" Query: '{query_str}'"
100
+
101
+ response = self._llm.complete(classification_prompt)
102
+ category = response.text.strip()
103
+
104
+ # Return the category only if it's in the categories list
105
+ return category if category in self.categories_list else None
106
+
107
+ def generate_queries(self, query_str: str, category: Optional[str], num_queries: int = 3) -> List[str]:
108
+ """Generate query variations using the LLM, taking into account the category if applicable."""
109
+
110
+ # Generate query generation prompt using external function
111
+ query_gen_prompt = get_query_generation_prompt(query_str, num_queries)
112
+
113
+ response = self._llm.complete(query_gen_prompt)
114
+ queries = response.text.split("\n")
115
+
116
+ queries = [query.strip() for query in queries if query.strip()]
117
+
118
+ if category:
119
+ category_query = f"{category}"
120
+ queries.append(category_query)
121
+
122
+ return queries
123
+
124
+ async def run_queries(self, queries: List[str], retrievers: List[BaseRetriever]) -> dict:
125
+ """Run queries against retrievers."""
126
+ tasks = []
127
+ for query in queries:
128
+ for retriever in retrievers:
129
+ tasks.append(retriever.aretrieve(query))
130
+
131
+ task_results = await asyncio.gather(*tasks)
132
+
133
+ results_dict = {}
134
+ for i, (query, query_result) in enumerate(zip(queries, task_results)):
135
+ results_dict[(query, i)] = query_result
136
+ return results_dict
137
+
138
+ def fuse_vector_and_bm25_results(self, results_dict, similarity_top_k: int) -> List[NodeWithScore]:
139
+ """Fuse results from Vector and BM25 retrievers."""
140
+ k = 60.0 # `k` is a parameter used to control the impact of outlier rankings.
141
+ fused_scores = {}
142
+ text_to_node = {}
143
+
144
+ for nodes_with_scores in results_dict.values():
145
+ for rank, node_with_score in enumerate(
146
+ sorted(nodes_with_scores, key=lambda x: x.score or 0.0, reverse=True)
147
+ ):
148
+ text = node_with_score.node.get_content()
149
+ text_to_node[text] = node_with_score
150
+ if text not in fused_scores:
151
+ fused_scores[text] = 0.0
152
+ fused_scores[text] += 1.0 / (rank + k)
153
+
154
+ reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))
155
+
156
+ reranked_nodes: List[NodeWithScore] = []
157
+ for text, score in reranked_results.items():
158
+ if text in text_to_node:
159
+ node = text_to_node[text]
160
+ node.score = score
161
+ reranked_nodes.append(node)
162
+ else:
163
+ if self.verbose:
164
+ print(f"Warning: Text not found in `text_to_node`: {text}")
165
+
166
+ return reranked_nodes[:similarity_top_k]
167
+
168
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
169
+ """Retrieve nodes given query."""
170
+ if self._rewriter:
171
+ category = self.classify_query(query_bundle.query_str)
172
+ if self.verbose and category:
173
+ print(f"Classified Category: {category}")
174
+
175
+ classification_result, params = self.classify_query_and_get_params(query_bundle.query_str)
176
+ self._classification_result = classification_result
177
+
178
+ top_k = params["top_k"]
179
+
180
+ if self._reranker_model_name:
181
+ self._reranker = SentenceTransformerRerank(model=self._reranker_model_name, top_n=top_k)
182
+ if self.verbose:
183
+ print(f"Initialized reranker with top_n: {top_k}")
184
+
185
+ num_queries = 3 if top_k == 5 else 5 if top_k == 7 else 7
186
+ if self.verbose:
187
+ print(f"Number of Query Rewrites: {num_queries}")
188
+
189
+ if self._rewriter:
190
+ queries = self.generate_queries(query_bundle.query_str, category, num_queries=num_queries)
191
+ if self.verbose:
192
+ print(f"Generated Queries: {queries}")
193
+ else:
194
+ queries = [query_bundle.query_str]
195
+
196
+ active_retrievers = []
197
+ if self._vector_retriever:
198
+ active_retrievers.append(self._vector_retriever)
199
+ if self._bm25_retriever:
200
+ active_retrievers.append(self._bm25_retriever)
201
+
202
+ if not active_retrievers:
203
+ raise ValueError("No active retriever provided!")
204
+
205
+ results = {}
206
+ if active_retrievers:
207
+ results = asyncio.run(self.run_queries(queries, active_retrievers))
208
+ if self.verbose:
209
+ print(f"Fusion Results: {results}")
210
+
211
+ final_results = self.fuse_vector_and_bm25_results(results, similarity_top_k=top_k)
212
+
213
+ if self._reranker:
214
+ final_results = self._reranker.postprocess_nodes(final_results, query_bundle)
215
+ if self.verbose:
216
+ print(f"Reranked Results: {final_results}")
217
+ else:
218
+ final_results = final_results[:top_k]
219
+
220
+ if self._rewriter:
221
+ unique_nodes = {}
222
+ for node in final_results:
223
+ content = node.node.get_content()
224
+ if content not in unique_nodes:
225
+ unique_nodes[content] = node
226
+ final_results = list(unique_nodes.values())
227
+
228
+ if self.verbose:
229
+ print(f"Final Results: {final_results}")
230
+
231
+ return final_results
232
+
233
+ def get_classification_result(self) -> str:
234
+ return getattr(self, "_classification_result", None)
235
+
236
+
237
+ class HyPARetriever(PARetriever):
238
+ """Custom retriever that extends PARetriever with knowledge graph (KG) search."""
239
+
240
+ def __init__(
241
+ self,
242
+ llm, # LLM for query generation
243
+ vector_retriever: Optional[VectorIndexRetriever] = None,
244
+ bm25_retriever: Optional[BaseRetriever] = None,
245
+ kg_index=None, # Pass the knowledge graph index
246
+ property_index: bool = True, # Whether to use the property graph for retrieval
247
+ pg_filters=None,
248
+ **kwargs, # Pass any additional arguments to PARetriever
249
+ ):
250
+ # Initialize PARetriever to reuse all its functionality
251
+ super().__init__(
252
+ llm=llm,
253
+ vector_retriever=vector_retriever,
254
+ bm25_retriever=bm25_retriever,
255
+ **kwargs
256
+ )
257
+
258
+ # Initialize knowledge graph (KG) specific components
259
+ self._pg_filters = pg_filters
260
+ self._kg_index = kg_index
261
+ self.property_index = property_index
262
+
263
+ def _initialize_kg_retriever(self, params):
264
+ """Initialize the KG retriever based on retrieval mode."""
265
+ graph_index = self._kg_index
266
+ filters = self._pg_filters
267
+
268
+ if self._kg_index and not self.property_index:
269
+ # If not using property index, use KGTableRetriever
270
+ return KGTableRetriever(
271
+ index=self._kg_index,
272
+ retriever_mode='hybrid',
273
+ max_keywords_per_query=params["max_keywords_per_query"],
274
+ max_knowledge_sequence=params["max_knowledge_sequence"]
275
+ )
276
+
277
+ elif self._kg_index and self.property_index:
278
+ # If using property index, use the simpler graph index retriever
279
+ # Use this for the DEMO
280
+
281
+ vector_retriever = VectorContextRetriever(
282
+ graph_store=graph_index.property_graph_store,
283
+ similarity_top_k=params["max_keywords_per_query"],
284
+ path_depth=params["max_knowledge_sequence"],
285
+ include_text=True,
286
+ filters=filters
287
+ )
288
+ synonym_retriever = LLMSynonymRetriever(
289
+ graph_store=graph_index.property_graph_store,
290
+ llm=self._llm,
291
+ include_text=True,
292
+ filters=filters
293
+ )
294
+ return graph_index.as_retriever(sub_retrievers=[vector_retriever, synonym_retriever])
295
+ #return graph_index.as_retriever(similarity_top_k=params["top_k"])
296
+
297
+ return None
298
+
299
+ def _combine_with_kg_results(self, vector_bm25_results, kg_results):
300
+ """Combine KG results with vector and BM25 results."""
301
+ vector_ids = {n.node.id_ for n in vector_bm25_results}
302
+ kg_ids = {n.node.id_ for n in kg_results}
303
+
304
+ combined_results = {n.node.id_: n for n in vector_bm25_results}
305
+ combined_results.update({n.node.id_: n for n in kg_results})
306
+
307
+ if self._mode == "AND":
308
+ result_ids = vector_ids.intersection(kg_ids)
309
+ else:
310
+ result_ids = vector_ids.union(kg_ids)
311
+
312
+ return [combined_results[rid] for rid in result_ids]
313
+
314
+ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
315
+ """Retrieve nodes with KG integration."""
316
+ # Call PARetriever's _retrieve to get the vector and BM25 results
317
+ final_results = super()._retrieve(query_bundle)
318
+
319
+ # If we have a KG index, initialize the retriever
320
+ if self._kg_index:
321
+ kg_retriever = self._initialize_kg_retriever(self.classify_query_and_get_params(query_bundle.query_str)[1])
322
+
323
+ if kg_retriever:
324
+ kg_nodes = kg_retriever.retrieve(query_bundle)
325
+
326
+ # Only combine KG and vector/BM25 results if property_index is True
327
+ if self.property_index:
328
+ final_results = self._combine_with_kg_results(final_results, kg_nodes)
329
+
330
+ return final_results
331
+
332
+
333
+
334
+ import os
335
+ from dotenv import load_dotenv
336
+ from llama_index.llms.azure_openai import AzureOpenAI
337
+ from llama_index.core import VectorStoreIndex, Settings
338
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
339
+ from llama_index.core.node_parser import SentenceSplitter
340
+ from llama_index.core.retrievers import KGTableRetriever, VectorIndexRetriever
341
+ from llama_index.retrievers.bm25 import BM25Retriever
342
+ from llama_index.readers.file import PyMuPDFReader
343
+ from llama_index.core.chat_engine import ContextChatEngine
344
+ from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer
345
+ from llama_index.core import KnowledgeGraphIndex
346
+ from retrievers import PARetriever, HyPARetriever
347
+
348
+
349
+ def load_documents():
350
+ """Load and return documents from specified file paths."""
351
+ loader = PyMuPDFReader()
352
+ documents1 = loader.load(file_path="../../legal_data/LL144/LL144.pdf")
353
+ documents2 = loader.load(file_path="../../legal_data/LL144/LL144_Definitions.pdf")
354
+ return documents1 + documents2
355
+
356
+ def create_indices(documents, llm, embed_model):
357
+ """Create and return VectorStoreIndex and KnowledgeGraphIndex from documents."""
358
+ splitter = SentenceSplitter(chunk_size=512)
359
+
360
+ vector_index = VectorStoreIndex.from_documents(
361
+ documents,
362
+ embed_model=embed_model,
363
+ transformations=[splitter]
364
+ )
365
+
366
+ """graph_index = KnowledgeGraphIndex.from_documents(
367
+ documents,
368
+ max_triplets_per_chunk=5,
369
+ llm=llm,
370
+ embed_model=embed_model,
371
+ include_embeddings=True,
372
+ transformations=[splitter]
373
+ )"""
374
+
375
+ return vector_index#, graph_index
376
+
377
+ def create_retrievers(vector_index, graph_index, llm, category_list):
378
+ """Create and return the PA and HyPA retrievers."""
379
+ vector_retriever = VectorIndexRetriever(index=vector_index, similarity_top_k=10)
380
+ bm25_retriever = BM25Retriever.from_defaults(index=vector_index, similarity_top_k=10)
381
+
382
+ PA_retriever = PARetriever(
383
+ llm=llm,
384
+ categories_list=category_list,
385
+ rewriter=True,
386
+ vector_retriever=vector_retriever,
387
+ bm25_retriever=bm25_retriever,
388
+ classifier_model="rk68/distilbert-q-classifier-3",
389
+ verbose=False
390
+ )
391
+
392
+ HyPA_retriever = HyPARetriever(
393
+ llm=llm,
394
+ categories_list=category_list,
395
+ rewriter=True,
396
+ kg_index=graph_index,
397
+ vector_retriever=vector_retriever,
398
+ bm25_retriever=bm25_retriever,
399
+ classifier_model="rk68/distilbert-q-classifier-3",
400
+ verbose=False,
401
+ property_index=False
402
+ )
403
+
404
+ return PA_retriever, HyPA_retriever
405
+
406
+ def create_chat_engine(retriever, memory):
407
+ """Create and return the ContextChatEngine using the provided retriever and memory."""
408
+ return ContextChatEngine.from_defaults(
409
+ retriever=retriever,
410
+ verbose=False,
411
+ chat_mode="context",
412
+ memory_cls=memory,
413
+ memory=memory
414
+ )
415
+
416
+ def main():
417
+ # Initialize environment and LLM
418
+ gpt35_creds, gpt4o_mini_creds, gpt4o_creds = initialize_openai_creds()
419
+ llm_gpt35 = create_llm(gpt35_creds=gpt35_creds, gpt4o_mini_creds=gpt4o_mini_creds, gpt4o_creds=gpt4o_creds)
420
+
421
+ # Set global settings for embedding model and LLM
422
+ embed_model = HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
423
+ Settings.embed_model = embed_model
424
+ Settings.llm = llm_gpt35
425
+
426
+ category_list = [
427
+ '5-301 Bias Audit',
428
+ '5-302 Data Requirements',
429
+ '§ 5-303 Published Results',
430
+ '§ 5-304 Notice to Candidates and Employees'
431
+ ]
432
+
433
+ # Load documents and create indices
434
+ documents = load_documents()
435
+ vector_index, graph_index = create_indices(documents, llm_gpt35, embed_model)
436
+
437
+ # Create retrievers
438
+ PA_retriever, HyPA_retriever = create_retrievers(vector_index, graph_index, llm_gpt35, category_list)
439
+
440
+ # Initialize chat memory
441
+ memory = ChatMemoryBuffer.from_defaults(token_limit=8192)
442
+
443
+ # Create chat engines
444
+ PA_chat_engine = create_chat_engine(PA_retriever, memory)
445
+ HyPA_chat_engine = create_chat_engine(HyPA_retriever, memory)
446
+
447
+ # Sample question and response
448
+ question = "What is a bias audit?"
449
+ PA_response = PA_chat_engine.chat(question)
450
+ HyPA_response = HyPA_chat_engine.chat(question)
451
+
452
+ # Output responses in a nicely formatted manner
453
+ print("\n" + "="*50)
454
+ print(f"Question: {question}")
455
+ print("="*50)
456
+
457
+ print("\n------- PA Retriever Response -------")
458
+ print(PA_response)
459
+
460
+ print("\n------- HyPA Retriever Response -------")
461
+ print(HyPA_response)
462
+ print("="*50 + "\n")
463
+
464
+ if __name__ == '__main__':
465
+ main()
utils_code.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import os
3
+ from dotenv import load_dotenv, find_dotenv
4
+ from llama_index.llms.azure_openai import AzureOpenAI
5
+ from llama_index.readers.file import PyMuPDFReader
6
+ from llama_index.core.chat_engine import ContextChatEngine
7
+ from llama_index.core import KnowledgeGraphIndex
8
+ from llama_index.core.node_parser import SentenceSplitter
9
+ from llama_index.embeddings.huggingface import HuggingFaceEmbedding
10
+
11
+ def initialize_openai_creds():
12
+ """Load environment variables and set API keys."""
13
+ dotenv_path = find_dotenv()
14
+ if dotenv_path == "":
15
+ print("No .env file found. Make sure the .env file is in the correct directory.")
16
+ else:
17
+ print(f".env file found at: {dotenv_path}")
18
+
19
+ load_dotenv(dotenv_path)
20
+
21
+ # General Azure OpenAI settings for gpt35 and gpt-4o-mini
22
+ general_creds = {
23
+ "api_key": os.getenv('AZURE_OPENAI_API_KEY'),
24
+ "api_version": os.getenv("AZURE_API_VERSION"),
25
+ "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT"),
26
+ "temperature": 0, # Default temperature for models
27
+ "gpt35_deployment_name": os.getenv("AZURE_DEPLOYMENT_NAME"),
28
+ "gpt4o_mini_deployment_name": os.getenv("GPT4O_MINI_DEPLOYMENT_NAME")
29
+ }
30
+
31
+ # GPT-4o specific settings
32
+ gpt4o_creds = {
33
+ "api_key": os.getenv('GPT4O_API_KEY'),
34
+ "api_version": os.getenv("GPT4O_API_VERSION"),
35
+ "endpoint": os.getenv("GPT4O_AZURE_ENDPOINT"),
36
+ "deployment_name": os.getenv("GPT4O_DEPLOYMENT_NAME"),
37
+ "temperature": os.getenv("GPT4O_TEMPERATURE", 0) # Default temperature for GPT-4o
38
+ }
39
+
40
+ return general_creds, gpt4o_creds
41
+
42
+
43
+
44
+ def initialize_openai_creds():
45
+ """Load environment variables and set API keys."""
46
+ dotenv_path = find_dotenv()
47
+ if dotenv_path == "":
48
+ print("No .env file found. Make sure the .env file is in the correct directory.")
49
+ else:
50
+ print(f".env file found at: {dotenv_path}")
51
+
52
+ load_dotenv(dotenv_path)
53
+
54
+ # GPT-3.5 Credentials
55
+ gpt35_creds = {
56
+ "api_key": os.getenv('AZURE_OPENAI_API_KEY_GPT35'),
57
+ "api_version": os.getenv("AZURE_API_VERSION"),
58
+ "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT_GPT35"),
59
+ "temperature": 0, # Default temperature for models
60
+ "deployment_name": os.getenv("AZURE_DEPLOYMENT_NAME_GPT35")
61
+ }
62
+
63
+ # GPT-4o-mini Credentials (shares the same API key as GPT-3.5 but different deployment name and endpoint)
64
+ gpt4o_mini_creds = {
65
+ "api_key": os.getenv('AZURE_OPENAI_API_KEY_GPT4O_MINI'),
66
+ "api_version": os.getenv("AZURE_API_VERSION"),
67
+ "endpoint": os.getenv("AZURE_OPENAI_ENDPOINT_GPT4O_MINI"),
68
+ "temperature": 0, # Default temperature for models
69
+ "deployment_name": os.getenv("GPT4O_MINI_DEPLOYMENT_NAME")
70
+ }
71
+
72
+ # GPT-4o specific credentials
73
+ gpt4o_creds = {
74
+ "api_key": os.getenv('GPT4O_API_KEY'),
75
+ "api_version": os.getenv("GPT4O_API_VERSION"),
76
+ "endpoint": os.getenv("GPT4O_AZURE_ENDPOINT"),
77
+ "deployment_name": os.getenv("GPT4O_DEPLOYMENT_NAME"),
78
+ "temperature": os.getenv("GPT4O_TEMPERATURE", 0) # Default temperature for GPT-4o
79
+ }
80
+
81
+ return gpt35_creds, gpt4o_mini_creds, gpt4o_creds
82
+
83
+
84
+
85
+ def create_llm(model: str, gpt35_creds: dict, gpt4o_mini_creds: dict, gpt4o_creds: dict):
86
+ """
87
+ Initialize and return the Azure OpenAI LLM based on the selected model.
88
+
89
+ :param model: The model to initialize ("gpt35", "gpt4o", or "gpt-4o-mini").
90
+ :param gpt35_creds: Credentials for gpt35.
91
+ :param gpt4o_mini_creds: Credentials for gpt-4o-mini.
92
+ :param gpt4o_creds: Credentials for gpt4o.
93
+ """
94
+ if model == "gpt35":
95
+ return AzureOpenAI(
96
+ deployment_name=gpt35_creds["deployment_name"],
97
+ temperature=gpt35_creds["temperature"],
98
+ api_key=gpt35_creds["api_key"],
99
+ azure_endpoint=gpt35_creds["endpoint"],
100
+ api_version=gpt35_creds["api_version"]
101
+ )
102
+ elif model == "gpt-4o-mini":
103
+ return AzureOpenAI(
104
+ deployment_name=gpt4o_mini_creds["deployment_name"],
105
+ temperature=gpt4o_mini_creds["temperature"],
106
+ api_key=gpt4o_mini_creds["api_key"],
107
+ azure_endpoint=gpt4o_mini_creds["endpoint"],
108
+ api_version=gpt4o_mini_creds["api_version"]
109
+ )
110
+ elif model == "gpt4o":
111
+ return AzureOpenAI(
112
+ deployment_name=gpt4o_creds["deployment_name"],
113
+ temperature=gpt4o_creds["temperature"],
114
+ api_key=gpt4o_creds["api_key"],
115
+ azure_endpoint=gpt4o_creds["endpoint"],
116
+ api_version=gpt4o_creds["api_version"]
117
+ )
118
+ else:
119
+ raise ValueError(f"Invalid model: {model}. Choose from 'gpt35', 'gpt4o', or 'gpt-4o-mini'.")
120
+
121
+
122
+
123
+ def create_chat_engine(retriever, memory, llm):
124
+ """Create and return the ContextChatEngine using the provided retriever and memory."""
125
+ chat_engine = ContextChatEngine.from_defaults(
126
+ retriever=retriever,
127
+ memory=memory,
128
+ llm=llm
129
+ )
130
+ return chat_engine
131
+
132
+
133
+ def load_documents(filepaths):
134
+ """
135
+ Load and return documents from specified file paths.
136
+
137
+ :param filepaths: A string (single file path) or a list of strings (multiple file paths).
138
+ :return: A list of loaded documents.
139
+ """
140
+ loader = PyMuPDFReader()
141
+
142
+ # If a single string is passed, convert it to a list for consistent handling
143
+ if isinstance(filepaths, str):
144
+ filepaths = [filepaths]
145
+
146
+ # Load and accumulate documents
147
+ all_documents = []
148
+ for filepath in filepaths:
149
+ documents = loader.load(file_path=filepath)
150
+ all_documents += documents
151
+
152
+ return all_documents
153
+
154
+
155
+ def create_kg_index(
156
+ documents,
157
+ storage_context,
158
+ llm,
159
+ max_triplets_per_chunk=10,
160
+ embed_model=HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5"),
161
+ include_embeddings=True,
162
+ chunk_size=512
163
+ ):
164
+ splitter = SentenceSplitter(chunk_size=chunk_size)
165
+ graph_index = KnowledgeGraphIndex.from_documents(
166
+ documents,
167
+ storage_context=storage_context,
168
+ max_triplets_per_chunk=max_triplets_per_chunk,
169
+ llm=llm,
170
+ embed_model=embed_model,
171
+ include_embeddings=include_embeddings,
172
+ transformations=[splitter]
173
+ )
174
+ return graph_index
175
+
176
+
177
+ from llama_index.core.indices.property_graph import SimpleLLMPathExtractor
178
+ from llama_index.core.indices.property_graph import DynamicLLMPathExtractor
179
+ from llama_index.graph_stores.neo4j import Neo4jPropertyGraphStore
180
+ from llama_index.core import PropertyGraphIndex
181
+
182
+
183
+ def create_pg_index(
184
+ llm,
185
+ documents,
186
+ graph_store,
187
+ max_triplets_per_chunk=10,
188
+ num_workers=4,
189
+ embed_kg_nodes=True,
190
+ embed_model=HuggingFaceEmbedding(model_name="BAAI/bge-large-en-v1.5")
191
+ ):
192
+
193
+ splitter = SentenceSplitter(chunk_size=512)
194
+ # Initialize the LLM path extractor
195
+ kg_extractor = DynamicLLMPathExtractor(
196
+ llm=llm,
197
+ max_triplets_per_chunk=max_triplets_per_chunk,
198
+ num_workers=num_workers
199
+ )
200
+
201
+
202
+ # Create the Property Graph Index
203
+ graph_index = PropertyGraphIndex.from_documents(
204
+ documents,
205
+ property_graph_store=graph_store,
206
+ embed_model=embed_model,
207
+ embed_kg_nodes=embed_kg_nodes,
208
+ kg_extractors=[kg_extractor],
209
+ transformations=[splitter]
210
+ )
211
+
212
+ return graph_index