Gourisankar Padihary commited on
Commit
cfb3435
·
1 Parent(s): b58a992

corrected rmse and auroc calculation

Browse files
generator/compute_metrics.py CHANGED
@@ -20,8 +20,8 @@ def compute_metrics(attributes, total_sentences):
20
  completeness_score = len(Ri & Ui) / len(Ri) if len(Ri) else 0
21
 
22
  # Compute Adherence
23
- #adherence = all(info.get("fully_supported", False) for info in sentence_support_information)
24
- adherence = 1 if all(info.get("fully_supported", False) for info in sentence_support_information) else 0
25
 
26
  return {
27
  "Context Relevance": context_relevance,
 
20
  completeness_score = len(Ri & Ui) / len(Ri) if len(Ri) else 0
21
 
22
  # Compute Adherence
23
+ adherence = all(info.get("fully_supported", False) for info in sentence_support_information)
24
+ #adherence = 1 if all(info.get("fully_supported", False) for info in sentence_support_information) else 0
25
 
26
  return {
27
  "Context Relevance": context_relevance,
generator/compute_rmse_auc_roc_metrics.py CHANGED
@@ -15,17 +15,12 @@ def compute_rmse_auc_roc_metrics(llm, dataset, vector_store, num_question):
15
  all_ground_truth_adherence = []
16
  all_predicted_adherence = []
17
 
18
- # To store RMSE scores for each question
19
- relevance_scores = []
20
- utilization_scores = []
21
- adherence_scores = []
22
-
23
  # For each question in dataset get the metrics
24
  for i, document in enumerate(dataset):
25
  # Extract ground truth metrics from dataset
26
  ground_truth_relevance = dataset[i]['relevance_score']
27
  ground_truth_utilization = dataset[i]['utilization_score']
28
- ground_truth_adherence = dataset[i]['gpt3_adherence']
29
 
30
  query = document['question']
31
  logging.info(f'Query number: {i + 1}')
@@ -35,65 +30,38 @@ def compute_rmse_auc_roc_metrics(llm, dataset, vector_store, num_question):
35
  # Extract predicted metrics (ensure these are continuous if possible)
36
  predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
37
  predicted_utilization = metrics.get('Context Utilization', 0) if metrics else 0
38
- predicted_adherence = metrics.get('Adherence', 0) if metrics else 0
39
 
40
  # === Handle Continuous Inputs for RMSE ===
41
- relevance_rmse = root_mean_squared_error([ground_truth_relevance], [predicted_relevance])
42
- utilization_rmse = root_mean_squared_error([ground_truth_utilization], [predicted_utilization])
43
- adherence_rmse = root_mean_squared_error([ground_truth_adherence], [predicted_adherence])
44
-
45
- # === Handle Binary Conversion for AUC-ROC ===
46
- binary_ground_truth_relevance = 1 if ground_truth_relevance > 0.5 else 0
47
- #binary_predicted_relevance = 1 if predicted_relevance > 0.5 else 0
48
-
49
- binary_ground_truth_utilization = 1 if ground_truth_utilization > 0.2 else 0
50
- #binary_predicted_utilization = 1 if predicted_utilization > 0.5 else 0
51
-
52
- #binary_ground_truth_adherence = 1 if ground_truth_adherence > 0.5 else 0
53
- #binary_predicted_adherence = 1 if predicted_adherence > 0.5 else 0
54
-
55
- # === Accumulate data for overall AUC-ROC computation ===
56
- all_ground_truth_relevance.append(binary_ground_truth_relevance)
57
- all_predicted_relevance.append(predicted_relevance) # Use probability-based predictions
58
-
59
- all_ground_truth_utilization.append(binary_ground_truth_utilization)
60
  all_predicted_utilization.append(predicted_utilization)
61
-
62
  all_ground_truth_adherence.append(ground_truth_adherence)
63
  all_predicted_adherence.append(predicted_adherence)
64
 
65
- # Store RMSE scores for each question
66
- relevance_scores.append(relevance_rmse)
67
- utilization_scores.append(utilization_rmse)
68
- adherence_scores.append(adherence_rmse)
69
  if i == num_question:
70
  break
71
 
72
- # === Compute AUC-ROC for the Entire Dataset ===
73
  try:
74
- #print(f"All Ground Truth Relevance: {all_ground_truth_relevance}")
75
- #print(f"All Predicted Relevance: {all_predicted_relevance}")
76
- relevance_auc = roc_auc_score(all_ground_truth_relevance, all_predicted_relevance)
77
  except ValueError:
78
- relevance_auc = None
79
 
80
  try:
81
- #print(f"All Ground Truth Utilization: {all_ground_truth_utilization}")
82
- #print(f"All Predicted Utilization: {all_predicted_utilization}")
83
- utilization_auc = roc_auc_score(all_ground_truth_utilization, all_predicted_utilization)
84
  except ValueError:
85
- utilization_auc = None
86
 
87
  try:
88
- #print(f"All Ground Truth Adherence: {all_ground_truth_utilization}")
89
- #print(f"All Predicted Utilization: {all_predicted_utilization}")
90
  adherence_auc = roc_auc_score(all_ground_truth_adherence, all_predicted_adherence)
91
  except ValueError:
92
  adherence_auc = None
93
 
94
- print(f"Relevance RMSE (per question): {relevance_scores}")
95
- print(f"Utilization RMSE (per question): {utilization_scores}")
96
- print(f"Adherence RMSE (per question): {adherence_scores}")
97
- print(f"\nOverall Relevance AUC-ROC: {relevance_auc}")
98
- print(f"Overall Utilization AUC-ROC: {utilization_auc}")
99
  print(f"Overall Adherence AUC-ROC: {adherence_auc}")
 
15
  all_ground_truth_adherence = []
16
  all_predicted_adherence = []
17
 
 
 
 
 
 
18
  # For each question in dataset get the metrics
19
  for i, document in enumerate(dataset):
20
  # Extract ground truth metrics from dataset
21
  ground_truth_relevance = dataset[i]['relevance_score']
22
  ground_truth_utilization = dataset[i]['utilization_score']
23
+ ground_truth_adherence = 1 if dataset[i]['adherence_score'] else 0
24
 
25
  query = document['question']
26
  logging.info(f'Query number: {i + 1}')
 
30
  # Extract predicted metrics (ensure these are continuous if possible)
31
  predicted_relevance = metrics.get('Context Relevance', 0) if metrics else 0
32
  predicted_utilization = metrics.get('Context Utilization', 0) if metrics else 0
33
+ predicted_adherence = 1 if metrics.get('Adherence', False) else 0
34
 
35
  # === Handle Continuous Inputs for RMSE ===
36
+ all_ground_truth_relevance.append(ground_truth_relevance)
37
+ all_predicted_relevance.append(predicted_relevance)
38
+ all_ground_truth_utilization.append(ground_truth_utilization)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  all_predicted_utilization.append(predicted_utilization)
40
+
41
  all_ground_truth_adherence.append(ground_truth_adherence)
42
  all_predicted_adherence.append(predicted_adherence)
43
 
 
 
 
 
44
  if i == num_question:
45
  break
46
 
47
+ # === Compute RMSE & AUC-ROC for the Entire Dataset ===
48
  try:
49
+ relevance_rmse = root_mean_squared_error(all_ground_truth_relevance, all_predicted_relevance)
 
 
50
  except ValueError:
51
+ relevance_rmse = None
52
 
53
  try:
54
+ utilization_rmse = root_mean_squared_error(all_ground_truth_utilization, all_predicted_utilization)
 
 
55
  except ValueError:
56
+ utilization_rmse = None
57
 
58
  try:
59
+ print(f"All Ground Truth Adherence: {all_ground_truth_utilization}")
60
+ print(f"All Predicted Utilization: {all_predicted_utilization}")
61
  adherence_auc = roc_auc_score(all_ground_truth_adherence, all_predicted_adherence)
62
  except ValueError:
63
  adherence_auc = None
64
 
65
+ print(f"Relevance RMSE score: {relevance_rmse}")
66
+ print(f"Utilization RMSE score: {utilization_rmse}")
 
 
 
67
  print(f"Overall Adherence AUC-ROC: {adherence_auc}")
main.py CHANGED
@@ -11,7 +11,7 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
11
 
12
  def main():
13
  logging.info("Starting the RAG pipeline")
14
- data_set_name = 'techqa'
15
 
16
  # Load the dataset
17
  dataset = load_data(data_set_name)
@@ -39,7 +39,7 @@ def main():
39
  generate_metrics(llm, vector_store, sample_question)
40
 
41
  #Compute RMSE and AUC-ROC for entire dataset
42
- #compute_rmse_auc_roc_metrics(llm, dataset, vector_store, 10)
43
 
44
  logging.info("Finished!!!")
45
 
 
11
 
12
  def main():
13
  logging.info("Starting the RAG pipeline")
14
+ data_set_name = 'covidqa'
15
 
16
  # Load the dataset
17
  dataset = load_data(data_set_name)
 
39
  generate_metrics(llm, vector_store, sample_question)
40
 
41
  #Compute RMSE and AUC-ROC for entire dataset
42
+ compute_rmse_auc_roc_metrics(llm, dataset, vector_store, 10)
43
 
44
  logging.info("Finished!!!")
45