Gourisankar Padihary
commited on
Commit
·
5184c29
1
Parent(s):
e234b58
Multiple data set support
Browse files- app.py +102 -39
- generator/compute_metrics.py +43 -8
- generator/compute_rmse_auc_roc_metrics.py +3 -2
- generator/generate_metrics.py +7 -3
- generator/initialize_llm.py +3 -1
- main.py +36 -21
- retriever/retrieve_documents.py +77 -1
app.py
CHANGED
@@ -1,71 +1,134 @@
|
|
1 |
import gradio as gr
|
2 |
import logging
|
3 |
-
|
|
|
|
|
|
|
|
|
4 |
|
5 |
-
def launch_gradio(vector_store,
|
6 |
"""
|
7 |
Launch the Gradio app with pre-initialized objects.
|
8 |
"""
|
9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
try:
|
11 |
-
|
|
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
response
|
|
|
16 |
|
17 |
response_text = f"Response: {response}\n\n"
|
18 |
-
|
19 |
-
for key, value in metrics.items():
|
20 |
-
if key != 'response':
|
21 |
-
metrics_text += f"{key}: {value}\n"
|
22 |
-
|
23 |
-
return response_text, metrics_text
|
24 |
except Exception as e:
|
25 |
logging.error(f"Error processing query: {e}")
|
26 |
-
return f"An error occurred: {e}"
|
27 |
|
28 |
-
def
|
29 |
try:
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
34 |
|
35 |
-
|
36 |
-
result = (
|
37 |
-
f"Relevance RMSE Score: {relevance_rmse}\n"
|
38 |
-
f"Utilization RMSE Score: {utilization_rmse}\n"
|
39 |
-
f"Overall Adherence AUC-ROC: {adherence_auc}\n"
|
40 |
-
)
|
41 |
-
return result
|
42 |
except Exception as e:
|
43 |
-
logging.error(f"Error
|
44 |
-
return f"An error occurred: {e}"
|
45 |
|
46 |
# Define Gradio Blocks layout
|
47 |
with gr.Blocks() as interface:
|
48 |
interface.title = "Real Time RAG Pipeline Q&A"
|
49 |
gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
|
50 |
-
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
|
51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
52 |
with gr.Row():
|
53 |
query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
|
54 |
with gr.Row():
|
|
|
55 |
clear_query_button = gr.Button("Clear") # Clear button
|
56 |
-
submit_button = gr.Button("Submit", variant="primary") # Submit button
|
57 |
with gr.Row():
|
58 |
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
|
|
|
59 |
with gr.Row():
|
|
|
|
|
60 |
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
# Define button actions
|
67 |
-
submit_button.click(
|
68 |
-
|
69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
|
71 |
-
interface.launch()
|
|
|
1 |
import gradio as gr
|
2 |
import logging
|
3 |
+
import threading
|
4 |
+
import time
|
5 |
+
from generator.compute_metrics import get_attributes_text
|
6 |
+
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
|
7 |
+
from io import StringIO
|
8 |
|
9 |
+
def launch_gradio(vector_store, gen_llm, val_llm):
|
10 |
"""
|
11 |
Launch the Gradio app with pre-initialized objects.
|
12 |
"""
|
13 |
+
logger = logging.getLogger()
|
14 |
+
logger.setLevel(logging.INFO)
|
15 |
+
|
16 |
+
# Create a list to store logs
|
17 |
+
logs = []
|
18 |
+
|
19 |
+
# Custom log handler to capture logs and add them to the logs list
|
20 |
+
class LogHandler(logging.Handler):
|
21 |
+
def emit(self, record):
|
22 |
+
log_entry = self.format(record)
|
23 |
+
logs.append(log_entry)
|
24 |
+
|
25 |
+
# Add custom log handler to the logger
|
26 |
+
log_handler = LogHandler()
|
27 |
+
log_handler.setFormatter(logging.Formatter('%(asctime)s - %(message)s'))
|
28 |
+
logger.addHandler(log_handler)
|
29 |
+
|
30 |
+
def log_updater():
|
31 |
+
"""Background function to add logs."""
|
32 |
+
while True:
|
33 |
+
time.sleep(2) # Update logs every 2 seconds
|
34 |
+
pass # Log capture is now handled by the logging system
|
35 |
+
|
36 |
+
def get_logs():
|
37 |
+
"""Retrieve logs for display."""
|
38 |
+
return "\n".join(logs[-50:]) # Only show the last 50 logs for example
|
39 |
+
|
40 |
+
# Start the logging thread
|
41 |
+
threading.Thread(target=log_updater, daemon=True).start()
|
42 |
+
|
43 |
+
def answer_question(query, state):
|
44 |
try:
|
45 |
+
# Generate response using the passed objects
|
46 |
+
response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query)
|
47 |
|
48 |
+
# Update state with the response and source documents
|
49 |
+
state["query"] = query
|
50 |
+
state["response"] = response
|
51 |
+
state["source_docs"] = source_docs
|
52 |
|
53 |
response_text = f"Response: {response}\n\n"
|
54 |
+
return response_text, state
|
|
|
|
|
|
|
|
|
|
|
55 |
except Exception as e:
|
56 |
logging.error(f"Error processing query: {e}")
|
57 |
+
return f"An error occurred: {e}", state
|
58 |
|
59 |
+
def compute_metrics(state):
|
60 |
try:
|
61 |
+
logging.info(f"Computing metrics")
|
62 |
+
|
63 |
+
# Retrieve response and source documents from state
|
64 |
+
response = state.get("response", "")
|
65 |
+
source_docs = state.get("source_docs", {})
|
66 |
+
query = state.get("query", "")
|
67 |
+
|
68 |
+
# Generate metrics using the passed objects
|
69 |
+
attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 1)
|
70 |
+
|
71 |
+
attributes_text = get_attributes_text(attributes)
|
72 |
+
|
73 |
+
metrics_text = "Metrics:\n"
|
74 |
+
for key, value in metrics.items():
|
75 |
+
if key != 'response':
|
76 |
+
metrics_text += f"{key}: {value}\n"
|
77 |
|
78 |
+
return attributes_text, metrics_text
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
except Exception as e:
|
80 |
+
logging.error(f"Error computing metrics: {e}")
|
81 |
+
return f"An error occurred: {e}", ""
|
82 |
|
83 |
# Define Gradio Blocks layout
|
84 |
with gr.Blocks() as interface:
|
85 |
interface.title = "Real Time RAG Pipeline Q&A"
|
86 |
gr.Markdown("### Real Time RAG Pipeline Q&A") # Heading
|
|
|
87 |
|
88 |
+
# Section to display LLM names
|
89 |
+
with gr.Row():
|
90 |
+
model_info = f"Generation LLM: {gen_llm.name if hasattr(gen_llm, 'name') else 'Unknown'}\n"
|
91 |
+
model_info += f"Validation LLM: {val_llm.name if hasattr(val_llm, 'name') else 'Unknown'}\n"
|
92 |
+
gr.Textbox(value=model_info, label="Model Information", interactive=False) # Read-only textbox
|
93 |
+
|
94 |
+
# State to store response and source documents
|
95 |
+
state = gr.State(value={"query": "","response": "", "source_docs": {}})
|
96 |
+
gr.Markdown("Ask a question and get a response with metrics calculated from the RAG pipeline.") # Description
|
97 |
with gr.Row():
|
98 |
query_input = gr.Textbox(label="Ask a question", placeholder="Type your query here")
|
99 |
with gr.Row():
|
100 |
+
submit_button = gr.Button("Submit", variant="primary") # Submit button
|
101 |
clear_query_button = gr.Button("Clear") # Clear button
|
|
|
102 |
with gr.Row():
|
103 |
answer_output = gr.Textbox(label="Response", placeholder="Response will appear here")
|
104 |
+
|
105 |
with gr.Row():
|
106 |
+
compute_metrics_button = gr.Button("Compute metrics", variant="primary")
|
107 |
+
attr_output = gr.Textbox(label="Attributes", placeholder="Attributes will appear here")
|
108 |
metrics_output = gr.Textbox(label="Metrics", placeholder="Metrics will appear here")
|
109 |
+
|
110 |
+
#with gr.Row():
|
111 |
+
|
|
|
|
|
112 |
# Define button actions
|
113 |
+
submit_button.click(
|
114 |
+
fn=answer_question,
|
115 |
+
inputs=[query_input, state],
|
116 |
+
outputs=[answer_output, state]
|
117 |
+
)
|
118 |
+
clear_query_button.click(fn=lambda: "", outputs=[query_input]) # Clear query input
|
119 |
+
compute_metrics_button.click(
|
120 |
+
fn=compute_metrics,
|
121 |
+
inputs=[state],
|
122 |
+
outputs=[attr_output, metrics_output]
|
123 |
+
)
|
124 |
+
|
125 |
+
# Section to display logs
|
126 |
+
with gr.Row():
|
127 |
+
start_log_button = gr.Button("Start Log Update", elem_id="start_btn") # Button to start log updates
|
128 |
+
with gr.Row():
|
129 |
+
log_section = gr.Textbox(label="Logs", interactive=False, visible=True, lines=10) # Log section
|
130 |
+
|
131 |
+
# Set button click to trigger log updates
|
132 |
+
start_log_button.click(fn=get_logs, outputs=log_section)
|
133 |
|
134 |
+
interface.launch()
|
generator/compute_metrics.py
CHANGED
@@ -32,18 +32,53 @@ def compute_metrics(attributes, total_sentences):
|
|
32 |
|
33 |
def get_metrics(attributes, total_sentences):
|
34 |
if attributes.content:
|
35 |
-
#print(attributes)
|
36 |
-
result_content = attributes.content # Access the content attribute
|
37 |
-
# Extract the JSON part from the result_content
|
38 |
-
json_start = result_content.find("{")
|
39 |
-
json_end = result_content.rfind("}") + 1
|
40 |
-
json_str = result_content[json_start:json_end]
|
41 |
-
|
42 |
try:
|
|
|
|
|
|
|
|
|
|
|
43 |
result_json = json.loads(json_str)
|
44 |
# Compute metrics using the extracted attributes
|
45 |
metrics = compute_metrics(result_json, total_sentences)
|
46 |
logging.info(metrics)
|
|
|
47 |
return metrics
|
48 |
except json.JSONDecodeError as e:
|
49 |
-
logging.error(f"JSONDecodeError: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
|
33 |
def get_metrics(attributes, total_sentences):
|
34 |
if attributes.content:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
try:
|
36 |
+
result_content = attributes.content # Access the content attribute
|
37 |
+
# Extract the JSON part from the result_content
|
38 |
+
json_start = result_content.find("{")
|
39 |
+
json_end = result_content.rfind("}") + 1
|
40 |
+
json_str = result_content[json_start:json_end]
|
41 |
result_json = json.loads(json_str)
|
42 |
# Compute metrics using the extracted attributes
|
43 |
metrics = compute_metrics(result_json, total_sentences)
|
44 |
logging.info(metrics)
|
45 |
+
|
46 |
return metrics
|
47 |
except json.JSONDecodeError as e:
|
48 |
+
logging.error(f"JSONDecodeError: {e}")
|
49 |
+
|
50 |
+
def get_attributes_text(attributes):
|
51 |
+
try:
|
52 |
+
result_content = attributes.content # Access the content attribute
|
53 |
+
# Extract the JSON part from the result_content
|
54 |
+
json_start = result_content.find("{")
|
55 |
+
json_end = result_content.rfind("}") + 1
|
56 |
+
json_str = result_content[json_start:json_end]
|
57 |
+
result_json = json.loads(json_str)
|
58 |
+
|
59 |
+
# Extract the required fields from json
|
60 |
+
relevance_explanation = result_json.get("relevance_explanation", "N/A")
|
61 |
+
all_relevant_sentence_keys = result_json.get("all_relevant_sentence_keys", [])
|
62 |
+
overall_supported_explanation = result_json.get("overall_supported_explanation", "N/A")
|
63 |
+
overall_supported = result_json.get("overall_supported", "N/A")
|
64 |
+
sentence_support_information = result_json.get("sentence_support_information", [])
|
65 |
+
all_utilized_sentence_keys = result_json.get("all_utilized_sentence_keys", [])
|
66 |
+
|
67 |
+
# Format the metrics for display
|
68 |
+
attributes_text = "Attributes:\n"
|
69 |
+
attributes_text = f"### Relevance Explanation:\n{relevance_explanation}\n\n"
|
70 |
+
attributes_text += f"### All Relevant Sentence Keys:\n{', '.join(all_relevant_sentence_keys)}\n\n"
|
71 |
+
attributes_text += f"### Overall Supported Explanation:\n{overall_supported_explanation}\n\n"
|
72 |
+
attributes_text += f"### Overall Supported:\n{overall_supported}\n\n"
|
73 |
+
attributes_text += "### Sentence Support Information:\n"
|
74 |
+
for info in sentence_support_information:
|
75 |
+
attributes_text += f"- Response Sentence Key: {info.get('response_sentence_key', 'N/A')}\n"
|
76 |
+
attributes_text += f" Explanation: {info.get('explanation', 'N/A')}\n"
|
77 |
+
attributes_text += f" Supporting Sentence Keys: {', '.join(info.get('supporting_sentence_keys', []))}\n"
|
78 |
+
attributes_text += f" Fully Supported: {info.get('fully_supported', 'N/A')}\n"
|
79 |
+
attributes_text += f"\n### All Utilized Sentence Keys:\n{', '.join(all_utilized_sentence_keys)}"
|
80 |
+
|
81 |
+
return attributes_text
|
82 |
+
except Exception as e:
|
83 |
+
logging.error(f"Error extracting attributes: {e}")
|
84 |
+
return f"An error occurred while extracting attributes: {e}"
|
generator/compute_rmse_auc_roc_metrics.py
CHANGED
@@ -1,6 +1,6 @@
|
|
1 |
|
2 |
from sklearn.metrics import roc_auc_score, root_mean_squared_error
|
3 |
-
from generator.generate_metrics import generate_metrics
|
4 |
import logging
|
5 |
|
6 |
def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_question):
|
@@ -25,7 +25,8 @@ def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_qu
|
|
25 |
query = document['question']
|
26 |
logging.info(f'Query number: {i + 1}')
|
27 |
# Call the generate_metrics for each query
|
28 |
-
response,
|
|
|
29 |
|
30 |
# Extract predicted metrics (ensure these are continuous if possible)
|
31 |
predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
|
|
|
1 |
|
2 |
from sklearn.metrics import roc_auc_score, root_mean_squared_error
|
3 |
+
from generator.generate_metrics import generate_metrics, retrieve_and_generate_response
|
4 |
import logging
|
5 |
|
6 |
def compute_rmse_auc_roc_metrics(gen_llm, val_llm, dataset, vector_store, num_question):
|
|
|
25 |
query = document['question']
|
26 |
logging.info(f'Query number: {i + 1}')
|
27 |
# Call the generate_metrics for each query
|
28 |
+
response, source_docs = retrieve_and_generate_response(gen_llm, vector_store, query)
|
29 |
+
attributes, metrics = generate_metrics(val_llm, response, source_docs, query, 25)
|
30 |
|
31 |
# Extract predicted metrics (ensure these are continuous if possible)
|
32 |
predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
|
generator/generate_metrics.py
CHANGED
@@ -5,7 +5,7 @@ from retriever.retrieve_documents import retrieve_top_k_documents
|
|
5 |
from generator.compute_metrics import get_metrics
|
6 |
from generator.extract_attributes import extract_attributes
|
7 |
|
8 |
-
def
|
9 |
logging.info(f'Query: {query}')
|
10 |
|
11 |
# Step 1: Retrieve relevant documents for given query
|
@@ -21,6 +21,10 @@ def generate_metrics(gen_llm, val_llm, vector_store, query, time_to_wait):
|
|
21 |
|
22 |
logging.info(f"Response from LLM: {response}")
|
23 |
|
|
|
|
|
|
|
|
|
24 |
# Add a sleep interval to avoid hitting the rate limit
|
25 |
time.sleep(time_to_wait) # Adjust the sleep time as needed
|
26 |
|
@@ -28,8 +32,8 @@ def generate_metrics(gen_llm, val_llm, vector_store, query, time_to_wait):
|
|
28 |
logging.info(f"Extracting attributes through validation LLM")
|
29 |
attributes, total_sentences = extract_attributes(val_llm, query, source_docs, response)
|
30 |
logging.info(f"Extracted attributes successfully")
|
31 |
-
|
32 |
# Step 4 : Call the get metrics calculate metrics
|
33 |
metrics = get_metrics(attributes, total_sentences)
|
34 |
|
35 |
-
return
|
|
|
5 |
from generator.compute_metrics import get_metrics
|
6 |
from generator.extract_attributes import extract_attributes
|
7 |
|
8 |
+
def retrieve_and_generate_response(gen_llm, vector_store, query):
|
9 |
logging.info(f'Query: {query}')
|
10 |
|
11 |
# Step 1: Retrieve relevant documents for given query
|
|
|
21 |
|
22 |
logging.info(f"Response from LLM: {response}")
|
23 |
|
24 |
+
return response, source_docs
|
25 |
+
|
26 |
+
def generate_metrics(val_llm, response, source_docs, query, time_to_wait):
|
27 |
+
|
28 |
# Add a sleep interval to avoid hitting the rate limit
|
29 |
time.sleep(time_to_wait) # Adjust the sleep time as needed
|
30 |
|
|
|
32 |
logging.info(f"Extracting attributes through validation LLM")
|
33 |
attributes, total_sentences = extract_attributes(val_llm, query, source_docs, response)
|
34 |
logging.info(f"Extracted attributes successfully")
|
35 |
+
|
36 |
# Step 4 : Call the get metrics calculate metrics
|
37 |
metrics = get_metrics(attributes, total_sentences)
|
38 |
|
39 |
+
return attributes, metrics
|
generator/initialize_llm.py
CHANGED
@@ -4,8 +4,9 @@ from langchain_groq import ChatGroq
|
|
4 |
|
5 |
def initialize_generation_llm():
|
6 |
os.environ["GROQ_API_KEY"] = "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"
|
7 |
-
model_name = "
|
8 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
|
|
9 |
logging.info(f'Generation LLM {model_name} initialized')
|
10 |
return llm
|
11 |
|
@@ -13,5 +14,6 @@ def initialize_validation_llm():
|
|
13 |
os.environ["GROQ_API_KEY"] = "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"
|
14 |
model_name = "llama3-70b-8192"
|
15 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
|
|
16 |
logging.info(f'Validation LLM {model_name} initialized')
|
17 |
return llm
|
|
|
4 |
|
5 |
def initialize_generation_llm():
|
6 |
os.environ["GROQ_API_KEY"] = "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"
|
7 |
+
model_name = "mixtral-8x7b-32768"
|
8 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
9 |
+
llm.name = model_name
|
10 |
logging.info(f'Generation LLM {model_name} initialized')
|
11 |
return llm
|
12 |
|
|
|
14 |
os.environ["GROQ_API_KEY"] = "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"
|
15 |
model_name = "llama3-70b-8192"
|
16 |
llm = ChatGroq(model=model_name, temperature=0.7)
|
17 |
+
llm.name = model_name
|
18 |
logging.info(f'Validation LLM {model_name} initialized')
|
19 |
return llm
|
main.py
CHANGED
@@ -3,7 +3,6 @@ from data.load_dataset import load_data
|
|
3 |
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
|
4 |
from retriever.chunk_documents import chunk_documents
|
5 |
from retriever.embed_documents import embed_documents
|
6 |
-
from generator.generate_metrics import generate_metrics
|
7 |
from generator.initialize_llm import initialize_generation_llm
|
8 |
from generator.initialize_llm import initialize_validation_llm
|
9 |
from app import launch_gradio
|
@@ -13,21 +12,43 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
|
|
13 |
|
14 |
def main():
|
15 |
logging.info("Starting the RAG pipeline")
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
logging.info("Dataset loaded")
|
21 |
|
22 |
-
#
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
28 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
# Embed the documents
|
30 |
-
vector_store = embed_documents(
|
31 |
logging.info("Documents embedded")
|
32 |
|
33 |
# Initialize the Generation LLM
|
@@ -36,18 +57,12 @@ def main():
|
|
36 |
# Initialize the Validation LLM
|
37 |
val_llm = initialize_validation_llm()
|
38 |
|
39 |
-
# Sample question
|
40 |
-
#row_num = 30
|
41 |
-
#query = dataset[row_num]['question']
|
42 |
-
|
43 |
-
# Call generate_metrics for above sample question
|
44 |
-
#generate_metrics(gen_llm, val_llm, vector_store, query)
|
45 |
-
|
46 |
#Compute RMSE and AUC-ROC for entire dataset
|
47 |
-
|
|
|
48 |
|
49 |
# Launch the Gradio app
|
50 |
-
launch_gradio(vector_store,
|
51 |
|
52 |
logging.info("Finished!!!")
|
53 |
|
|
|
3 |
from generator.compute_rmse_auc_roc_metrics import compute_rmse_auc_roc_metrics
|
4 |
from retriever.chunk_documents import chunk_documents
|
5 |
from retriever.embed_documents import embed_documents
|
|
|
6 |
from generator.initialize_llm import initialize_generation_llm
|
7 |
from generator.initialize_llm import initialize_validation_llm
|
8 |
from app import launch_gradio
|
|
|
12 |
|
13 |
def main():
|
14 |
logging.info("Starting the RAG pipeline")
|
15 |
+
|
16 |
+
|
17 |
+
# Load single dataset
|
18 |
+
#dataset = load_data(data_set_name)
|
19 |
+
#logging.info("Dataset loaded")
|
20 |
+
# List of datasets to load
|
21 |
+
data_set_names = ['covidqa', 'techqa', 'cuad']
|
22 |
|
23 |
+
default_chunk_size = 1000
|
24 |
+
chunk_overlap = 200
|
|
|
25 |
|
26 |
+
# Dictionary to store chunked documents
|
27 |
+
all_chunked_documents = []
|
28 |
+
# Load multiple datasets
|
29 |
+
datasets = {}
|
30 |
+
for data_set_name in data_set_names:
|
31 |
+
logging.info(f"Loading dataset: {data_set_name}")
|
32 |
+
datasets[data_set_name] = load_data(data_set_name)
|
33 |
|
34 |
+
# Set chunk size based on dataset name
|
35 |
+
chunk_size = default_chunk_size
|
36 |
+
if data_set_name == 'cuad':
|
37 |
+
chunk_size = 4000 # Custom chunk size for 'cuad'
|
38 |
+
|
39 |
+
# Chunk documents
|
40 |
+
chunked_documents = chunk_documents(datasets[data_set_name], chunk_size=chunk_size, chunk_overlap=chunk_overlap)
|
41 |
+
all_chunked_documents.extend(chunked_documents) # Combine all chunks
|
42 |
+
|
43 |
+
# Access individual datasets
|
44 |
+
#for name, dataset in datasets.items():
|
45 |
+
#logging.info(f"Loaded {name} with {dataset.num_rows} rows")
|
46 |
+
|
47 |
+
# Logging final count
|
48 |
+
logging.info(f"Total chunked documents: {len(all_chunked_documents)}")
|
49 |
+
|
50 |
# Embed the documents
|
51 |
+
vector_store = embed_documents(all_chunked_documents)
|
52 |
logging.info("Documents embedded")
|
53 |
|
54 |
# Initialize the Generation LLM
|
|
|
57 |
# Initialize the Validation LLM
|
58 |
val_llm = initialize_validation_llm()
|
59 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
#Compute RMSE and AUC-ROC for entire dataset
|
61 |
+
data_set_name = 'covidqa'
|
62 |
+
#compute_rmse_auc_roc_metrics(gen_llm, val_llm, datasets[data_set_name], vector_store, 10)
|
63 |
|
64 |
# Launch the Gradio app
|
65 |
+
launch_gradio(vector_store, gen_llm, val_llm)
|
66 |
|
67 |
logging.info("Finished!!!")
|
68 |
|
retriever/retrieve_documents.py
CHANGED
@@ -1,2 +1,78 @@
|
|
|
|
|
|
|
|
1 |
def retrieve_top_k_documents(vector_store, query, top_k=5):
|
2 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from transformers import pipeline
|
3 |
+
|
4 |
def retrieve_top_k_documents(vector_store, query, top_k=5):
|
5 |
+
documents = vector_store.similarity_search(query, k=top_k)
|
6 |
+
documents = rerank_documents(query, documents)
|
7 |
+
return documents
|
8 |
+
|
9 |
+
# Reranking: Cross-Encoder for refining top-k results
|
10 |
+
def rerank_documents(query, documents, reranker_model_name="cross-encoder/ms-marco-electra-base"):
|
11 |
+
"""
|
12 |
+
Re-rank documents using a cross-encoder model.
|
13 |
+
|
14 |
+
Parameters:
|
15 |
+
query (str): The user's query.
|
16 |
+
documents (list): List of LangChain Document objects.
|
17 |
+
reranker_model_name (str): Hugging Face model name for re-ranking.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
list: Re-ranked list of Document objects with updated scores.
|
21 |
+
"""
|
22 |
+
# Initialize the cross-encoder model
|
23 |
+
reranker = pipeline("text-classification", model=reranker_model_name, return_all_scores=False)
|
24 |
+
|
25 |
+
# Pair the query with each document's text
|
26 |
+
rerank_inputs = [{"text": query, "text_pair": doc.page_content} for doc in documents]
|
27 |
+
|
28 |
+
# Get relevance scores for each query-document pair
|
29 |
+
scores = reranker(rerank_inputs)
|
30 |
+
|
31 |
+
# Attach the new scores to the documents
|
32 |
+
for doc, score in zip(documents, scores):
|
33 |
+
doc.metadata["rerank_score"] = score["score"] # Add score to document metadata
|
34 |
+
|
35 |
+
# Sort documents by the rerank_score in descending order
|
36 |
+
documents = sorted(documents, key=lambda x: x.metadata.get("rerank_score", 0), reverse=True)
|
37 |
+
return documents
|
38 |
+
|
39 |
+
|
40 |
+
# Query Handling: Retrieve top-k candidates using FAISS with IVF index not used only for learning
|
41 |
+
def retrieve_top_k_documents_manual(vector_store, query, top_k=5):
|
42 |
+
"""
|
43 |
+
Retrieve top-k documents using FAISS index and optionally rerank them.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
vector_store (FAISS): The vector store containing the FAISS index and docstore.
|
47 |
+
query (str): The user's query string.
|
48 |
+
top_k (int): The number of top results to retrieve.
|
49 |
+
reranker_model_name (str): The Hugging Face model name for cross-encoder reranking.
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
list: Top-k retrieved and reranked documents.
|
53 |
+
"""
|
54 |
+
# Encode the query into a dense vector
|
55 |
+
embedding_model = vector_store.embedding_function
|
56 |
+
query_vector = embedding_model.embed_query(query) # Encode the query
|
57 |
+
query_vector = np.array([query_vector]).astype('float32')
|
58 |
+
|
59 |
+
# Search the FAISS index for top_k results
|
60 |
+
distances, indices = vector_store.index.search(query_vector, top_k)
|
61 |
+
|
62 |
+
# Retrieve documents from the docstore
|
63 |
+
documents = []
|
64 |
+
for idx in indices.flatten():
|
65 |
+
if idx == -1: # FAISS can return -1 for invalid indices
|
66 |
+
continue
|
67 |
+
doc_id = vector_store.index_to_docstore_id[idx]
|
68 |
+
|
69 |
+
# Access the internal dictionary of InMemoryDocstore
|
70 |
+
internal_docstore = getattr(vector_store.docstore, "_dict", None)
|
71 |
+
if internal_docstore and doc_id in internal_docstore: # Check if doc_id exists
|
72 |
+
document = internal_docstore[doc_id]
|
73 |
+
documents.append(document)
|
74 |
+
|
75 |
+
# Rerank the documents
|
76 |
+
documents = rerank_documents(query, documents)
|
77 |
+
|
78 |
+
return documents
|