import os import altair as alt from my_model.config import evaluation_config as config import streamlit as st from PIL import Image import pandas as pd import random class ResultDemonstrator: """ A class to demonstrate the results of the Knowledge-Based Visual Question Answering (KB-VQA) model. Attributes: main_data (pd.DataFrame): Data loaded from an Excel file containing evaluation results. sample_img_pool (list[str]): List of image file names available for demonstration. model_names (list[str]): List of model names as defined in the configuration. model_configs (list[str]): List of model configurations as defined in the configuration. demo_images_path(str): Path to the demo images directory. """ def __init__(self) -> None: """ Initializes the ResultDemonstrator class by loading the data from an Excel file. """ # Load data self.main_data = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Main Data") self.sample_img_pool = list(os.listdir(config.DEMO_IMAGES_PATH)) self.model_names = config.MODEL_NAMES self.model_configs = config.MODEL_CONFIGURATIONS self.demo_images_path = config.DEMO_IMAGES_PATH @staticmethod def display_table(data: pd.DataFrame) -> None: """ Displays a DataFrame using Streamlit's dataframe display function. Args: data (pd.DataFrame): The data to display. """ st.dataframe(data) def calculate_and_append_data(self, data_list: list, score_column: str, model_config: str) -> None: """ Calculates mean scores by category and appends them to the data list. Args: data_list (list): List to append new data rows. score_column (str): Name of the column to calculate mean scores for. model_config (str): Configuration of the model. """ if score_column in self.main_data.columns: category_means = self.main_data.groupby('question_category')[score_column].mean() for category, mean_value in category_means.items(): data_list.append({ "Category": category, "Configuration": model_config, "Mean Value": round(mean_value * 100, 2) }) def display_ablation_results_per_question_category(self) -> None: """Displays ablation results per question category for each model configuration.""" score_types = ['vqa', 'vqa_gpt4', 'em', 'em_gpt4'] data_lists = {key: [] for key in score_types} column_names = { 'vqa': 'vqa_score_{config}', 'vqa_gpt4': 'gpt4_vqa_score_{config}', 'em': 'exact_match_score_{config}', 'em_gpt4': 'gpt4_em_score_{config}' } for model_name in config.MODEL_NAMES: for conf in config.MODEL_CONFIGURATIONS: model_config = f"{model_name}_{conf}" for score_type, col_template in column_names.items(): self.calculate_and_append_data(data_lists[score_type], col_template.format(config=model_config), model_config) # Process and display results for each score type for score_type, data_list in data_lists.items(): df = pd.DataFrame(data_list) results_df = df.pivot(index='Category', columns='Configuration', values='Mean Value').applymap( lambda x: f"{x:.2f}%") with st.expander(f"{score_type.upper()} Scores per Question Category and Model Configuration"): self.display_table(results_df) def display_main_results(self) -> None: """Displays the main model results from the Scores sheet, these are displayed from the file directly.""" main_scores = pd.read_excel(config.EVALUATION_DATA_PATH, sheet_name="Scores", index_col=0) st.markdown("### Main Model Results (Inclusive of Ablation Experiments)") main_scores.reset_index() self.display_table(main_scores) def plot_token_count_vs_scores(self, conf: str, model_name: str, score_name: str = 'VQA Score') -> None: """ Plots an interactive scatter plot comparing token count to VQA or EM scores using Altair. Args: conf (str): The configuration name. model_name (str): The name of the model. score_name (str): The type of score to plot. """ # Construct the full model configuration name model_configuration = f"{model_name}_{conf}" # Determine the score column name and legend mapping based on the score type if score_name == 'VQA Score': score_column_name = f"vqa_score_{model_configuration}" scores = self.main_data[score_column_name] # Map scores to categories for the legend legend_map = ['Correct' if score == 1 else 'Partially Correct' if round(score, 2) == 0.67 else 'Incorrect' for score in scores] color_scale = alt.Scale(domain=['Correct', 'Partially Correct', 'Incorrect'], range=['green', 'orange', 'red']) else: score_column_name = f"exact_match_score_{model_configuration}" scores = self.main_data[score_column_name] # Map scores to categories for the legend legend_map = ['Correct' if score == 1 else 'Incorrect' for score in scores] color_scale = alt.Scale(domain=['Correct', 'Incorrect'], range=['green', 'red']) # Retrieve token count from the data token_count = self.main_data[f'tokens_count_{conf}'] # Create a DataFrame for the scatter plot scatter_data = pd.DataFrame({ 'Index': range(len(token_count)), 'Token Count': token_count, score_name: legend_map }) # Create an interactive scatter plot using Altair chart = alt.Chart(scatter_data).mark_circle( size=60, fillOpacity=1, # Sets the fill opacity to maximum strokeWidth=1, # Adjusts the border width making the circles bolder stroke='black' # Sets the border color to black ).encode( x=alt.X('Index', scale=alt.Scale(domain=[0, 1020])), y=alt.Y('Token Count', scale=alt.Scale(domain=[token_count.min()-200, token_count.max()+200])), color=alt.Color(score_name, scale=color_scale, legend=alt.Legend(title=score_name)), tooltip=['Index', 'Token Count', score_name] ).interactive() # Enables zoom & pan chart = chart.properties( title={ "text": f"Token Count vs {score_name} ({model_configuration.replace('_', '-')})", "color": "black", # Optional color "fontSize": 20, # Optional font size "anchor": "middle", # Optional anchor position "offset": 0 # Optional offset }, width=700, height=500 ) # Display the interactive plot in Streamlit st.altair_chart(chart, use_container_width=True) @staticmethod def color_scores(value: float) -> str: """ Applies color coding based on the score value. Args: value (float): The score value. Returns: str: CSS color style based on score value. """ try: value = float(value) # Convert to float to handle numerical comparisons except ValueError: return 'color: black;' # Return black if value is not a number if value == 1.0: return 'color: green;' elif value == 0.0: return 'color: red;' elif value == 0.67: return 'color: orange;' return 'color: black;' def show_samples(self, num_samples: int = 3) -> None: """ Displays random sample images and their associated models answers and evaluations. Args: num_samples (int): Number of sample images to display. """ # Sample images from the pool target_imgs = random.sample(self.sample_img_pool, num_samples) # Generate model configurations model_configs = [f"{model_name}_{conf}" for model_name in self.model_names for conf in self.model_configs] # Define column names for scores dynamically column_names = { 'vqa': 'vqa_score_{config}', 'vqa_gpt4': 'gpt4_vqa_score_{config}', 'em': 'exact_match_score_{config}', 'em_gpt4': 'gpt4_em_score_{config}' } for img_filename in target_imgs: image_data = self.main_data[self.main_data['image_filename'] == img_filename] im = Image.open(f"{self.demo_images_path}/{img_filename}") col1, col2 = st.columns([1, 2]) # to display images side by side with their data. # Create a container for each image with st.container(): st.write("-------------------------------") with col1: st.image(im, use_column_width=True) with st.expander('Show Caption'): st.text(image_data.iloc[0]['caption']) with st.expander('Show DETIC Objects'): st.text(image_data.iloc[0]['objects_detic_trimmed']) with st.expander('Show YOLOv5 Objects'): st.text(image_data.iloc[0]['objects_yolov5']) with col2: if not image_data.empty: st.write(f"**Question:** {image_data.iloc[0]['question']}") st.write(f"**Ground Truth Answers:** {image_data.iloc[0]['raw_answers']}") # Initialize an empty DataFrame for summary data summary_data = pd.DataFrame( columns=['Model Configuration', 'Answer', 'VQA Score', 'VQA Score (GPT-4)', 'EM Score', 'EM Score (GPT-4)']) for config in model_configs: # Collect data for each model configuration row_data = { 'Model Configuration': config, 'Answer': image_data.iloc[0].get(f'{config}', '-') } for score_type, score_template in column_names.items(): score_col = score_template.format(config=config) score_value = image_data.iloc[0].get(score_col, '-') if pd.notna(score_value) and not isinstance(score_value, str): # Format score to two decimals if it's a valid number score_value = f"{float(score_value):.2f}" row_data[score_type.replace('_', ' ').title()] = score_value # Convert row data to a DataFrame and concatenate it rd = pd.DataFrame([row_data]) rd.columns = summary_data.columns summary_data = pd.concat([summary_data, rd], axis=0, ignore_index=True) # Apply styling to DataFrame for score coloring styled_summary = summary_data.style.applymap(self.color_scores, subset=['VQA Score', 'VQA Score (GPT-4)', 'EM Score', 'EM Score (GPT-4)']) st.markdown(styled_summary.to_html(escape=False, index=False), unsafe_allow_html=True) else: st.write("No data available for this image.")