m7mdal7aj commited on
Commit
062b387
1 Parent(s): c792f00

Create demo.py

Browse files
Files changed (1) hide show
  1. my_model/results/demo.py +262 -0
my_model/results/demo.py ADDED
@@ -0,0 +1,262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import altair as alt
3
+ from my_model.config import evaluation_config as config
4
+ import streamlit as st
5
+ from PIL import Image
6
+ import pandas as pd
7
+ import random
8
+
9
+
10
+ class ResultDemonstrator:
11
+ """
12
+ A class to demonstrate the results of the Knowledge-Based Visual Question Answering (KB-VQA) model.
13
+
14
+ Attributes:
15
+ main_data (pd.DataFrame): Data loaded from an Excel file containing evaluation results.
16
+ sample_img_pool (list[str]): List of image file names available for demonstration.
17
+ model_names (list[str]): List of model names as defined in the configuration.
18
+ model_configs (list[str]): List of model configurations as defined in the configuration.
19
+ """
20
+
21
+ def __init__(self) -> None:
22
+ """
23
+ Initializes the ResultDemonstrator class by loading the data from an Excel file.
24
+ """
25
+ # Load data
26
+ self.main_data = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Main Data")
27
+ self.sample_img_pool = list(os.listdir("Demo_Images"))
28
+ self.model_names = config.MODEL_NAMES
29
+ self.model_configs = config.MODEL_CONFIGURATIONS
30
+
31
+ @staticmethod
32
+ def display_table(data: pd.DataFrame) -> None:
33
+ """
34
+ Displays a DataFrame using Streamlit's dataframe display function.
35
+
36
+ Args:
37
+ data (pd.DataFrame): The data to display.
38
+ """
39
+ st.dataframe(data)
40
+
41
+ def calculate_and_append_data(self, data_list: list, score_column: str, model_config: str) -> None:
42
+ """
43
+ Calculates mean scores by category and appends them to the data list.
44
+
45
+ Args:
46
+ data_list (list): List to append new data rows.
47
+ score_column (str): Name of the column to calculate mean scores for.
48
+ model_config (str): Configuration of the model.
49
+ """
50
+ if score_column in self.main_data.columns:
51
+ category_means = self.main_data.groupby('question_category')[score_column].mean()
52
+ for category, mean_value in category_means.items():
53
+ data_list.append({
54
+ "Category": category,
55
+ "Configuration": model_config,
56
+ "Mean Value": round(mean_value * 100, 2)
57
+ })
58
+
59
+ def display_ablation_results_per_question_category(self) -> None:
60
+ """Displays ablation results per question category for each model configuration."""
61
+
62
+ score_types = ['vqa', 'vqa_gpt4', 'em', 'em_gpt4']
63
+ data_lists = {key: [] for key in score_types}
64
+ column_names = {
65
+ 'vqa': 'vqa_score_{config}',
66
+ 'vqa_gpt4': 'gpt4_vqa_score_{config}',
67
+ 'em': 'exact_match_score_{config}',
68
+ 'em_gpt4': 'gpt4_em_score_{config}'
69
+ }
70
+
71
+ for model_name in config.MODEL_NAMES:
72
+ for conf in config.MODEL_CONFIGURATIONS:
73
+ model_config = f"{model_name}_{conf}"
74
+ for score_type, col_template in column_names.items():
75
+ self.calculate_and_append_data(data_lists[score_type],
76
+ col_template.format(config=model_config),
77
+ model_config)
78
+
79
+ # Process and display results for each score type
80
+ for score_type, data_list in data_lists.items():
81
+ df = pd.DataFrame(data_list)
82
+ results_df = df.pivot(index='Category', columns='Configuration', values='Mean Value').applymap(
83
+ lambda x: f"{x:.2f}%")
84
+
85
+ with st.expander(f"{score_type.upper()} Scores per Question Category and Model Configuration"):
86
+ self.display_table(results_df)
87
+
88
+ def display_main_results(self) -> None:
89
+ """Displays the main model results from the Scores sheet, these are displayed from the file directly."""
90
+ main_scores = pd.read_excel('evaluation_results.xlsx', sheet_name="Scores", index_col=0)
91
+ st.markdown("### Main Model Results (Inclusive of Ablation Experiments)")
92
+ main_scores.reset_index()
93
+ self.display_table(main_scores)
94
+
95
+ def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None:
96
+ """
97
+ Plots an interactive scatter plot comparing token counts to VQA or EM scores using Altair.
98
+
99
+ Args:
100
+ conf (str): The configuration name.
101
+ model_name (str): The name of the model.
102
+ score_name (str): The type of score to plot.
103
+ """
104
+
105
+ # Construct the full model configuration name
106
+ model_configuration = f"{model_name}_{conf}"
107
+
108
+ # Determine the score column name and legend mapping based on the score type
109
+ if score_name == 'VQA Score':
110
+
111
+ score_column_name = f"vqa_score_{model_configuration}"
112
+ scores = self.main_data[score_column_name]
113
+ # Map scores to categories for the legend
114
+ legend_map = ['Correct' if score == 1 else 'Partially Correct' if round(score, 2) == 0.67 else 'Incorrect'
115
+ for score in scores]
116
+
117
+ color_scale = alt.Scale(domain=['Correct', 'Partially Correct', 'Incorrect'], range=['green', 'orange',
118
+ 'red'])
119
+ else:
120
+ score_column_name = f"exact_match_score_{model_configuration}"
121
+ scores = self.main_data[score_column_name]
122
+ # Map scores to categories for the legend
123
+ legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores]
124
+ color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red'])
125
+
126
+ # Retrieve token counts from the data
127
+ token_counts = self.main_data[f'tokens_count_{conf}']
128
+
129
+ # Create a DataFrame for the scatter plot
130
+ scatter_data = pd.DataFrame({
131
+ 'Index': range(len(token_counts)),
132
+ 'Token Counts': token_counts,
133
+ score_name: legend_map
134
+ })
135
+
136
+ # Create an interactive scatter plot using Altair
137
+ chart = alt.Chart(scatter_data).mark_circle(
138
+ size=60,
139
+ fillOpacity=1, # Sets the fill opacity to maximum
140
+ strokeWidth=1, # Adjusts the border width making the circles bolder
141
+ stroke='black' # Sets the border color to black
142
+ ).encode(
143
+ x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])),
144
+ y=alt.Y('Token Counts', scale=alt.Scale(domain=[token_counts.min()-200, token_counts.max()+200])),
145
+ color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)),
146
+ tooltip=['Index', 'Token Counts', score_name]
147
+ ).interactive() # Enables zoom & pan
148
+
149
+ chart = chart.properties(
150
+ title={
151
+ "text": f"Token Counts vs {score_name} + Score + ({model_configuration})",
152
+ "color": "black", # Optional color
153
+ "fontSize": 20, # Optional font size
154
+ "anchor": "middle", # Optional anchor position
155
+ "offset": 0 # Optional offset
156
+ },
157
+ width=700,
158
+ height=500
159
+ )
160
+
161
+ # Display the interactive plot in Streamlit
162
+ st.altair_chart(chart, use_container_width=True)
163
+
164
+ @staticmethod
165
+ def color_scores(value: float) -> str:
166
+ """
167
+ Applies color coding based on the score value.
168
+
169
+ Args:
170
+ value (float): The score value.
171
+
172
+ Returns:
173
+ str: CSS color style based on score value.
174
+ """
175
+
176
+ try:
177
+ value = float(value) # Convert to float to handle numerical comparisons
178
+ except ValueError:
179
+ return 'color: black;' # Return black if value is not a number
180
+
181
+ if value == 1.0:
182
+ return 'color: green;'
183
+ elif value == 0.0:
184
+ return 'color: red;'
185
+ elif value == 0.67:
186
+ return 'color: orange;'
187
+ return 'color: black;'
188
+
189
+ def show_samples(self, num_samples: int = 3) -> None:
190
+ """
191
+ Displays random sample images and their associated models answers and evaluations.
192
+
193
+ Args:
194
+ num_samples (int): Number of sample images to display.
195
+ """
196
+
197
+ # Sample images from the pool
198
+ target_imgs = random.sample(self.sample_img_pool, num_samples)
199
+ # Generate model configurations
200
+ model_configs = [f"{model_name}_{conf}" for model_name in self.model_names for conf in self.model_configs]
201
+ # Define column names for scores dynamically
202
+ column_names = {
203
+ 'vqa': 'vqa_score_{config}',
204
+ 'vqa_gpt4': 'gpt4_vqa_score_{config}',
205
+ 'em': 'exact_match_score_{config}',
206
+ 'em_gpt4': 'gpt4_em_score_{config}'
207
+ }
208
+
209
+ for img_filename in target_imgs:
210
+ image_data = self.main_data[self.main_data['image_filename'] == img_filename]
211
+ im = Image.open(f"demo/{img_filename}")
212
+ col1, col2 = st.columns([1, 2]) # to display images side by side with their data.
213
+ # Create a container for each image
214
+ with st.container():
215
+ st.write("-------------------------------")
216
+ with col1:
217
+ st.image(im, use_column_width=True)
218
+ with st.expander('Show Caption'):
219
+ st.text(image_data.iloc[0]['caption'])
220
+ with st.expander('Show DETIC Objects'):
221
+ st.text(image_data.iloc[0]['objects_detic_trimmed'])
222
+ with st.expander('Show YOLOv5 Objects'):
223
+ st.text(image_data.iloc[0]['objects_yolov5'])
224
+ with col2:
225
+ if not image_data.empty:
226
+ st.write(f"**Question: {image_data.iloc[0]['question']}**")
227
+ st.write(f"**Ground Truth Answers:** {image_data.iloc[0]['raw_answers']}")
228
+
229
+ # Initialize an empty DataFrame for summary data
230
+ summary_data = pd.DataFrame(
231
+ columns=['Model Configuration', 'Answer', 'VQA Score', 'VQA Score (GPT-4)', 'EM Score',
232
+ 'EM Score (GPT-4)'])
233
+
234
+ for config in model_configs:
235
+ # Collect data for each model configuration
236
+ row_data = {
237
+ 'Model Configuration': config,
238
+ 'Answer': image_data.iloc[0].get(f'{config}', '-')
239
+ }
240
+ for score_type, score_template in column_names.items():
241
+ score_col = score_template.format(config=config)
242
+ score_value = image_data.iloc[0].get(score_col, '-')
243
+ if pd.notna(score_value) and not isinstance(score_value, str):
244
+ # Format score to two decimals if it's a valid number
245
+ score_value = f"{float(score_value):.2f}"
246
+ row_data[score_type.replace('_', ' ').title()] = score_value
247
+
248
+ # Convert row data to a DataFrame and concatenate it
249
+ rd = pd.DataFrame([row_data])
250
+ rd.columns = summary_data.columns
251
+ summary_data = pd.concat([summary_data, rd], axis=0, ignore_index=True)
252
+
253
+ # Apply styling to DataFrame for score coloring
254
+ styled_summary = summary_data.style.applymap(self.color_scores,
255
+ subset=['VQA Score', 'VQA Score (GPT-4)',
256
+ 'EM Score',
257
+ 'EM Score (GPT-4)'])
258
+ st.markdown(styled_summary.to_html(escape=False, index=False), unsafe_allow_html=True)
259
+ else:
260
+ st.write("No data available for this image.")
261
+
262
+