taesiri's picture
update
2fcf7ee
import json
import os
import pickle
import random
import time
from collections import Counter
from datetime import datetime
from glob import glob
import gdown
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import streamlit as st
from PIL import Image
import SessionState
from download_utils import *
from image_utils import *
random.seed(datetime.now())
np.random.seed(int(time.time()))
NUMBER_OF_TRIALS = 20
CLASSIFIER_TAG = ""
explaination_functions = [load_chm_nns, load_knn_nns]
selected_xai_tool = None
# Config
folder_to_name = {}
class_descriptions = {}
classifier_predictions = {}
selected_dataset = "Final"
root_visualization_dir = "./visualizations/"
viz_url = "https://drive.google.com/uc?id=1LpmOc_nFBzApYWAokO2J-s9RRXsk3pBN"
viz_archivefile = "Final.zip"
demonstration_url = "https://drive.google.com/uc?id=1C92llG5VrlABrsIEvxfNlSDc_gIeLlls"
demonst_zipfile = "demonstrations.zip"
picklefile_url = "https://drive.google.com/uc?id=1Yx4abA4VLZGO5JkzhXVGdy6mbPltMd68"
prediction_root = "./predictions/"
prediction_pickle = f"{prediction_root}predictions.pickle"
# Get the Data
download_files(
root_visualization_dir,
viz_url,
viz_archivefile,
demonstration_url,
demonst_zipfile,
picklefile_url,
prediction_root,
prediction_pickle,
)
################################################
# GLOBAL VARIABLES
app_mode = ""
## Shared/Global Information
with open("imagenet-labels.json", "rb") as f:
folder_to_name = json.load(f)
with open("gloss.txt", "r") as f:
description_file = f.readlines()
class_descriptions = {l.split("\t")[0]: l.split("\t")[1] for l in description_file}
################################################
with open(prediction_pickle, "rb") as f:
classifier_predictions = pickle.load(f)
# SESSION STATE
session_state = SessionState.get(
page=1,
first_run=1,
user_feedback={},
queries=[],
is_classifier_correct={},
XAI_tool="Unselected",
)
################################################
def resmaple_queries():
if session_state.first_run == 1:
both_correct = glob(
root_visualization_dir + selected_dataset + "/Both_correct/*.JPEG"
)
both_wrong = glob(
root_visualization_dir + selected_dataset + "/Both_wrong/*.JPEG"
)
correct_samples = list(
np.random.choice(a=both_correct, size=NUMBER_OF_TRIALS // 2, replace=False)
)
wrong_samples = list(
np.random.choice(a=both_wrong, size=NUMBER_OF_TRIALS // 2, replace=False)
)
all_images = correct_samples + wrong_samples
random.shuffle(all_images)
session_state.queries = all_images
session_state.first_run = -1
# RESET INTERACTIONS
session_state.user_feedback = {}
session_state.is_classifier_correct = {}
def render_experiment(query):
current_query = session_state.queries[query]
query_id = os.path.basename(current_query)
predicted_wnid = classifier_predictions[query_id][f"{CLASSIFIER_TAG}-predictions"]
prediction_confidence = classifier_predictions[query_id][
f"{CLASSIFIER_TAG}-confidence"
]
prediction_label = folder_to_name[predicted_wnid]
class_def = class_descriptions[predicted_wnid]
session_state.is_classifier_correct[query_id] = classifier_predictions[query_id][
f"{CLASSIFIER_TAG.upper()}-Output"
]
################################### SHOW QUERY and PREDICTION
col1, col2 = st.columns(2)
with col1:
st.image(load_query(current_query), caption=f"Query ID: {query_id}")
with col2:
################################### SHOW DESCRIPTION OF CLASS
with st.expander("Show Class Description"):
st.write(f"**Name**: {prediction_label}")
st.write("**Class Definition**:")
st.markdown("`" + class_def + "`")
st.image(
Image.open(f"demonstrations/{predicted_wnid}.jpeg"),
caption=f"Class Explanation",
use_column_width=True,
)
default_value = 0
if query_id in session_state.user_feedback.keys():
if session_state.user_feedback[query_id] == "Correct":
default_value = 1
elif session_state.user_feedback[query_id] == "Wrong":
default_value = 2
session_state.user_feedback[query_id] = st.radio(
"What do you think about model's prediction?",
("-", "Correct", "Wrong"),
key=query_id,
index=default_value,
)
st.write(f"**Model Prediction**: {prediction_label}")
st.write(f"**Model Confidence**: {prediction_confidence}")
################################### SHOW Model Explanation
if selected_xai_tool is not None:
st.image(
selected_xai_tool(current_query),
caption=f"Explaination",
use_column_width=True,
)
################################### SHOW DEBUG INFO
if st.button("Debug: Show Everything"):
st.image(Image.open(current_query))
def render_results():
user_correct_guess = 0
for q in session_state.user_feedback.keys():
uf = True if session_state.user_feedback[q] == "Correct" else False
if session_state.is_classifier_correct[q] == uf:
user_correct_guess += 1
st.write(
f"User performance on {CLASSIFIER_TAG}: {user_correct_guess} out of {len( session_state.user_feedback)} Correct"
)
st.markdown("## User Performance Breakdown")
categories = [
"Correct",
"Wrong",
] # set(session_state.is_classifier_correct.values())
breakdown_stats_correct = {c: 0 for c in categories}
breakdown_stats_wrong = {c: 0 for c in categories}
experiment_summary = []
for q in session_state.user_feedback.keys():
category = "Correct" if session_state.is_classifier_correct[q] else "Wrong"
is_user_correct = category == session_state.user_feedback[q]
if is_user_correct:
breakdown_stats_correct[category] += 1
else:
breakdown_stats_wrong[category] += 1
experiment_summary.append(
[
q,
classifier_predictions[q]["real-gts"],
folder_to_name[
classifier_predictions[q][f"{CLASSIFIER_TAG}-predictions"]
],
category,
session_state.user_feedback[q],
is_user_correct,
]
)
################################### Summary Table
experiment_summary_df = pd.DataFrame.from_records(
experiment_summary,
columns=[
"Query",
"GT Labels",
f"{CLASSIFIER_TAG} Prediction",
"Category",
"User Prediction",
"Is User Prediction Correct",
],
)
st.write("Summary", experiment_summary_df)
csv = convert_df(experiment_summary_df)
st.download_button(
"Press to Download", csv, "summary.csv", "text/csv", key="download-records"
)
################################### SHOW BREAKDOWN
user_pf_by_model_pred = experiment_summary_df.groupby("Category").agg(
{"Is User Prediction Correct": ["count", "sum", "mean"]}
)
# rename columns
user_pf_by_model_pred.columns = user_pf_by_model_pred.columns.droplevel(0)
user_pf_by_model_pred.columns = [
"Count",
"Correct User Guess",
"Mean User Performance",
]
user_pf_by_model_pred.index.name = "Model Prediction"
st.write("User performance break down by Model prediction:", user_pf_by_model_pred)
csv = convert_df(user_pf_by_model_pred)
st.download_button(
"Press to Download",
csv,
"user-performance-by-model-prediction.csv",
"text/csv",
key="download-performance-by-model-prediction",
)
################################### CONFUSION MATRIX
confusion_matrix = pd.crosstab(
experiment_summary_df["Category"],
experiment_summary_df["User Prediction"],
rownames=["Actual"],
colnames=["Predicted"],
)
st.write("Confusion Matrix", confusion_matrix)
csv = convert_df(confusion_matrix)
st.download_button(
"Press to Download",
csv,
"confusion-matrix.csv",
"text/csv",
key="download-confusiion-matrix",
)
def render_menu():
# Render the readme as markdown using st.markdown.
readme_text = st.markdown(
"""
# Instructions
```
When testing this study, you should first see the class definition, then hide the expander and see the query.
```
"""
)
app_mode = st.selectbox(
"Choose the page to show:",
["Experiment Instruction", "Start Experiment", "See the Results"],
)
if app_mode == "Experiment Instruction":
st.success("To continue select an option in the dropdown menu.")
elif app_mode == "Start Experiment":
# Clear Canvas
readme_text.empty()
page_id = session_state.page
col1, col4, col2, col3 = st.columns(4)
prev_page = col1.button("Previous Image")
if prev_page:
page_id -= 1
if page_id < 1:
page_id = 1
next_page = col2.button("Next Image")
if next_page:
page_id += 1
if page_id > NUMBER_OF_TRIALS:
page_id = NUMBER_OF_TRIALS
if page_id == NUMBER_OF_TRIALS:
st.success(
'You have reached the last image. Please go to the "Results" page to see your performance.'
)
if st.button("View"):
app_mode = "See the Results"
if col3.button("Resample"):
st.write("Restarting ...")
page_id = 1
session_state.first_run = 1
resmaple_queries()
session_state.page = page_id
st.write(f"Render Experiment: {session_state.page}")
render_experiment(session_state.page - 1)
elif app_mode == "See the Results":
readme_text.empty()
st.write("Results Summary")
render_results()
def main():
global app_mode
global session_state
global selected_xai_tool
global CLASSIFIER_TAG
# Set the session state
# State Management and General Setup
st.set_page_config(layout="wide")
st.title("Visual CorrespondenceHuman Study - ImageNet")
options = [
"Unselected",
"NOXAI",
"KNN",
"EMD-Corr Nearest Neighbors",
"EMD-Corr Correspondence",
"CHM-Corr Nearest Neighbors",
"CHM-Corr Correspondence",
]
st.markdown(
""" <style>
div[role="radiogroup"] > :first-child{
display: none !important;
}
</style>
""",
unsafe_allow_html=True,
)
if session_state.XAI_tool == "Unselected":
default = options.index(session_state.XAI_tool)
session_state.XAI_tool = st.radio(
"What explaination tool do you want to evaluate?",
options,
key="which_xai",
index=default,
)
# print(session_state.XAI_tool)
if session_state.XAI_tool != "Unselected":
st.markdown(f"## SELECTED METHOD ``{session_state.XAI_tool}``")
if session_state.XAI_tool == "NOXAI":
CLASSIFIER_TAG = "knn"
selected_xai_tool = None
elif session_state.XAI_tool == "KNN":
selected_xai_tool = load_knn_nns
CLASSIFIER_TAG = "knn"
elif session_state.XAI_tool == "CHM-Corr Nearest Neighbors":
selected_xai_tool = load_chm_nns
CLASSIFIER_TAG = "CHM"
elif session_state.XAI_tool == "CHM-Corr Correspondence":
selected_xai_tool = load_chm_corrs
CLASSIFIER_TAG = "CHM"
elif session_state.XAI_tool == "EMD-Corr Nearest Neighbors":
selected_xai_tool = load_emd_nns
CLASSIFIER_TAG = "EMD"
elif session_state.XAI_tool == "EMD-Corr Correspondence":
selected_xai_tool = load_emd_corrs
CLASSIFIER_TAG = "EMD"
resmaple_queries()
render_menu()
if __name__ == "__main__":
main()