src / load_predictions.py
bdpc's picture
Upload 9 files
1ceb840
import sys
import numpy as np
import pandas as pd
from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics
from sklearn.metrics import f1_score
from collections import OrderedDict
EXPERIMENT_ROOT = "/mnt/lerna/experiments"
def softmax(x, axis=-1):
# Subtract the maximum value for numerical stability
x = x - np.max(x, axis=axis, keepdims=True)
# Compute the exponentials of the shifted input
exps = np.exp(x)
# Compute the sum of exponentials along the last axis
exps_sum = np.sum(exps, axis=axis, keepdims=True)
# Compute the softmax probabilities
softmax_probs = exps / exps_sum
return softmax_probs
def predictions_loader(predictions_path):
data = np.load(predictions_path)["arr_0"]
dataset_idx = data[:, -1]
labels = data[:, -2]
if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]):
data = data[:, :-2] # logits
predictions = np.argmax(data, -1)
else:
labels = data[:, -2].astype(int)
predictions = data[:, -3].astype(int)
data = data[:, :-3] # logits
return data, labels, predictions, dataset_idx
def compare_errors():
"""
from scipy.stats import pearsonr, spearmanr
#idx = [x for x in strategy_correctness['first'] if x ==0]
spearmanr(strategy_correctness['first'], strategy_correctness['second'])
#SignificanceResult(statistic=0.5429413617297623, pvalue=0.0)
spearmanr(strategy_correctness['first'], strategy_correctness['last'])
#SignificanceResult(statistic=0.5005224326802595, pvalue=0.0)
pearsonr(strategy_correctness['first'], strategy_correctness['second'])
#PearsonRResult(statistic=0.5429413617297617, pvalue=0.0)
pearsonr(strategy_correctness['first'], strategy_correctness['last'])
#PearsonRResult(statistic=0.5005224326802583, pvalue=0.0)
"""
for dataset in ["rvl_cdip_n_mp"]: # "DiT-base-rvl_cdip_MP",
strategy_logits = {}
strategy_correctness = {}
for strategy in ["first", "second", "last"]:
path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz"
strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path)
strategy_correctness[strategy] = (predictions == labels).astype(int)
print("Base accuracy of first: ", np.mean(strategy_correctness["first"]))
firstcorrectifsecondcorrect = [
x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"])
] # if x ==0]
print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}")
firstcorrectiflastcorrect = [
x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"])
] # if x ==0]
print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}")
firstcorrectifsecondorlastcorrect = [
x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i])
for i, x in enumerate(strategy_correctness["first"])
] # if x ==0]
print(
f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}"
)
# inverse
print("Base accuracy of second: ", np.mean(strategy_correctness["second"]))
secondcorrectiffirstcorrect = [
x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"])
] # if x ==0]
print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}")
secondcorrectiflastcorrect = [
x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"])
] # if x ==0]
print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}")
# inverse second
print("Base accuracy of last: ", np.mean(strategy_correctness["last"]))
lastcorrectiffirstcorrect = [
x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"])
] # if x ==0]
print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}")
lastcorrectifsecondcorrect = [
x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"])
] # if x ==0]
print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}")
def review_one(path):
collect = OrderedDict()
try:
logits, labels, predictions, dataset_idx = predictions_loader(path)
except Exception as e:
print(f"something went wrong in inference loading {e}")
return
# print(logits.shape, labels.shape, logits[-1], labels[-1], dataset_idx[-1])
y_correct = (predictions == labels).astype(int)
acc = np.mean(y_correct)
p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)])
res = aurc_logits(
y_correct, p_hat, plot=False, get_cache=True, use_as_is=True
) # DEV: implementation hack to allow for passing I[Y==y_hat] and p_hat instead of logits and label indices
collect["aurc"] = res["aurc"]
collect["accuracy"] = 100 * acc
collect["f1"] = 100 * f1_score(labels, predictions, average="weighted")
collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro")
collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True)
df = pd.DataFrame.from_dict([collect])
# df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]]
print(df.to_latex())
print(df.to_string())
return collect, res
def experiments_review():
STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"]
for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]:
collect = {}
aurcs = []
caches = []
for strategy in STRATEGIES:
path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz"
collect[strategy], res = review_one(path)
aurcs.append(res["aurc"])
caches.append(res["cache"])
df = pd.DataFrame.from_dict(collect, orient="index")
df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]]
print(df.to_latex())
print(df.to_string())
"""
subset = [0, 1, 2]
multi_aurc_plot(
[x for i, x in enumerate(caches) if i in subset],
[x for i, x in enumerate(STRATEGIES) if i in subset],
aurcs=[x for i, x in enumerate(aurcs) if i in subset],
)
"""
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""")
DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz"
parser.add_argument(
"predictions_path",
type=str,
default=DEFAULT,
nargs="?",
help="path to predictions",
)
args = parser.parse_args()
if args.predictions_path == DEFAULT:
experiments_review()
compare_errors()
sys.exit(1)
print(f"Running default experiment on {args.predictions_path}")
review_one(args.predictions_path)