import json import os import pickle import random import time from collections import Counter from datetime import datetime from glob import glob import gdown import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import streamlit as st from PIL import Image import SessionState from download_utils import * from image_utils import * random.seed(datetime.now()) np.random.seed(int(time.time())) NUMBER_OF_TRIALS = 20 CLASSIFIER_TAG = "" explaination_functions = [load_chm_nns, load_knn_nns] selected_xai_tool = None # Config folder_to_name = {} class_descriptions = {} classifier_predictions = {} selected_dataset = "Final" root_visualization_dir = "./visualizations/" viz_url = "https://static.taesiri.com/xai/Final.zip" viz_archivefile = "Final.zip" demonstration_url = "https://static.taesiri.com/xai/demonstrations.zip" demonst_zipfile = "demonstrations.zip" picklefile_url = "https://static.taesiri.com/xai/Task1_Results_CHM_and_EMD.pickle" prediction_root = "./predictions/" prediction_pickle = f"{prediction_root}predictions.pickle" # Get the Data download_files( root_visualization_dir, viz_url, viz_archivefile, demonstration_url, demonst_zipfile, picklefile_url, prediction_root, prediction_pickle, ) ################################################ # GLOBAL VARIABLES app_mode = "" ## Shared/Global Information with open("imagenet-labels.json", "rb") as f: folder_to_name = json.load(f) with open("gloss.txt", "r") as f: description_file = f.readlines() class_descriptions = {l.split("\t")[0]: l.split("\t")[1] for l in description_file} ################################################ with open(prediction_pickle, "rb") as f: classifier_predictions = pickle.load(f) # SESSION STATE session_state = SessionState.get( page=1, first_run=1, user_feedback={}, queries=[], is_classifier_correct={}, XAI_tool="Unselected", ) ################################################ def resmaple_queries(): if session_state.first_run == 1: both_correct = glob( root_visualization_dir + selected_dataset + "/Both_correct/*.JPEG" ) both_wrong = glob( root_visualization_dir + selected_dataset + "/Both_wrong/*.JPEG" ) correct_samples = list( np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False) ) wrong_samples = list( np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False) ) all_images = correct_samples + wrong_samples random.shuffle(all_images) session_state.queries = all_images session_state.first_run = -1 # RESET INTERACTIONS session_state.user_feedback = {} session_state.is_classifier_correct = {} def render_experiment(query): current_query = session_state.queries[query] query_id = os.path.basename(current_query) predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"] prediction_confidence = classifier_predictions[query_id][ f"{CLASSIFIER_TAG}-confidence" ] prediction_label = folder_to_name[predicted_wnid] class_def = class_descriptions[predicted_wnid] session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][ f"{CLASSIFIER_TAG.upper()}-Output" ] ################################### SHOW QUERY and PREDICTION col1, col2 = st.columns(2) with col1: st.image(load_query(current_query), caption=f"Query ID: {query_id}") with col2: ################################### SHOW DESCRIPTION OF CLASS with st.expander("Show Class Description"): st.write(f"**Name**: {prediction_label}") st.write("**Class Definition**:") st.markdown("`" + class_def + "`") st.image( Image.open(f"demonstrations/{predicted_wnid}.jpeg"), caption=f"Class Explanation", use_column_width=True, ) default_value = 0 if query_id in session_state.user_feedback.keys(): if session_state.user_feedback[query_id] == "Correct": default_value = 1 elif session_state.user_feedback[query_id] == "Wrong": default_value = 2 session_state.user_feedback[query_id] = st.radio( "What do you think about model's prediction?", ("-", "Correct", "Wrong"), key=query_id, index=default_value, ) st.write(f"**Model Prediction**: {prediction_label}") st.write(f"**Model Confidence**: {prediction_confidence}") ################################### SHOW Model Explanation if selected_xai_tool is not None: st.image( selected_xai_tool(current_query), caption=f"Explaination", use_column_width=True, ) ################################### SHOW DEBUG INFO if st.button("Debug: Show Everything"): st.image(Image.open(current_query)) def render_results(): user_correct_guess = 0 for q in session_state.user_feedback.keys(): uf = True if session_state.user_feedback[q] == 'Correct' else False if session_state.is_classifier_correct[q] == uf: user_correct_guess += 1 st.write( f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct" ) st.markdown("## User Performance Breakdown") categories = [ "Correct", "Wrong", ] # set(session_state.is_classifier_correct.values()) breakdown_stats_correct = {c: 0 for c in categories} breakdown_stats_wrong = {c: 0 for c in categories} experiment_summary = [] for q in session_state.user_feedback.keys(): category = "Correct" if session_state.is_classifier_correct[q] else "Wrong" is_user_correct = category == session_state.user_feedback[q] if is_user_correct: breakdown_stats_correct[category] += 1 else: breakdown_stats_wrong[category] += 1 experiment_summary.append( [ q, classifier_predictions[q]["real-gts"], folder_to_name[ classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"] ], category, session_state.user_feedback[q], is_user_correct, ] ) ################################### Summary Table experiment_summary_df = pd.DataFrame.from_records( experiment_summary, columns=[ "Query", "GT Labels", f"{CLASSIFIER_TAG} Prediction", "Category", "User Prediction", "Is User Prediction Correct", ], ) st.write("Summary", experiment_summary_df) csv = convert_df(experiment_summary_df) st.download_button( "Press to Download", csv, "summary.csv", "text/csv", key="download-records" ) ################################### SHOW BREAKDOWN user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg( {"Is User Prediction Correct": ["count", "sum", "mean"]} ) # rename columns user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0) user_pf_by_model_pred.columns = [ "Count", "Correct User Guess", "Mean User Performance", ] user_pf_by_model_pred.index.name = "Model Prediction" st.write("User performance break down by Model prediction:", user_pf_by_model_pred) csv = convert_df(user_pf_by_model_pred) st.download_button( "Press to Download", csv, "user-performance-by-model-prediction.csv", "text/csv", key="download-performance-by-model-prediction", ) ################################### CONFUSION MATRIX confusion_matrix = pd.crosstab( experiment_summary_df["Category"], experiment_summary_df["User Prediction"], rownames=["Actual"], colnames=["Predicted"], ) st.write("Confusion Matrix", confusion_matrix) csv = convert_df(confusion_matrix) st.download_button( "Press to Download", csv, "confusion-matrix.csv", "text/csv", key="download-confusiion-matrix", ) def render_menu(): # Render the readme as markdown using st.markdown. readme_text = st.markdown( """ # Instructions ``` When testing this study, you should first see the class definition, then hide the expander and see the query. ``` """ ) app_mode = st.selectbox( "Choose the page to show:", ["Experiment Instruction", "Start Experiment", "See the Results"], ) if app_mode == "Experiment Instruction": st.success("To continue select an option in the dropdown menu.") elif app_mode == "Start Experiment": # Clear Canvas readme_text.empty() page_id = session_state.page col1, col4, col2, col3 = st.columns(4) prev_page = col1.button("Previous Image") if prev_page: page_id -= 1 if page_id < 1: page_id = 1 next_page = col2.button("Next Image") if next_page: page_id += 1 if page_id > NUMBER_OF_TRIALS: page_id = NUMBER_OF_TRIALS if page_id == NUMBER_OF_TRIALS: st.success( 'You have reached the last image. Please go to the "Results" page to see your performance.' ) if st.button("View"): app_mode = "See the Results" if col3.button("Resample"): st.write("Restarting ...") page_id = 1 session_state.first_run = 1 resmaple_queries() session_state.page = page_id st.write(f"Render Experiment: {session_state.page}") render_experiment(session_state.page - 1) elif app_mode == "See the Results": readme_text.empty() st.write("Results Summary") render_results() def main(): global app_mode global session_state global selected_xai_tool global CLASSIFIER_TAG # Set the session state # State Management and General Setup st.set_page_config(layout="wide") st.title("TASK - 1 - ImageNetREAL") # st.write(classifier_predictions.keys()) # st.write(classifier_predictions["ILSVRC2012_val_00024646.JPEG"]) options = [ "Unselected", "NOXAI", "KNN", "EMD Nearest Neighbors", "EMD Correspondence", "CHM Nearest Neighbors", "CHM Correspondence", ] st.markdown( """ """, unsafe_allow_html=True, ) if session_state.XAI_tool == "Unselected": default = options.index(session_state.XAI_tool) session_state.XAI_tool = st.radio( "What explaination tool do you want to evaluate?", options, key="which_xai", index=default, ) # print(session_state.XAI_tool) if session_state.XAI_tool != "Unselected": st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``") if session_state.XAI_tool == "NOXAI": CLASSIFIER_TAG = "knn" selected_xai_tool = None elif session_state.XAI_tool == "KNN": selected_xai_tool = load_knn_nns CLASSIFIER_TAG = "knn" elif session_state.XAI_tool == "CHM Nearest Neighbors": selected_xai_tool = load_chm_nns CLASSIFIER_TAG = "CHM" elif session_state.XAI_tool == "CHM Correspondence": selected_xai_tool = load_chm_corrs CLASSIFIER_TAG = "CHM" elif session_state.XAI_tool == "EMD Nearest Neighbors": selected_xai_tool = load_emd_nns CLASSIFIER_TAG = "EMD" elif session_state.XAI_tool == "EMD Correspondence": selected_xai_tool = load_emd_corrs CLASSIFIER_TAG = "EMD" resmaple_queries() render_menu() if __name__ == "__main__": main()