Gourisankar Padihary
commited on
Commit
·
f7c2fa3
1
Parent(s):
b1b2c27
Compute RMSE and AUCROC
Browse files- data/load_dataset.py +1 -1
- generator/compute_metrics.py +32 -7
- generator/compute_rmse_auc_roc_metrics.py +101 -0
- generator/extract_attributes.py +10 -3
- generator/initialize_llm.py +5 -0
- main.py +11 -21
data/load_dataset.py
CHANGED
@@ -5,5 +5,5 @@ def load_data():
|
|
5 |
logging.info("Loading dataset")
|
6 |
dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
|
7 |
logging.info("Dataset loaded successfully")
|
8 |
-
logging.info(dataset)
|
9 |
return dataset
|
|
|
5 |
logging.info("Loading dataset")
|
6 |
dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
|
7 |
logging.info("Dataset loaded successfully")
|
8 |
+
logging.info(f"Number of documents found: {dataset.num_rows}")
|
9 |
return dataset
|
generator/compute_metrics.py
CHANGED
@@ -1,3 +1,6 @@
|
|
|
|
|
|
|
|
1 |
def compute_metrics(attributes, total_sentences):
|
2 |
# Extract relevant information from attributes
|
3 |
all_relevant_sentence_keys = attributes.get("all_relevant_sentence_keys", [])
|
@@ -8,17 +11,39 @@ def compute_metrics(attributes, total_sentences):
|
|
8 |
context_relevance = len(all_relevant_sentence_keys) / total_sentences if total_sentences else 0
|
9 |
|
10 |
# Compute Context Utilization
|
11 |
-
context_utilization = len(all_utilized_sentence_keys) /
|
12 |
-
|
13 |
-
# Compute Completeness
|
14 |
-
completeness = all(info.get("fully_supported", False) for info in sentence_support_information)
|
15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
# Compute Adherence
|
17 |
-
adherence =
|
18 |
|
19 |
return {
|
20 |
"Context Relevance": context_relevance,
|
21 |
"Context Utilization": context_utilization,
|
22 |
-
"Completeness":
|
23 |
"Adherence": adherence
|
24 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
|
4 |
def compute_metrics(attributes, total_sentences):
|
5 |
# Extract relevant information from attributes
|
6 |
all_relevant_sentence_keys = attributes.get("all_relevant_sentence_keys", [])
|
|
|
11 |
context_relevance = len(all_relevant_sentence_keys) / total_sentences if total_sentences else 0
|
12 |
|
13 |
# Compute Context Utilization
|
14 |
+
context_utilization = len(all_utilized_sentence_keys) / total_sentences if total_sentences else 0
|
|
|
|
|
|
|
15 |
|
16 |
+
# Compute Completeness score
|
17 |
+
Ri = set(all_relevant_sentence_keys)
|
18 |
+
Ui = set(all_utilized_sentence_keys)
|
19 |
+
|
20 |
+
completeness_score = len(Ri & Ui) / len(Ri) if len(Ri) else 0
|
21 |
+
|
22 |
# Compute Adherence
|
23 |
+
adherence = all(info.get("fully_supported", False) for info in sentence_support_information)
|
24 |
|
25 |
return {
|
26 |
"Context Relevance": context_relevance,
|
27 |
"Context Utilization": context_utilization,
|
28 |
+
"Completeness Score": completeness_score,
|
29 |
"Adherence": adherence
|
30 |
+
}
|
31 |
+
|
32 |
+
def get_metrics(attributes, total_sentences):
|
33 |
+
if attributes.content:
|
34 |
+
result_content = attributes.content # Access the content attribute
|
35 |
+
# Extract the JSON part from the result_content
|
36 |
+
json_start = result_content.find("{")
|
37 |
+
json_end = result_content.rfind("}") + 1
|
38 |
+
json_str = result_content[json_start:json_end]
|
39 |
+
|
40 |
+
try:
|
41 |
+
result_json = json.loads(json_str)
|
42 |
+
print(json.dumps(result_json, indent=2))
|
43 |
+
|
44 |
+
# Compute metrics using the extracted attributes
|
45 |
+
metrics = compute_metrics(result_json, total_sentences)
|
46 |
+
print(metrics)
|
47 |
+
return metrics
|
48 |
+
except json.JSONDecodeError as e:
|
49 |
+
logging.error(f"JSONDecodeError: {e}")
|
generator/compute_rmse_auc_roc_metrics.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from sklearn.metrics import roc_auc_score, root_mean_squared_error
|
3 |
+
from generator.compute_metrics import get_metrics
|
4 |
+
from generator.extract_attributes import extract_attributes
|
5 |
+
from generator.generate_response import generate_response
|
6 |
+
from retriever.retrieve_documents import retrieve_top_k_documents
|
7 |
+
|
8 |
+
def compute_rmse_auc_roc_metrics(llm, dataset, vector_store):
|
9 |
+
|
10 |
+
# Lists to accumulate ground truths and predictions for AUC-ROC computation
|
11 |
+
all_ground_truth_relevance = []
|
12 |
+
all_predicted_relevance = []
|
13 |
+
|
14 |
+
all_ground_truth_utilization = []
|
15 |
+
all_predicted_utilization = []
|
16 |
+
|
17 |
+
all_ground_truth_adherence = []
|
18 |
+
all_predicted_adherence = []
|
19 |
+
|
20 |
+
# To store RMSE scores for each question
|
21 |
+
relevance_scores = []
|
22 |
+
utilization_scores = []
|
23 |
+
adherence_scores = []
|
24 |
+
|
25 |
+
for i, sample in enumerate(dataset):
|
26 |
+
print(sample)
|
27 |
+
sample_question = sample['question']
|
28 |
+
|
29 |
+
# Extract ground truth metrics from dataset
|
30 |
+
ground_truth_relevance = dataset[i]['relevance_score']
|
31 |
+
ground_truth_utilization = dataset[i]['utilization_score']
|
32 |
+
ground_truth_completeness = dataset[i]['completeness_score']
|
33 |
+
|
34 |
+
# Step 1: Retrieve relevant documents
|
35 |
+
relevant_docs = retrieve_top_k_documents(vector_store, sample_question, top_k=5)
|
36 |
+
|
37 |
+
# Step 2: Generate a response using LLM
|
38 |
+
response, source_docs = generate_response(llm, vector_store, sample_question, relevant_docs)
|
39 |
+
|
40 |
+
# Step 3: Extract attributes
|
41 |
+
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
|
42 |
+
|
43 |
+
# Call the process_attributes method in the main block
|
44 |
+
metrics = get_metrics(attributes, total_sentences)
|
45 |
+
|
46 |
+
# Extract predicted metrics (ensure these are continuous if possible)
|
47 |
+
predicted_relevance = metrics['Context Relevance']
|
48 |
+
predicted_utilization = metrics['Context Utilization']
|
49 |
+
predicted_completeness = metrics['Completeness Score']
|
50 |
+
|
51 |
+
# === Handle Continuous Inputs for RMSE ===
|
52 |
+
relevance_rmse = root_mean_squared_error([ground_truth_relevance], [predicted_relevance])
|
53 |
+
utilization_rmse = root_mean_squared_error([ground_truth_utilization], [predicted_utilization])
|
54 |
+
#adherence_rmse = mean_squared_error([ground_truth_adherence], [predicted_adherence], squared=False)
|
55 |
+
|
56 |
+
# === Handle Binary Conversion for AUC-ROC ===
|
57 |
+
binary_ground_truth_relevance = 1 if ground_truth_relevance > 0.5 else 0
|
58 |
+
binary_predicted_relevance = 1 if predicted_relevance > 0.5 else 0
|
59 |
+
|
60 |
+
binary_ground_truth_utilization = 1 if ground_truth_utilization > 0.5 else 0
|
61 |
+
binary_predicted_utilization = 1 if predicted_utilization > 0.5 else 0
|
62 |
+
|
63 |
+
#binary_ground_truth_adherence = 1 if ground_truth_adherence > 0.5 else 0
|
64 |
+
#binary_predicted_adherence = 1 if predicted_adherence > 0.5 else 0
|
65 |
+
|
66 |
+
# === Accumulate data for overall AUC-ROC computation ===
|
67 |
+
all_ground_truth_relevance.append(binary_ground_truth_relevance)
|
68 |
+
all_predicted_relevance.append(predicted_relevance) # Use probability-based predictions
|
69 |
+
|
70 |
+
all_ground_truth_utilization.append(binary_ground_truth_utilization)
|
71 |
+
all_predicted_utilization.append(predicted_utilization)
|
72 |
+
|
73 |
+
#all_ground_truth_adherence.append(binary_ground_truth_adherence)
|
74 |
+
#all_predicted_adherence.append(predicted_adherence)
|
75 |
+
|
76 |
+
# Store RMSE scores for each question
|
77 |
+
relevance_scores.append(relevance_rmse)
|
78 |
+
utilization_scores.append(utilization_rmse)
|
79 |
+
#adherence_scores.append(adherence_rmse)
|
80 |
+
if i == 9: # Stop after processing the first 10 rows
|
81 |
+
break
|
82 |
+
# === Compute AUC-ROC for the Entire Dataset ===
|
83 |
+
try:
|
84 |
+
print(f"All Ground Truth Relevance: {all_ground_truth_relevance}")
|
85 |
+
print(f"All Predicted Relevance: {all_predicted_relevance}")
|
86 |
+
relevance_auc = roc_auc_score(all_ground_truth_relevance, all_predicted_relevance)
|
87 |
+
except ValueError:
|
88 |
+
relevance_auc = None
|
89 |
+
|
90 |
+
try:
|
91 |
+
print(f"All Ground Truth Utilization: {all_ground_truth_utilization}")
|
92 |
+
print(f"All Predicted Utilization: {all_predicted_utilization}")
|
93 |
+
utilization_auc = roc_auc_score(all_ground_truth_utilization, all_predicted_utilization)
|
94 |
+
except ValueError:
|
95 |
+
utilization_auc = None
|
96 |
+
|
97 |
+
print(f"Relevance RMSE (per question): {relevance_scores}")
|
98 |
+
print(f"Utilization RMSE (per question): {utilization_scores}")
|
99 |
+
#print(f"Adherence RMSE (per question): {adherence_scores}")
|
100 |
+
print(f"\nOverall Relevance AUC-ROC: {relevance_auc}")
|
101 |
+
print(f"Overall Utilization AUC-ROC: {utilization_auc}")
|
generator/extract_attributes.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
from generator.create_prompt import create_prompt
|
2 |
-
from generator.initialize_llm import
|
3 |
from generator.document_utils import Document, apply_sentence_keys_documents, apply_sentence_keys_response
|
4 |
|
5 |
# Initialize the LLM
|
6 |
-
llm =
|
7 |
|
8 |
# Function to extract attributes
|
9 |
def extract_attributes(question, relevant_docs, response):
|
@@ -12,9 +12,16 @@ def extract_attributes(question, relevant_docs, response):
|
|
12 |
formatted_documents = apply_sentence_keys_documents(relevant_docs)
|
13 |
formatted_responses = apply_sentence_keys_response(response)
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
# Calculate the total number of sentences from formatted_documents
|
16 |
total_sentences = sum(len(doc) for doc in formatted_documents)
|
17 |
-
|
|
|
18 |
attribute_prompt = create_prompt(formatted_documents, question, formatted_responses)
|
19 |
|
20 |
# Instead of using BaseMessage, pass the formatted prompt directly to invoke
|
|
|
1 |
from generator.create_prompt import create_prompt
|
2 |
+
from generator.initialize_llm import initialize_validation_llm
|
3 |
from generator.document_utils import Document, apply_sentence_keys_documents, apply_sentence_keys_response
|
4 |
|
5 |
# Initialize the LLM
|
6 |
+
llm = initialize_validation_llm()
|
7 |
|
8 |
# Function to extract attributes
|
9 |
def extract_attributes(question, relevant_docs, response):
|
|
|
12 |
formatted_documents = apply_sentence_keys_documents(relevant_docs)
|
13 |
formatted_responses = apply_sentence_keys_response(response)
|
14 |
|
15 |
+
#print(f"Formatted documents : {formatted_documents}")
|
16 |
+
# Print the number of sentences in each document
|
17 |
+
for i, doc in enumerate(formatted_documents):
|
18 |
+
num_sentences = len(doc)
|
19 |
+
print(f"Document {i} has {num_sentences} sentences.")
|
20 |
+
|
21 |
# Calculate the total number of sentences from formatted_documents
|
22 |
total_sentences = sum(len(doc) for doc in formatted_documents)
|
23 |
+
print(f"Total number of sentences {total_sentences}")
|
24 |
+
|
25 |
attribute_prompt = create_prompt(formatted_documents, question, formatted_responses)
|
26 |
|
27 |
# Instead of using BaseMessage, pass the formatted prompt directly to invoke
|
generator/initialize_llm.py
CHANGED
@@ -4,4 +4,9 @@ from langchain_groq import ChatGroq
|
|
4 |
def initialize_llm():
|
5 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
6 |
llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
|
|
|
|
|
|
|
|
|
|
|
7 |
return llm
|
|
|
4 |
def initialize_llm():
|
5 |
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
6 |
llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
|
7 |
+
return llm
|
8 |
+
|
9 |
+
def initialize_validation_llm():
|
10 |
+
os.environ["GROQ_API_KEY"] = "your_groq_api_key"
|
11 |
+
llm = ChatGroq(model="llama3-70b-8192", temperature=0.7)
|
12 |
return llm
|
main.py
CHANGED
@@ -1,12 +1,13 @@
|
|
1 |
-
import logging
|
2 |
from data.load_dataset import load_data
|
|
|
3 |
from retriever.chunk_documents import chunk_documents
|
4 |
from retriever.embed_documents import embed_documents
|
5 |
from retriever.retrieve_documents import retrieve_top_k_documents
|
6 |
from generator.initialize_llm import initialize_llm
|
7 |
from generator.generate_response import generate_response
|
8 |
from generator.extract_attributes import extract_attributes
|
9 |
-
from generator.compute_metrics import
|
10 |
|
11 |
# Configure logging
|
12 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
@@ -27,7 +28,8 @@ def main():
|
|
27 |
logging.info("Documents embedded")
|
28 |
|
29 |
# Sample question
|
30 |
-
|
|
|
31 |
logging.info(f"Sample question: {sample_question}")
|
32 |
|
33 |
# Retrieve relevant documents
|
@@ -52,23 +54,11 @@ def main():
|
|
52 |
# Valuations : Extract attributes from the response and source documents
|
53 |
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
|
54 |
|
55 |
-
#
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
json_str = result_content[json_start:json_end]
|
62 |
-
|
63 |
-
try:
|
64 |
-
result_json = json.loads(json_str)
|
65 |
-
print(json.dumps(result_json, indent=2))
|
66 |
-
|
67 |
-
# Compute metrics using the extracted attributes
|
68 |
-
metrics = compute_metrics(result_json, total_sentences)
|
69 |
-
print(metrics)
|
70 |
-
except json.JSONDecodeError as e:
|
71 |
-
logging.error(f"JSONDecodeError: {e}")
|
72 |
-
|
73 |
if __name__ == "__main__":
|
74 |
main()
|
|
|
1 |
+
import logging
|
2 |
from data.load_dataset import load_data
|
3 |
+
from generator import compute_rmse_auc_roc_metrics
|
4 |
from retriever.chunk_documents import chunk_documents
|
5 |
from retriever.embed_documents import embed_documents
|
6 |
from retriever.retrieve_documents import retrieve_top_k_documents
|
7 |
from generator.initialize_llm import initialize_llm
|
8 |
from generator.generate_response import generate_response
|
9 |
from generator.extract_attributes import extract_attributes
|
10 |
+
from generator.compute_metrics import get_metrics
|
11 |
|
12 |
# Configure logging
|
13 |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
|
|
28 |
logging.info("Documents embedded")
|
29 |
|
30 |
# Sample question
|
31 |
+
row_num = 1
|
32 |
+
sample_question = dataset[row_num]['question']
|
33 |
logging.info(f"Sample question: {sample_question}")
|
34 |
|
35 |
# Retrieve relevant documents
|
|
|
54 |
# Valuations : Extract attributes from the response and source documents
|
55 |
attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
|
56 |
|
57 |
+
# Call the process_attributes method in the main block
|
58 |
+
metrics = get_metrics(attributes, total_sentences)
|
59 |
+
|
60 |
+
#Compute RMSE and AUC-ROC for entire dataset
|
61 |
+
#compute_rmse_auc_roc_metrics(llm, dataset, vector_store)
|
62 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
if __name__ == "__main__":
|
64 |
main()
|