Spaces:
Runtime error
Runtime error
import json | |
import os | |
import pickle | |
import random | |
import time | |
from collections import Counter | |
from datetime import datetime | |
from glob import glob | |
import gdown | |
import matplotlib.pyplot as plt | |
import numpy as np | |
import pandas as pd | |
import seaborn as sns | |
import streamlit as st | |
from PIL import Image | |
import SessionState | |
from download_utils import * | |
from image_utils import * | |
random.seed(datetime.now()) | |
np.random.seed(int(time.time())) | |
NUMBER_OF_TRIALS = 20 | |
CLASSIFIER_TAG = "" | |
explaination_functions = [load_chm_nns, load_knn_nns] | |
selected_xai_tool = None | |
# Config | |
folder_to_name = {} | |
class_descriptions = {} | |
classifier_predictions = {} | |
selected_dataset = "Final" | |
root_visualization_dir = "./visualizations/" | |
viz_url = "https://drive.google.com/uc?id=1LpmOc_nFBzApYWAokO2J-s9RRXsk3pBN" | |
viz_archivefile = "Final.zip" | |
demonstration_url = "https://drive.google.com/uc?id=1C92llG5VrlABrsIEvxfNlSDc_gIeLlls" | |
demonst_zipfile = "demonstrations.zip" | |
picklefile_url = "https://drive.google.com/uc?id=1Yx4abA4VLZGO5JkzhXVGdy6mbPltMd68" | |
prediction_root = "./predictions/" | |
prediction_pickle = f"{prediction_root}predictions.pickle" | |
# Get the Data | |
download_files( | |
root_visualization_dir, | |
viz_url, | |
viz_archivefile, | |
demonstration_url, | |
demonst_zipfile, | |
picklefile_url, | |
prediction_root, | |
prediction_pickle, | |
) | |
################################################ | |
# GLOBAL VARIABLES | |
app_mode = "" | |
## Shared/Global Information | |
with open("imagenet-labels.json", "rb") as f: | |
folder_to_name = json.load(f) | |
with open("gloss.txt", "r") as f: | |
description_file = f.readlines() | |
class_descriptions = {l.split("\t")[0]: l.split("\t")[1] for l in description_file} | |
################################################ | |
with open(prediction_pickle, "rb") as f: | |
classifier_predictions = pickle.load(f) | |
# SESSION STATE | |
session_state = SessionState.get( | |
page=1, | |
first_run=1, | |
user_feedback={}, | |
queries=[], | |
is_classifier_correct={}, | |
XAI_tool="Unselected", | |
) | |
################################################ | |
def resmaple_queries(): | |
if session_state.first_run == 1: | |
both_correct = glob( | |
root_visualization_dir + selected_dataset + "/Both_correct/*.JPEG" | |
) | |
both_wrong = glob( | |
root_visualization_dir + selected_dataset + "/Both_wrong/*.JPEG" | |
) | |
correct_samples = list( | |
np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False) | |
) | |
wrong_samples = list( | |
np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False) | |
) | |
all_images = correct_samples + wrong_samples | |
random.shuffle(all_images) | |
session_state.queries = all_images | |
session_state.first_run = -1 | |
# RESET INTERACTIONS | |
session_state.user_feedback = {} | |
session_state.is_classifier_correct = {} | |
def render_experiment(query): | |
current_query = session_state.queries[query] | |
query_id = os.path.basename(current_query) | |
predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"] | |
prediction_confidence = classifier_predictions[query_id][ | |
f"{CLASSIFIER_TAG}-confidence" | |
] | |
prediction_label = folder_to_name[predicted_wnid] | |
class_def = class_descriptions[predicted_wnid] | |
session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][ | |
f"{CLASSIFIER_TAG.upper()}-Output" | |
] | |
################################### SHOW QUERY and PREDICTION | |
col1, col2 = st.columns(2) | |
with col1: | |
st.image(load_query(current_query), caption=f"Query ID: {query_id}") | |
with col2: | |
################################### SHOW DESCRIPTION OF CLASS | |
with st.expander("Show Class Description"): | |
st.write(f"**Name**: {prediction_label}") | |
st.write("**Class Definition**:") | |
st.markdown("`" + class_def + "`") | |
st.image( | |
Image.open(f"demonstrations/{predicted_wnid}.jpeg"), | |
caption=f"Class Explanation", | |
use_column_width=True, | |
) | |
default_value = 0 | |
if query_id in session_state.user_feedback.keys(): | |
if session_state.user_feedback[query_id] == "Correct": | |
default_value = 1 | |
elif session_state.user_feedback[query_id] == "Wrong": | |
default_value = 2 | |
session_state.user_feedback[query_id] = st.radio( | |
"What do you think about model's prediction?", | |
("-", "Correct", "Wrong"), | |
key=query_id, | |
index=default_value, | |
) | |
st.write(f"**Model Prediction**: {prediction_label}") | |
st.write(f"**Model Confidence**: {prediction_confidence}") | |
################################### SHOW Model Explanation | |
if selected_xai_tool is not None: | |
st.image( | |
selected_xai_tool(current_query), | |
caption=f"Explaination", | |
use_column_width=True, | |
) | |
################################### SHOW DEBUG INFO | |
if st.button("Debug: Show Everything"): | |
st.image(Image.open(current_query)) | |
def render_results(): | |
user_correct_guess = 0 | |
for q in session_state.user_feedback.keys(): | |
uf = True if session_state.user_feedback[q] == "Correct" else False | |
if session_state.is_classifier_correct[q] == uf: | |
user_correct_guess += 1 | |
st.write( | |
f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct" | |
) | |
st.markdown("## User Performance Breakdown") | |
categories = [ | |
"Correct", | |
"Wrong", | |
] # set(session_state.is_classifier_correct.values()) | |
breakdown_stats_correct = {c: 0 for c in categories} | |
breakdown_stats_wrong = {c: 0 for c in categories} | |
experiment_summary = [] | |
for q in session_state.user_feedback.keys(): | |
category = "Correct" if session_state.is_classifier_correct[q] else "Wrong" | |
is_user_correct = category == session_state.user_feedback[q] | |
if is_user_correct: | |
breakdown_stats_correct[category] += 1 | |
else: | |
breakdown_stats_wrong[category] += 1 | |
experiment_summary.append( | |
[ | |
q, | |
classifier_predictions[q]["real-gts"], | |
folder_to_name[ | |
classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"] | |
], | |
category, | |
session_state.user_feedback[q], | |
is_user_correct, | |
] | |
) | |
################################### Summary Table | |
experiment_summary_df = pd.DataFrame.from_records( | |
experiment_summary, | |
columns=[ | |
"Query", | |
"GT Labels", | |
f"{CLASSIFIER_TAG} Prediction", | |
"Category", | |
"User Prediction", | |
"Is User Prediction Correct", | |
], | |
) | |
st.write("Summary", experiment_summary_df) | |
csv = convert_df(experiment_summary_df) | |
st.download_button( | |
"Press to Download", csv, "summary.csv", "text/csv", key="download-records" | |
) | |
################################### SHOW BREAKDOWN | |
user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg( | |
{"Is User Prediction Correct": ["count", "sum", "mean"]} | |
) | |
# rename columns | |
user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0) | |
user_pf_by_model_pred.columns = [ | |
"Count", | |
"Correct User Guess", | |
"Mean User Performance", | |
] | |
user_pf_by_model_pred.index.name = "Model Prediction" | |
st.write("User performance break down by Model prediction:", user_pf_by_model_pred) | |
csv = convert_df(user_pf_by_model_pred) | |
st.download_button( | |
"Press to Download", | |
csv, | |
"user-performance-by-model-prediction.csv", | |
"text/csv", | |
key="download-performance-by-model-prediction", | |
) | |
################################### CONFUSION MATRIX | |
confusion_matrix = pd.crosstab( | |
experiment_summary_df["Category"], | |
experiment_summary_df["User Prediction"], | |
rownames=["Actual"], | |
colnames=["Predicted"], | |
) | |
st.write("Confusion Matrix", confusion_matrix) | |
csv = convert_df(confusion_matrix) | |
st.download_button( | |
"Press to Download", | |
csv, | |
"confusion-matrix.csv", | |
"text/csv", | |
key="download-confusiion-matrix", | |
) | |
def render_menu(): | |
# Render the readme as markdown using st.markdown. | |
readme_text = st.markdown( | |
""" | |
# Instructions | |
``` | |
When testing this study, you should first see the class definition, then hide the expander and see the query. | |
``` | |
""" | |
) | |
app_mode = st.selectbox( | |
"Choose the page to show:", | |
["Experiment Instruction", "Start Experiment", "See the Results"], | |
) | |
if app_mode == "Experiment Instruction": | |
st.success("To continue select an option in the dropdown menu.") | |
elif app_mode == "Start Experiment": | |
# Clear Canvas | |
readme_text.empty() | |
page_id = session_state.page | |
col1, col4, col2, col3 = st.columns(4) | |
prev_page = col1.button("Previous Image") | |
if prev_page: | |
page_id -= 1 | |
if page_id < 1: | |
page_id = 1 | |
next_page = col2.button("Next Image") | |
if next_page: | |
page_id += 1 | |
if page_id > NUMBER_OF_TRIALS: | |
page_id = NUMBER_OF_TRIALS | |
if page_id == NUMBER_OF_TRIALS: | |
st.success( | |
'You have reached the last image. Please go to the "Results" page to see your performance.' | |
) | |
if st.button("View"): | |
app_mode = "See the Results" | |
if col3.button("Resample"): | |
st.write("Restarting ...") | |
page_id = 1 | |
session_state.first_run = 1 | |
resmaple_queries() | |
session_state.page = page_id | |
st.write(f"Render Experiment: {session_state.page}") | |
render_experiment(session_state.page - 1) | |
elif app_mode == "See the Results": | |
readme_text.empty() | |
st.write("Results Summary") | |
render_results() | |
def main(): | |
global app_mode | |
global session_state | |
global selected_xai_tool | |
global CLASSIFIER_TAG | |
# Set the session state | |
# State Management and General Setup | |
st.set_page_config(layout="wide") | |
st.title("Visual CorrespondenceHuman Study - ImageNet") | |
options = [ | |
"Unselected", | |
"NOXAI", | |
"KNN", | |
"EMD-Corr Nearest Neighbors", | |
"EMD-Corr Correspondence", | |
"CHM-Corr Nearest Neighbors", | |
"CHM-Corr Correspondence", | |
] | |
st.markdown( | |
""" <style> | |
div[role="radiogroup"] > :first-child{ | |
display: none !important; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True, | |
) | |
if session_state.XAI_tool == "Unselected": | |
default = options.index(session_state.XAI_tool) | |
session_state.XAI_tool = st.radio( | |
"What explaination tool do you want to evaluate?", | |
options, | |
key="which_xai", | |
index=default, | |
) | |
# print(session_state.XAI_tool) | |
if session_state.XAI_tool != "Unselected": | |
st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``") | |
if session_state.XAI_tool == "NOXAI": | |
CLASSIFIER_TAG = "knn" | |
selected_xai_tool = None | |
elif session_state.XAI_tool == "KNN": | |
selected_xai_tool = load_knn_nns | |
CLASSIFIER_TAG = "knn" | |
elif session_state.XAI_tool == "CHM-Corr Nearest Neighbors": | |
selected_xai_tool = load_chm_nns | |
CLASSIFIER_TAG = "CHM" | |
elif session_state.XAI_tool == "CHM-Corr Correspondence": | |
selected_xai_tool = load_chm_corrs | |
CLASSIFIER_TAG = "CHM" | |
elif session_state.XAI_tool == "EMD-Corr Nearest Neighbors": | |
selected_xai_tool = load_emd_nns | |
CLASSIFIER_TAG = "EMD" | |
elif session_state.XAI_tool == "EMD-Corr Correspondence": | |
selected_xai_tool = load_emd_corrs | |
CLASSIFIER_TAG = "EMD" | |
resmaple_queries() | |
render_menu() | |
if __name__ == "__main__": | |
main() | |