Gourisankar Padihary commited on
Commit
f7c2fa3
·
1 Parent(s): b1b2c27

Compute RMSE and AUCROC

Browse files
data/load_dataset.py CHANGED
@@ -5,5 +5,5 @@ def load_data():
5
  logging.info("Loading dataset")
6
  dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
7
  logging.info("Dataset loaded successfully")
8
- logging.info(dataset)
9
  return dataset
 
5
  logging.info("Loading dataset")
6
  dataset = load_dataset("rungalileo/ragbench", 'covidqa', split="test")
7
  logging.info("Dataset loaded successfully")
8
+ logging.info(f"Number of documents found: {dataset.num_rows}")
9
  return dataset
generator/compute_metrics.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  def compute_metrics(attributes, total_sentences):
2
  # Extract relevant information from attributes
3
  all_relevant_sentence_keys = attributes.get("all_relevant_sentence_keys", [])
@@ -8,17 +11,39 @@ def compute_metrics(attributes, total_sentences):
8
  context_relevance = len(all_relevant_sentence_keys) / total_sentences if total_sentences else 0
9
 
10
  # Compute Context Utilization
11
- context_utilization = len(all_utilized_sentence_keys) / len(sentence_support_information) if sentence_support_information else 0
12
-
13
- # Compute Completeness
14
- completeness = all(info.get("fully_supported", False) for info in sentence_support_information)
15
 
 
 
 
 
 
 
16
  # Compute Adherence
17
- adherence = attributes.get("overall_supported", False)
18
 
19
  return {
20
  "Context Relevance": context_relevance,
21
  "Context Utilization": context_utilization,
22
- "Completeness": completeness,
23
  "Adherence": adherence
24
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+
4
  def compute_metrics(attributes, total_sentences):
5
  # Extract relevant information from attributes
6
  all_relevant_sentence_keys = attributes.get("all_relevant_sentence_keys", [])
 
11
  context_relevance = len(all_relevant_sentence_keys) / total_sentences if total_sentences else 0
12
 
13
  # Compute Context Utilization
14
+ context_utilization = len(all_utilized_sentence_keys) / total_sentences if total_sentences else 0
 
 
 
15
 
16
+ # Compute Completeness score
17
+ Ri = set(all_relevant_sentence_keys)
18
+ Ui = set(all_utilized_sentence_keys)
19
+
20
+ completeness_score = len(Ri & Ui) / len(Ri) if len(Ri) else 0
21
+
22
  # Compute Adherence
23
+ adherence = all(info.get("fully_supported", False) for info in sentence_support_information)
24
 
25
  return {
26
  "Context Relevance": context_relevance,
27
  "Context Utilization": context_utilization,
28
+ "Completeness Score": completeness_score,
29
  "Adherence": adherence
30
+ }
31
+
32
+ def get_metrics(attributes, total_sentences):
33
+ if attributes.content:
34
+ result_content = attributes.content # Access the content attribute
35
+ # Extract the JSON part from the result_content
36
+ json_start = result_content.find("{")
37
+ json_end = result_content.rfind("}") + 1
38
+ json_str = result_content[json_start:json_end]
39
+
40
+ try:
41
+ result_json = json.loads(json_str)
42
+ print(json.dumps(result_json, indent=2))
43
+
44
+ # Compute metrics using the extracted attributes
45
+ metrics = compute_metrics(result_json, total_sentences)
46
+ print(metrics)
47
+ return metrics
48
+ except json.JSONDecodeError as e:
49
+ logging.error(f"JSONDecodeError: {e}")
generator/compute_rmse_auc_roc_metrics.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from sklearn.metrics import roc_auc_score, root_mean_squared_error
3
+ from generator.compute_metrics import get_metrics
4
+ from generator.extract_attributes import extract_attributes
5
+ from generator.generate_response import generate_response
6
+ from retriever.retrieve_documents import retrieve_top_k_documents
7
+
8
+ def compute_rmse_auc_roc_metrics(llm, dataset, vector_store):
9
+
10
+ # Lists to accumulate ground truths and predictions for AUC-ROC computation
11
+ all_ground_truth_relevance = []
12
+ all_predicted_relevance = []
13
+
14
+ all_ground_truth_utilization = []
15
+ all_predicted_utilization = []
16
+
17
+ all_ground_truth_adherence = []
18
+ all_predicted_adherence = []
19
+
20
+ # To store RMSE scores for each question
21
+ relevance_scores = []
22
+ utilization_scores = []
23
+ adherence_scores = []
24
+
25
+ for i, sample in enumerate(dataset):
26
+ print(sample)
27
+ sample_question = sample['question']
28
+
29
+ # Extract ground truth metrics from dataset
30
+ ground_truth_relevance = dataset[i]['relevance_score']
31
+ ground_truth_utilization = dataset[i]['utilization_score']
32
+ ground_truth_completeness = dataset[i]['completeness_score']
33
+
34
+ # Step 1: Retrieve relevant documents
35
+ relevant_docs = retrieve_top_k_documents(vector_store, sample_question, top_k=5)
36
+
37
+ # Step 2: Generate a response using LLM
38
+ response, source_docs = generate_response(llm, vector_store, sample_question, relevant_docs)
39
+
40
+ # Step 3: Extract attributes
41
+ attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
42
+
43
+ # Call the process_attributes method in the main block
44
+ metrics = get_metrics(attributes, total_sentences)
45
+
46
+ # Extract predicted metrics (ensure these are continuous if possible)
47
+ predicted_relevance = metrics['Context Relevance']
48
+ predicted_utilization = metrics['Context Utilization']
49
+ predicted_completeness = metrics['Completeness Score']
50
+
51
+ # === Handle Continuous Inputs for RMSE ===
52
+ relevance_rmse = root_mean_squared_error([ground_truth_relevance], [predicted_relevance])
53
+ utilization_rmse = root_mean_squared_error([ground_truth_utilization], [predicted_utilization])
54
+ #adherence_rmse = mean_squared_error([ground_truth_adherence], [predicted_adherence], squared=False)
55
+
56
+ # === Handle Binary Conversion for AUC-ROC ===
57
+ binary_ground_truth_relevance = 1 if ground_truth_relevance > 0.5 else 0
58
+ binary_predicted_relevance = 1 if predicted_relevance > 0.5 else 0
59
+
60
+ binary_ground_truth_utilization = 1 if ground_truth_utilization > 0.5 else 0
61
+ binary_predicted_utilization = 1 if predicted_utilization > 0.5 else 0
62
+
63
+ #binary_ground_truth_adherence = 1 if ground_truth_adherence > 0.5 else 0
64
+ #binary_predicted_adherence = 1 if predicted_adherence > 0.5 else 0
65
+
66
+ # === Accumulate data for overall AUC-ROC computation ===
67
+ all_ground_truth_relevance.append(binary_ground_truth_relevance)
68
+ all_predicted_relevance.append(predicted_relevance) # Use probability-based predictions
69
+
70
+ all_ground_truth_utilization.append(binary_ground_truth_utilization)
71
+ all_predicted_utilization.append(predicted_utilization)
72
+
73
+ #all_ground_truth_adherence.append(binary_ground_truth_adherence)
74
+ #all_predicted_adherence.append(predicted_adherence)
75
+
76
+ # Store RMSE scores for each question
77
+ relevance_scores.append(relevance_rmse)
78
+ utilization_scores.append(utilization_rmse)
79
+ #adherence_scores.append(adherence_rmse)
80
+ if i == 9: # Stop after processing the first 10 rows
81
+ break
82
+ # === Compute AUC-ROC for the Entire Dataset ===
83
+ try:
84
+ print(f"All Ground Truth Relevance: {all_ground_truth_relevance}")
85
+ print(f"All Predicted Relevance: {all_predicted_relevance}")
86
+ relevance_auc = roc_auc_score(all_ground_truth_relevance, all_predicted_relevance)
87
+ except ValueError:
88
+ relevance_auc = None
89
+
90
+ try:
91
+ print(f"All Ground Truth Utilization: {all_ground_truth_utilization}")
92
+ print(f"All Predicted Utilization: {all_predicted_utilization}")
93
+ utilization_auc = roc_auc_score(all_ground_truth_utilization, all_predicted_utilization)
94
+ except ValueError:
95
+ utilization_auc = None
96
+
97
+ print(f"Relevance RMSE (per question): {relevance_scores}")
98
+ print(f"Utilization RMSE (per question): {utilization_scores}")
99
+ #print(f"Adherence RMSE (per question): {adherence_scores}")
100
+ print(f"\nOverall Relevance AUC-ROC: {relevance_auc}")
101
+ print(f"Overall Utilization AUC-ROC: {utilization_auc}")
generator/extract_attributes.py CHANGED
@@ -1,9 +1,9 @@
1
  from generator.create_prompt import create_prompt
2
- from generator.initialize_llm import initialize_llm
3
  from generator.document_utils import Document, apply_sentence_keys_documents, apply_sentence_keys_response
4
 
5
  # Initialize the LLM
6
- llm = initialize_llm()
7
 
8
  # Function to extract attributes
9
  def extract_attributes(question, relevant_docs, response):
@@ -12,9 +12,16 @@ def extract_attributes(question, relevant_docs, response):
12
  formatted_documents = apply_sentence_keys_documents(relevant_docs)
13
  formatted_responses = apply_sentence_keys_response(response)
14
 
 
 
 
 
 
 
15
  # Calculate the total number of sentences from formatted_documents
16
  total_sentences = sum(len(doc) for doc in formatted_documents)
17
-
 
18
  attribute_prompt = create_prompt(formatted_documents, question, formatted_responses)
19
 
20
  # Instead of using BaseMessage, pass the formatted prompt directly to invoke
 
1
  from generator.create_prompt import create_prompt
2
+ from generator.initialize_llm import initialize_validation_llm
3
  from generator.document_utils import Document, apply_sentence_keys_documents, apply_sentence_keys_response
4
 
5
  # Initialize the LLM
6
+ llm = initialize_validation_llm()
7
 
8
  # Function to extract attributes
9
  def extract_attributes(question, relevant_docs, response):
 
12
  formatted_documents = apply_sentence_keys_documents(relevant_docs)
13
  formatted_responses = apply_sentence_keys_response(response)
14
 
15
+ #print(f"Formatted documents : {formatted_documents}")
16
+ # Print the number of sentences in each document
17
+ for i, doc in enumerate(formatted_documents):
18
+ num_sentences = len(doc)
19
+ print(f"Document {i} has {num_sentences} sentences.")
20
+
21
  # Calculate the total number of sentences from formatted_documents
22
  total_sentences = sum(len(doc) for doc in formatted_documents)
23
+ print(f"Total number of sentences {total_sentences}")
24
+
25
  attribute_prompt = create_prompt(formatted_documents, question, formatted_responses)
26
 
27
  # Instead of using BaseMessage, pass the formatted prompt directly to invoke
generator/initialize_llm.py CHANGED
@@ -4,4 +4,9 @@ from langchain_groq import ChatGroq
4
  def initialize_llm():
5
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
6
  llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
 
 
 
 
 
7
  return llm
 
4
  def initialize_llm():
5
  os.environ["GROQ_API_KEY"] = "your_groq_api_key"
6
  llm = ChatGroq(model="llama3-8b-8192", temperature=0.7)
7
+ return llm
8
+
9
+ def initialize_validation_llm():
10
+ os.environ["GROQ_API_KEY"] = "your_groq_api_key"
11
+ llm = ChatGroq(model="llama3-70b-8192", temperature=0.7)
12
  return llm
main.py CHANGED
@@ -1,12 +1,13 @@
1
- import logging, json
2
  from data.load_dataset import load_data
 
3
  from retriever.chunk_documents import chunk_documents
4
  from retriever.embed_documents import embed_documents
5
  from retriever.retrieve_documents import retrieve_top_k_documents
6
  from generator.initialize_llm import initialize_llm
7
  from generator.generate_response import generate_response
8
  from generator.extract_attributes import extract_attributes
9
- from generator.compute_metrics import compute_metrics
10
 
11
  # Configure logging
12
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
@@ -27,7 +28,8 @@ def main():
27
  logging.info("Documents embedded")
28
 
29
  # Sample question
30
- sample_question = dataset[0]['question']
 
31
  logging.info(f"Sample question: {sample_question}")
32
 
33
  # Retrieve relevant documents
@@ -52,23 +54,11 @@ def main():
52
  # Valuations : Extract attributes from the response and source documents
53
  attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
54
 
55
- # Only proceed if the content is not empty
56
- if attributes.content:
57
- result_content = attributes.content # Access the content attribute
58
- # Extract the JSON part from the result_content
59
- json_start = result_content.find("{")
60
- json_end = result_content.rfind("}") + 1
61
- json_str = result_content[json_start:json_end]
62
-
63
- try:
64
- result_json = json.loads(json_str)
65
- print(json.dumps(result_json, indent=2))
66
-
67
- # Compute metrics using the extracted attributes
68
- metrics = compute_metrics(result_json, total_sentences)
69
- print(metrics)
70
- except json.JSONDecodeError as e:
71
- logging.error(f"JSONDecodeError: {e}")
72
-
73
  if __name__ == "__main__":
74
  main()
 
1
+ import logging
2
  from data.load_dataset import load_data
3
+ from generator import compute_rmse_auc_roc_metrics
4
  from retriever.chunk_documents import chunk_documents
5
  from retriever.embed_documents import embed_documents
6
  from retriever.retrieve_documents import retrieve_top_k_documents
7
  from generator.initialize_llm import initialize_llm
8
  from generator.generate_response import generate_response
9
  from generator.extract_attributes import extract_attributes
10
+ from generator.compute_metrics import get_metrics
11
 
12
  # Configure logging
13
  logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
 
28
  logging.info("Documents embedded")
29
 
30
  # Sample question
31
+ row_num = 1
32
+ sample_question = dataset[row_num]['question']
33
  logging.info(f"Sample question: {sample_question}")
34
 
35
  # Retrieve relevant documents
 
54
  # Valuations : Extract attributes from the response and source documents
55
  attributes, total_sentences = extract_attributes(sample_question, source_docs, response)
56
 
57
+ # Call the process_attributes method in the main block
58
+ metrics = get_metrics(attributes, total_sentences)
59
+
60
+ #Compute RMSE and AUC-ROC for entire dataset
61
+ #compute_rmse_auc_roc_metrics(llm, dataset, vector_store)
62
+
 
 
 
 
 
 
 
 
 
 
 
 
63
  if __name__ == "__main__":
64
  main()