import sys import numpy as np import pandas as pd from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics from sklearn.metrics import f1_score from collections import OrderedDict EXPERIMENT_ROOT = "/mnt/lerna/experiments" def softmax(x, axis=-1): # Subtract the maximum value for numerical stability x = x - np.max(x, axis=axis, keepdims=True) # Compute the exponentials of the shifted input exps = np.exp(x) # Compute the sum of exponentials along the last axis exps_sum = np.sum(exps, axis=axis, keepdims=True) # Compute the softmax probabilities softmax_probs = exps / exps_sum return softmax_probs def predictions_loader(predictions_path): data = np.load(predictions_path)["arr_0"] dataset_idx = data[:, -1] labels = data[:, -2] if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]): data = data[:, :-2] # logits predictions = np.argmax(data, -1) else: labels = data[:, -2].astype(int) predictions = data[:, -3].astype(int) data = data[:, :-3] # logits return data, labels, predictions, dataset_idx def compare_errors(): """ from scipy.stats import pearsonr, spearmanr #idx = [x for x in strategy_correctness['first'] if x ==0] spearmanr(strategy_correctness['first'], strategy_correctness['second']) #SignificanceResult(statistic=0.5429413617297623, pvalue=0.0) spearmanr(strategy_correctness['first'], strategy_correctness['last']) #SignificanceResult(statistic=0.5005224326802595, pvalue=0.0) pearsonr(strategy_correctness['first'], strategy_correctness['second']) #PearsonRResult(statistic=0.5429413617297617, pvalue=0.0) pearsonr(strategy_correctness['first'], strategy_correctness['last']) #PearsonRResult(statistic=0.5005224326802583, pvalue=0.0) """ for dataset in ["rvl_cdip_n_mp"]: # "DiT-base-rvl_cdip_MP", strategy_logits = {} strategy_correctness = {} for strategy in ["first", "second", "last"]: path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path) strategy_correctness[strategy] = (predictions == labels).astype(int) print("Base accuracy of first: ", np.mean(strategy_correctness["first"])) firstcorrectifsecondcorrect = [ x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"]) ] # if x ==0] print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}") firstcorrectiflastcorrect = [ x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"]) ] # if x ==0] print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}") firstcorrectifsecondorlastcorrect = [ x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i]) for i, x in enumerate(strategy_correctness["first"]) ] # if x ==0] print( f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}" ) # inverse print("Base accuracy of second: ", np.mean(strategy_correctness["second"])) secondcorrectiffirstcorrect = [ x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"]) ] # if x ==0] print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}") secondcorrectiflastcorrect = [ x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"]) ] # if x ==0] print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}") # inverse second print("Base accuracy of last: ", np.mean(strategy_correctness["last"])) lastcorrectiffirstcorrect = [ x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"]) ] # if x ==0] print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}") lastcorrectifsecondcorrect = [ x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"]) ] # if x ==0] print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}") def review_one(path): collect = OrderedDict() try: logits, labels, predictions, dataset_idx = predictions_loader(path) except Exception as e: print(f"something went wrong in inference loading {e}") return # print(logits.shape, labels.shape, logits[-1], labels[-1], dataset_idx[-1]) y_correct = (predictions == labels).astype(int) acc = np.mean(y_correct) p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)]) res = aurc_logits( y_correct, p_hat, plot=False, get_cache=True, use_as_is=True ) # DEV: implementation hack to allow for passing I[Y==y_hat] and p_hat instead of logits and label indices collect["aurc"] = res["aurc"] collect["accuracy"] = 100 * acc collect["f1"] = 100 * f1_score(labels, predictions, average="weighted") collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro") collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True) df = pd.DataFrame.from_dict([collect]) # df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]] print(df.to_latex()) print(df.to_string()) return collect, res def experiments_review(): STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"] for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]: collect = {} aurcs = [] caches = [] for strategy in STRATEGIES: path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz" collect[strategy], res = review_one(path) aurcs.append(res["aurc"]) caches.append(res["cache"]) df = pd.DataFrame.from_dict(collect, orient="index") df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]] print(df.to_latex()) print(df.to_string()) """ subset = [0, 1, 2] multi_aurc_plot( [x for i, x in enumerate(caches) if i in subset], [x for i, x in enumerate(STRATEGIES) if i in subset], aurcs=[x for i, x in enumerate(aurcs) if i in subset], ) """ if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""") DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz" parser.add_argument( "predictions_path", type=str, default=DEFAULT, nargs="?", help="path to predictions", ) args = parser.parse_args() if args.predictions_path == DEFAULT: experiments_review() compare_errors() sys.exit(1) print(f"Running default experiment on {args.predictions_path}") review_one(args.predictions_path)