|
import sys |
|
import numpy as np |
|
import pandas as pd |
|
from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics |
|
from sklearn.metrics import f1_score |
|
from collections import OrderedDict |
|
|
|
EXPERIMENT_ROOT = "/mnt/lerna/experiments" |
|
|
|
|
|
def softmax(x, axis=-1): |
|
|
|
x = x - np.max(x, axis=axis, keepdims=True) |
|
|
|
|
|
exps = np.exp(x) |
|
|
|
|
|
exps_sum = np.sum(exps, axis=axis, keepdims=True) |
|
|
|
|
|
softmax_probs = exps / exps_sum |
|
|
|
return softmax_probs |
|
|
|
|
|
def predictions_loader(predictions_path): |
|
data = np.load(predictions_path)["arr_0"] |
|
dataset_idx = data[:, -1] |
|
labels = data[:, -2] |
|
if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]): |
|
data = data[:, :-2] |
|
predictions = np.argmax(data, -1) |
|
else: |
|
labels = data[:, -2].astype(int) |
|
predictions = data[:, -3].astype(int) |
|
data = data[:, :-3] |
|
return data, labels, predictions, dataset_idx |
|
|
|
|
|
def compare_errors(): |
|
""" |
|
from scipy.stats import pearsonr, spearmanr |
|
#idx = [x for x in strategy_correctness['first'] if x ==0] |
|
spearmanr(strategy_correctness['first'], strategy_correctness['second']) |
|
#SignificanceResult(statistic=0.5429413617297623, pvalue=0.0) |
|
spearmanr(strategy_correctness['first'], strategy_correctness['last']) |
|
#SignificanceResult(statistic=0.5005224326802595, pvalue=0.0) |
|
|
|
pearsonr(strategy_correctness['first'], strategy_correctness['second']) |
|
#PearsonRResult(statistic=0.5429413617297617, pvalue=0.0) |
|
pearsonr(strategy_correctness['first'], strategy_correctness['last']) |
|
#PearsonRResult(statistic=0.5005224326802583, pvalue=0.0) |
|
""" |
|
for dataset in ["rvl_cdip_n_mp"]: |
|
strategy_logits = {} |
|
strategy_correctness = {} |
|
for strategy in ["first", "second", "last"]: |
|
path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" |
|
|
|
strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path) |
|
strategy_correctness[strategy] = (predictions == labels).astype(int) |
|
|
|
print("Base accuracy of first: ", np.mean(strategy_correctness["first"])) |
|
firstcorrectifsecondcorrect = [ |
|
x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"]) |
|
] |
|
print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}") |
|
firstcorrectiflastcorrect = [ |
|
x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"]) |
|
] |
|
print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}") |
|
|
|
firstcorrectifsecondorlastcorrect = [ |
|
x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i]) |
|
for i, x in enumerate(strategy_correctness["first"]) |
|
] |
|
print( |
|
f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}" |
|
) |
|
|
|
|
|
print("Base accuracy of second: ", np.mean(strategy_correctness["second"])) |
|
secondcorrectiffirstcorrect = [ |
|
x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"]) |
|
] |
|
print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}") |
|
secondcorrectiflastcorrect = [ |
|
x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"]) |
|
] |
|
print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}") |
|
|
|
|
|
print("Base accuracy of last: ", np.mean(strategy_correctness["last"])) |
|
lastcorrectiffirstcorrect = [ |
|
x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"]) |
|
] |
|
print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}") |
|
lastcorrectifsecondcorrect = [ |
|
x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"]) |
|
] |
|
print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}") |
|
|
|
|
|
def review_one(path): |
|
collect = OrderedDict() |
|
try: |
|
logits, labels, predictions, dataset_idx = predictions_loader(path) |
|
except Exception as e: |
|
print(f"something went wrong in inference loading {e}") |
|
return |
|
|
|
y_correct = (predictions == labels).astype(int) |
|
acc = np.mean(y_correct) |
|
p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)]) |
|
|
|
res = aurc_logits( |
|
y_correct, p_hat, plot=False, get_cache=True, use_as_is=True |
|
) |
|
|
|
collect["aurc"] = res["aurc"] |
|
collect["accuracy"] = 100 * acc |
|
collect["f1"] = 100 * f1_score(labels, predictions, average="weighted") |
|
collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro") |
|
collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True) |
|
|
|
df = pd.DataFrame.from_dict([collect]) |
|
|
|
print(df.to_latex()) |
|
print(df.to_string()) |
|
return collect, res |
|
|
|
|
|
def experiments_review(): |
|
STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"] |
|
for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]: |
|
collect = {} |
|
aurcs = [] |
|
caches = [] |
|
for strategy in STRATEGIES: |
|
path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" |
|
collect[strategy], res = review_one(path) |
|
aurcs.append(res["aurc"]) |
|
caches.append(res["cache"]) |
|
|
|
df = pd.DataFrame.from_dict(collect, orient="index") |
|
df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]] |
|
print(df.to_latex()) |
|
print(df.to_string()) |
|
""" |
|
subset = [0, 1, 2] |
|
multi_aurc_plot( |
|
[x for i, x in enumerate(caches) if i in subset], |
|
[x for i, x in enumerate(STRATEGIES) if i in subset], |
|
aurcs=[x for i, x in enumerate(aurcs) if i in subset], |
|
) |
|
""" |
|
|
|
|
|
if __name__ == "__main__": |
|
from argparse import ArgumentParser |
|
|
|
parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""") |
|
DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz" |
|
parser.add_argument( |
|
"predictions_path", |
|
type=str, |
|
default=DEFAULT, |
|
nargs="?", |
|
help="path to predictions", |
|
) |
|
|
|
args = parser.parse_args() |
|
if args.predictions_path == DEFAULT: |
|
experiments_review() |
|
compare_errors() |
|
sys.exit(1) |
|
|
|
print(f"Running default experiment on {args.predictions_path}") |
|
review_one(args.predictions_path) |
|
|