m7mdal7aj commited on
Commit
1d9397e
·
verified ·
1 Parent(s): 1ee4a06

Update my_model/results/evaluation.py

Browse files
Files changed (1) hide show
  1. my_model/results/evaluation.py +83 -9
my_model/results/evaluation.py CHANGED
@@ -8,10 +8,38 @@ import streamlit as st
8
  from my_model.config import evaluation_config as config
9
 
10
  class KBVQAEvaluator:
11
- def __init__(self):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  """
13
- Initialize the VQA Processor with the dataset and configuration settings.
 
 
 
14
  """
 
15
  self.data_path = config.EVALUATION_DATA_PATH
16
  self.use_fuzzy = config.USE_FUZZY
17
  self.stemmer = PorterStemmer()
@@ -30,17 +58,32 @@ class KBVQAEvaluator:
30
  def stem_answers(self, answers: Union[str, List[str]]) -> Union[str, List[str]]:
31
  """
32
  Apply Porter Stemmer to either a single string or a list of strings.
 
 
 
 
 
 
33
  """
 
34
  if isinstance(answers, list):
35
  return [" ".join(self.stemmer.stem(word.strip()) for word in answer.split()) for answer in answers]
36
  else:
37
  words = answers.split()
38
  return " ".join(self.stemmer.stem(word.strip()) for word in words)
39
 
40
- def calculate_vqa_score(self, ground_truths, model_answer):
41
  """
42
  Calculate VQA score based on the number of matching answers, with optional fuzzy matching.
 
 
 
 
 
 
 
43
  """
 
44
  if self.use_fuzzy:
45
  fuzzy_matches = sum(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths)
46
  return min(fuzzy_matches / 3, 1)
@@ -48,19 +91,29 @@ class KBVQAEvaluator:
48
  count = Counter(ground_truths)
49
  return min(count.get(model_answer, 0) / 3, 1)
50
 
51
- def calculate_exact_match_score(self, ground_truths, model_answer):
52
  """
53
  Calculate Exact Match score, with optional fuzzy matching.
 
 
 
 
 
 
 
54
  """
 
55
  if self.use_fuzzy:
56
  return int(any(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths))
57
  else:
58
  return int(model_answer in ground_truths)
59
 
60
- def syntactic_evaluation(self):
61
  """
62
  Process the DataFrame: stem answers, calculate scores, and store results.
 
63
  """
 
64
  self.df['raw_answers_stemmed'] = self.df['raw_answers'].apply(literal_eval).apply(self.stem_answers)
65
 
66
  for name in self.model_names:
@@ -74,10 +127,19 @@ class KBVQAEvaluator:
74
  self.vqa_scores[full_config] = round(self.df[f'vqa_score_{full_config}'].mean()*100, 2)
75
  self.exact_match_scores[full_config] = round(self.df[f'exact_match_score_{full_config}'].mean()*100, 2)
76
 
77
- def create_GPT4_messages_template(self, question, ground_truths, model_answer):
78
  """
79
  Create a message list for the GPT-4 API call based on the question, ground truths, and model answer.
 
 
 
 
 
 
 
 
80
  """
 
81
  system_message = {
82
  "role": "system",
83
  "content": """You are an AI trained to evaluate the equivalence of AI-generated answers to a set of ground truth answers for a given question. Upon reviewing a model's answer, determine if it matches the ground truths. Use the following rating system: 1 if you find that the model answer matches more than 25% of the ground truth answers, 2 if you find that the model answer matches only less than 25% of the ground truth answers, and 3 if the model answer is incorrect. Respond in the format below for easy parsing:
@@ -93,7 +155,7 @@ class KBVQAEvaluator:
93
  return [system_message, user_message]
94
 
95
 
96
- def semantic_evaluation(self):
97
  """
98
  Perform semantic evaluation using GPT-4 for each model configuration.
99
  """
@@ -109,7 +171,14 @@ class KBVQAEvaluator:
109
  rating = int(evaluation.split('\n')[0].split(":")[1].strip())
110
  self.df.at[index, f'gpt4_rating_{config}'] = rating
111
 
112
- def save_results(self, save_filename):
 
 
 
 
 
 
 
113
  # Create a DataFrame for the scores
114
  scores_data = {
115
  'Model Configuration': list(self.vqa_scores.keys()),
@@ -123,10 +192,15 @@ class KBVQAEvaluator:
123
  self.df.to_excel(writer, sheet_name='Main Data', index=False)
124
  scores_df.to_excel(writer, sheet_name='Scores', index=False)
125
 
126
- def run_evaluation(save=False, save_filename="results"):
127
  """
128
  Run the full evaluation process using KBVQAEvaluator and save the results to an Excel file.
 
 
 
 
129
  """
 
130
  # Instantiate the evaluator
131
  evaluator = KBVQAEvaluator()
132
 
 
8
  from my_model.config import evaluation_config as config
9
 
10
  class KBVQAEvaluator:
11
+ """
12
+ A class to evaluate Knowledge-Based Visual Question Answering (KB-VQA) models.
13
+
14
+ This class provides methods for syntactic and semantic evaluation of the KB-VQA model,
15
+ using both exact match and VQA scores. The evaluation results can be saved to an
16
+ Excel file for further analysis.
17
+
18
+ Attributes:
19
+ data_path (str): Path to the evaluation data.
20
+ use_fuzzy (bool): Flag to determine if fuzzy matching should be used.
21
+ stemmer (PorterStemmer): Instance of PorterStemmer for stemming answers.
22
+ scores_df (pd.DataFrame): DataFrame containing scores.
23
+ df (pd.DataFrame): Main DataFrame containing evaluation data.
24
+ vqa_scores (Dict[str, float]): Dictionary to store VQA scores for different model configurations.
25
+ exact_match_scores (Dict[str, float]): Dictionary to store exact match scores for different model configurations.
26
+ fuzzy_threshold (int): Threshold for fuzzy matching score.
27
+ openai_api_key (str): API key for OpenAI GPT-4.
28
+ model_names (List[str]): List of model names to be evaluated.
29
+ model_configurations (List[str]): List of model configurations to be evaluated.
30
+ gpt4_seed (int): Seed for GPT-4 evaluation.
31
+ gpt4_max_tokens (int): Maximum tokens for GPT-4 responses.
32
+ gpt4_temperature (float): Temperature setting for GPT-4 responses.
33
+ """
34
+
35
+ def __init__(self): -> None
36
  """
37
+ Initialize the KBVQAEvaluator with the dataset and configuration settings.
38
+
39
+ Reads data from the specified paths in the configuration and initializes
40
+ various attributes required for evaluation.
41
  """
42
+
43
  self.data_path = config.EVALUATION_DATA_PATH
44
  self.use_fuzzy = config.USE_FUZZY
45
  self.stemmer = PorterStemmer()
 
58
  def stem_answers(self, answers: Union[str, List[str]]) -> Union[str, List[str]]:
59
  """
60
  Apply Porter Stemmer to either a single string or a list of strings.
61
+
62
+ Args:
63
+ answers (Union[str, List[str]]): A single answer string or a list of answer strings.
64
+
65
+ Returns:
66
+ Union[str, List[str]]: Stemmed version of the input string or list of strings.
67
  """
68
+
69
  if isinstance(answers, list):
70
  return [" ".join(self.stemmer.stem(word.strip()) for word in answer.split()) for answer in answers]
71
  else:
72
  words = answers.split()
73
  return " ".join(self.stemmer.stem(word.strip()) for word in words)
74
 
75
+ def calculate_vqa_score(self, ground_truths: List[str], model_answer: str) -> float:
76
  """
77
  Calculate VQA score based on the number of matching answers, with optional fuzzy matching.
78
+
79
+ Args:
80
+ ground_truths (List[str]): List of ground truth answers.
81
+ model_answer (str): Model's answer to be evaluated.
82
+
83
+ Returns:
84
+ float: VQA score based on the number of matches.
85
  """
86
+
87
  if self.use_fuzzy:
88
  fuzzy_matches = sum(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths)
89
  return min(fuzzy_matches / 3, 1)
 
91
  count = Counter(ground_truths)
92
  return min(count.get(model_answer, 0) / 3, 1)
93
 
94
+ def calculate_exact_match_score(self, ground_truths: List[str], model_answer: str) -> int:
95
  """
96
  Calculate Exact Match score, with optional fuzzy matching.
97
+
98
+ Args:
99
+ ground_truths (List[str]): List of ground truth answers.
100
+ model_answer (str): Model's answer to be evaluated.
101
+
102
+ Returns:
103
+ int: Exact match score (1 if there is a match, 0 otherwise).
104
  """
105
+
106
  if self.use_fuzzy:
107
  return int(any(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths))
108
  else:
109
  return int(model_answer in ground_truths)
110
 
111
+ def syntactic_evaluation(self) -> None:
112
  """
113
  Process the DataFrame: stem answers, calculate scores, and store results.
114
+
115
  """
116
+
117
  self.df['raw_answers_stemmed'] = self.df['raw_answers'].apply(literal_eval).apply(self.stem_answers)
118
 
119
  for name in self.model_names:
 
127
  self.vqa_scores[full_config] = round(self.df[f'vqa_score_{full_config}'].mean()*100, 2)
128
  self.exact_match_scores[full_config] = round(self.df[f'exact_match_score_{full_config}'].mean()*100, 2)
129
 
130
+ def create_GPT4_messages_template(self, question: str, ground_truths: List[str], model_answer: str) -> List[dict]:
131
  """
132
  Create a message list for the GPT-4 API call based on the question, ground truths, and model answer.
133
+
134
+ Args:
135
+ question (str): The question being evaluated.
136
+ ground_truths (List[str]): List of ground truth answers.
137
+ model_answer (str): Model's answer to be evaluated.
138
+
139
+ Returns:
140
+ List[dict]: Messages formatted for GPT-4 API call.
141
  """
142
+
143
  system_message = {
144
  "role": "system",
145
  "content": """You are an AI trained to evaluate the equivalence of AI-generated answers to a set of ground truth answers for a given question. Upon reviewing a model's answer, determine if it matches the ground truths. Use the following rating system: 1 if you find that the model answer matches more than 25% of the ground truth answers, 2 if you find that the model answer matches only less than 25% of the ground truth answers, and 3 if the model answer is incorrect. Respond in the format below for easy parsing:
 
155
  return [system_message, user_message]
156
 
157
 
158
+ def semantic_evaluation(self) -> None:
159
  """
160
  Perform semantic evaluation using GPT-4 for each model configuration.
161
  """
 
171
  rating = int(evaluation.split('\n')[0].split(":")[1].strip())
172
  self.df.at[index, f'gpt4_rating_{config}'] = rating
173
 
174
+ def save_results(self, save_filename: str) -> None:
175
+ """
176
+ Save the evaluation results to an Excel file.
177
+
178
+ Args:
179
+ save_filename (str): The filename to save the results.
180
+ """
181
+
182
  # Create a DataFrame for the scores
183
  scores_data = {
184
  'Model Configuration': list(self.vqa_scores.keys()),
 
192
  self.df.to_excel(writer, sheet_name='Main Data', index=False)
193
  scores_df.to_excel(writer, sheet_name='Scores', index=False)
194
 
195
+ def run_evaluation(save: bool = False, save_filename: str = "results") -> None:
196
  """
197
  Run the full evaluation process using KBVQAEvaluator and save the results to an Excel file.
198
+
199
+ Args:
200
+ save (bool): Whether to save the results to an Excel file. Defaults to False.
201
+ save_filename (str): The filename to save the results if save is True. Defaults to "results".
202
  """
203
+
204
  # Instantiate the evaluator
205
  evaluator = KBVQAEvaluator()
206