Update my_model/results/evaluation.py
Browse files
my_model/results/evaluation.py
CHANGED
@@ -8,10 +8,38 @@ import streamlit as st
|
|
8 |
from my_model.config import evaluation_config as config
|
9 |
|
10 |
class KBVQAEvaluator:
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
"""
|
13 |
-
Initialize the
|
|
|
|
|
|
|
14 |
"""
|
|
|
15 |
self.data_path = config.EVALUATION_DATA_PATH
|
16 |
self.use_fuzzy = config.USE_FUZZY
|
17 |
self.stemmer = PorterStemmer()
|
@@ -30,17 +58,32 @@ class KBVQAEvaluator:
|
|
30 |
def stem_answers(self, answers: Union[str, List[str]]) -> Union[str, List[str]]:
|
31 |
"""
|
32 |
Apply Porter Stemmer to either a single string or a list of strings.
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"""
|
|
|
34 |
if isinstance(answers, list):
|
35 |
return [" ".join(self.stemmer.stem(word.strip()) for word in answer.split()) for answer in answers]
|
36 |
else:
|
37 |
words = answers.split()
|
38 |
return " ".join(self.stemmer.stem(word.strip()) for word in words)
|
39 |
|
40 |
-
|
41 |
"""
|
42 |
Calculate VQA score based on the number of matching answers, with optional fuzzy matching.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
"""
|
|
|
44 |
if self.use_fuzzy:
|
45 |
fuzzy_matches = sum(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths)
|
46 |
return min(fuzzy_matches / 3, 1)
|
@@ -48,19 +91,29 @@ class KBVQAEvaluator:
|
|
48 |
count = Counter(ground_truths)
|
49 |
return min(count.get(model_answer, 0) / 3, 1)
|
50 |
|
51 |
-
def calculate_exact_match_score(self, ground_truths, model_answer):
|
52 |
"""
|
53 |
Calculate Exact Match score, with optional fuzzy matching.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
54 |
"""
|
|
|
55 |
if self.use_fuzzy:
|
56 |
return int(any(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths))
|
57 |
else:
|
58 |
return int(model_answer in ground_truths)
|
59 |
|
60 |
-
def syntactic_evaluation(self):
|
61 |
"""
|
62 |
Process the DataFrame: stem answers, calculate scores, and store results.
|
|
|
63 |
"""
|
|
|
64 |
self.df['raw_answers_stemmed'] = self.df['raw_answers'].apply(literal_eval).apply(self.stem_answers)
|
65 |
|
66 |
for name in self.model_names:
|
@@ -74,10 +127,19 @@ class KBVQAEvaluator:
|
|
74 |
self.vqa_scores[full_config] = round(self.df[f'vqa_score_{full_config}'].mean()*100, 2)
|
75 |
self.exact_match_scores[full_config] = round(self.df[f'exact_match_score_{full_config}'].mean()*100, 2)
|
76 |
|
77 |
-
def create_GPT4_messages_template(self, question, ground_truths, model_answer):
|
78 |
"""
|
79 |
Create a message list for the GPT-4 API call based on the question, ground truths, and model answer.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
"""
|
|
|
81 |
system_message = {
|
82 |
"role": "system",
|
83 |
"content": """You are an AI trained to evaluate the equivalence of AI-generated answers to a set of ground truth answers for a given question. Upon reviewing a model's answer, determine if it matches the ground truths. Use the following rating system: 1 if you find that the model answer matches more than 25% of the ground truth answers, 2 if you find that the model answer matches only less than 25% of the ground truth answers, and 3 if the model answer is incorrect. Respond in the format below for easy parsing:
|
@@ -93,7 +155,7 @@ class KBVQAEvaluator:
|
|
93 |
return [system_message, user_message]
|
94 |
|
95 |
|
96 |
-
def semantic_evaluation(self):
|
97 |
"""
|
98 |
Perform semantic evaluation using GPT-4 for each model configuration.
|
99 |
"""
|
@@ -109,7 +171,14 @@ class KBVQAEvaluator:
|
|
109 |
rating = int(evaluation.split('\n')[0].split(":")[1].strip())
|
110 |
self.df.at[index, f'gpt4_rating_{config}'] = rating
|
111 |
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
# Create a DataFrame for the scores
|
114 |
scores_data = {
|
115 |
'Model Configuration': list(self.vqa_scores.keys()),
|
@@ -123,10 +192,15 @@ class KBVQAEvaluator:
|
|
123 |
self.df.to_excel(writer, sheet_name='Main Data', index=False)
|
124 |
scores_df.to_excel(writer, sheet_name='Scores', index=False)
|
125 |
|
126 |
-
def run_evaluation(save=False, save_filename="results"):
|
127 |
"""
|
128 |
Run the full evaluation process using KBVQAEvaluator and save the results to an Excel file.
|
|
|
|
|
|
|
|
|
129 |
"""
|
|
|
130 |
# Instantiate the evaluator
|
131 |
evaluator = KBVQAEvaluator()
|
132 |
|
|
|
8 |
from my_model.config import evaluation_config as config
|
9 |
|
10 |
class KBVQAEvaluator:
|
11 |
+
"""
|
12 |
+
A class to evaluate Knowledge-Based Visual Question Answering (KB-VQA) models.
|
13 |
+
|
14 |
+
This class provides methods for syntactic and semantic evaluation of the KB-VQA model,
|
15 |
+
using both exact match and VQA scores. The evaluation results can be saved to an
|
16 |
+
Excel file for further analysis.
|
17 |
+
|
18 |
+
Attributes:
|
19 |
+
data_path (str): Path to the evaluation data.
|
20 |
+
use_fuzzy (bool): Flag to determine if fuzzy matching should be used.
|
21 |
+
stemmer (PorterStemmer): Instance of PorterStemmer for stemming answers.
|
22 |
+
scores_df (pd.DataFrame): DataFrame containing scores.
|
23 |
+
df (pd.DataFrame): Main DataFrame containing evaluation data.
|
24 |
+
vqa_scores (Dict[str, float]): Dictionary to store VQA scores for different model configurations.
|
25 |
+
exact_match_scores (Dict[str, float]): Dictionary to store exact match scores for different model configurations.
|
26 |
+
fuzzy_threshold (int): Threshold for fuzzy matching score.
|
27 |
+
openai_api_key (str): API key for OpenAI GPT-4.
|
28 |
+
model_names (List[str]): List of model names to be evaluated.
|
29 |
+
model_configurations (List[str]): List of model configurations to be evaluated.
|
30 |
+
gpt4_seed (int): Seed for GPT-4 evaluation.
|
31 |
+
gpt4_max_tokens (int): Maximum tokens for GPT-4 responses.
|
32 |
+
gpt4_temperature (float): Temperature setting for GPT-4 responses.
|
33 |
+
"""
|
34 |
+
|
35 |
+
def __init__(self): -> None
|
36 |
"""
|
37 |
+
Initialize the KBVQAEvaluator with the dataset and configuration settings.
|
38 |
+
|
39 |
+
Reads data from the specified paths in the configuration and initializes
|
40 |
+
various attributes required for evaluation.
|
41 |
"""
|
42 |
+
|
43 |
self.data_path = config.EVALUATION_DATA_PATH
|
44 |
self.use_fuzzy = config.USE_FUZZY
|
45 |
self.stemmer = PorterStemmer()
|
|
|
58 |
def stem_answers(self, answers: Union[str, List[str]]) -> Union[str, List[str]]:
|
59 |
"""
|
60 |
Apply Porter Stemmer to either a single string or a list of strings.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
answers (Union[str, List[str]]): A single answer string or a list of answer strings.
|
64 |
+
|
65 |
+
Returns:
|
66 |
+
Union[str, List[str]]: Stemmed version of the input string or list of strings.
|
67 |
"""
|
68 |
+
|
69 |
if isinstance(answers, list):
|
70 |
return [" ".join(self.stemmer.stem(word.strip()) for word in answer.split()) for answer in answers]
|
71 |
else:
|
72 |
words = answers.split()
|
73 |
return " ".join(self.stemmer.stem(word.strip()) for word in words)
|
74 |
|
75 |
+
def calculate_vqa_score(self, ground_truths: List[str], model_answer: str) -> float:
|
76 |
"""
|
77 |
Calculate VQA score based on the number of matching answers, with optional fuzzy matching.
|
78 |
+
|
79 |
+
Args:
|
80 |
+
ground_truths (List[str]): List of ground truth answers.
|
81 |
+
model_answer (str): Model's answer to be evaluated.
|
82 |
+
|
83 |
+
Returns:
|
84 |
+
float: VQA score based on the number of matches.
|
85 |
"""
|
86 |
+
|
87 |
if self.use_fuzzy:
|
88 |
fuzzy_matches = sum(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths)
|
89 |
return min(fuzzy_matches / 3, 1)
|
|
|
91 |
count = Counter(ground_truths)
|
92 |
return min(count.get(model_answer, 0) / 3, 1)
|
93 |
|
94 |
+
def calculate_exact_match_score(self, ground_truths: List[str], model_answer: str) -> int:
|
95 |
"""
|
96 |
Calculate Exact Match score, with optional fuzzy matching.
|
97 |
+
|
98 |
+
Args:
|
99 |
+
ground_truths (List[str]): List of ground truth answers.
|
100 |
+
model_answer (str): Model's answer to be evaluated.
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
int: Exact match score (1 if there is a match, 0 otherwise).
|
104 |
"""
|
105 |
+
|
106 |
if self.use_fuzzy:
|
107 |
return int(any(fuzz.partial_ratio(model_answer, gt) >= self.fuzzy_threshold for gt in ground_truths))
|
108 |
else:
|
109 |
return int(model_answer in ground_truths)
|
110 |
|
111 |
+
def syntactic_evaluation(self) -> None:
|
112 |
"""
|
113 |
Process the DataFrame: stem answers, calculate scores, and store results.
|
114 |
+
|
115 |
"""
|
116 |
+
|
117 |
self.df['raw_answers_stemmed'] = self.df['raw_answers'].apply(literal_eval).apply(self.stem_answers)
|
118 |
|
119 |
for name in self.model_names:
|
|
|
127 |
self.vqa_scores[full_config] = round(self.df[f'vqa_score_{full_config}'].mean()*100, 2)
|
128 |
self.exact_match_scores[full_config] = round(self.df[f'exact_match_score_{full_config}'].mean()*100, 2)
|
129 |
|
130 |
+
def create_GPT4_messages_template(self, question: str, ground_truths: List[str], model_answer: str) -> List[dict]:
|
131 |
"""
|
132 |
Create a message list for the GPT-4 API call based on the question, ground truths, and model answer.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
question (str): The question being evaluated.
|
136 |
+
ground_truths (List[str]): List of ground truth answers.
|
137 |
+
model_answer (str): Model's answer to be evaluated.
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
List[dict]: Messages formatted for GPT-4 API call.
|
141 |
"""
|
142 |
+
|
143 |
system_message = {
|
144 |
"role": "system",
|
145 |
"content": """You are an AI trained to evaluate the equivalence of AI-generated answers to a set of ground truth answers for a given question. Upon reviewing a model's answer, determine if it matches the ground truths. Use the following rating system: 1 if you find that the model answer matches more than 25% of the ground truth answers, 2 if you find that the model answer matches only less than 25% of the ground truth answers, and 3 if the model answer is incorrect. Respond in the format below for easy parsing:
|
|
|
155 |
return [system_message, user_message]
|
156 |
|
157 |
|
158 |
+
def semantic_evaluation(self) -> None:
|
159 |
"""
|
160 |
Perform semantic evaluation using GPT-4 for each model configuration.
|
161 |
"""
|
|
|
171 |
rating = int(evaluation.split('\n')[0].split(":")[1].strip())
|
172 |
self.df.at[index, f'gpt4_rating_{config}'] = rating
|
173 |
|
174 |
+
def save_results(self, save_filename: str) -> None:
|
175 |
+
"""
|
176 |
+
Save the evaluation results to an Excel file.
|
177 |
+
|
178 |
+
Args:
|
179 |
+
save_filename (str): The filename to save the results.
|
180 |
+
"""
|
181 |
+
|
182 |
# Create a DataFrame for the scores
|
183 |
scores_data = {
|
184 |
'Model Configuration': list(self.vqa_scores.keys()),
|
|
|
192 |
self.df.to_excel(writer, sheet_name='Main Data', index=False)
|
193 |
scores_df.to_excel(writer, sheet_name='Scores', index=False)
|
194 |
|
195 |
+
def run_evaluation(save: bool = False, save_filename: str = "results") -> None:
|
196 |
"""
|
197 |
Run the full evaluation process using KBVQAEvaluator and save the results to an Excel file.
|
198 |
+
|
199 |
+
Args:
|
200 |
+
save (bool): Whether to save the results to an Excel file. Defaults to False.
|
201 |
+
save_filename (str): The filename to save the results if save is True. Defaults to "results".
|
202 |
"""
|
203 |
+
|
204 |
# Instantiate the evaluator
|
205 |
evaluator = KBVQAEvaluator()
|
206 |
|