Upload 9 files
Browse files- DiT_inference.py +99 -0
- README.md +18 -0
- experiments.sh +23 -0
- inference_methods.py +114 -0
- load_predictions.py +180 -0
- mapping_functions.py +157 -0
- metrics.py +309 -0
- pyproject.toml +36 -0
- simulate_document_classifier.py +222 -0
DiT_inference.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch
|
5 |
+
from datasets import load_dataset, ClassLabel
|
6 |
+
from datasets import Features, Array3D
|
7 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
|
8 |
+
from metrics import apply_metrics
|
9 |
+
|
10 |
+
|
11 |
+
def process_label_ids(batch, remapper, label_column="label"):
|
12 |
+
batch[label_column] = [remapper[label_id] for label_id in batch[label_column]]
|
13 |
+
return batch
|
14 |
+
|
15 |
+
|
16 |
+
CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None
|
17 |
+
|
18 |
+
|
19 |
+
def main(args):
|
20 |
+
dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR)
|
21 |
+
if args.dataset == "rvl_cdip":
|
22 |
+
dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) # corrupt sample
|
23 |
+
batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000
|
24 |
+
|
25 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
26 |
+
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
27 |
+
|
28 |
+
label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()}
|
29 |
+
data_idx2label = dict(zip(enumerate(dataset.features["label"].names)))
|
30 |
+
data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)}
|
31 |
+
model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
|
32 |
+
diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]]
|
33 |
+
|
34 |
+
if diff:
|
35 |
+
print(f"aligning labels {diff}")
|
36 |
+
print(f"model labels: {model_idx2label}")
|
37 |
+
print(f"data labels: {data_idx2label}")
|
38 |
+
print(f"Remapping to {label2idx}") # YET cannot change the length of the labels
|
39 |
+
|
40 |
+
remapper = {}
|
41 |
+
for k, v in label2idx.items():
|
42 |
+
if k in data_label2idx:
|
43 |
+
remapper[data_label2idx[k]] = v
|
44 |
+
|
45 |
+
print(remapper)
|
46 |
+
new_features = Features(
|
47 |
+
{
|
48 |
+
**{k: v for k, v in dataset.features.items() if k != "label"},
|
49 |
+
"label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())),
|
50 |
+
}
|
51 |
+
)
|
52 |
+
|
53 |
+
dataset = dataset.map(
|
54 |
+
lambda example: process_label_ids(example, remapper),
|
55 |
+
features=new_features,
|
56 |
+
batched=True,
|
57 |
+
batch_size=batch_size,
|
58 |
+
desc="Aligning the labels",
|
59 |
+
)
|
60 |
+
|
61 |
+
features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))})
|
62 |
+
|
63 |
+
encoded_dataset = dataset.map(
|
64 |
+
lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]),
|
65 |
+
batched=True,
|
66 |
+
batch_size=batch_size,
|
67 |
+
features=features,
|
68 |
+
)
|
69 |
+
encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"])
|
70 |
+
BATCH_SIZE = 16
|
71 |
+
dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE)
|
72 |
+
|
73 |
+
all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros(
|
74 |
+
len(encoded_dataset), dtype=int
|
75 |
+
)
|
76 |
+
|
77 |
+
count = 0
|
78 |
+
for i, batch in tqdm(enumerate(dataloader)):
|
79 |
+
with torch.no_grad():
|
80 |
+
outputs = model(batch["pixel_values"])
|
81 |
+
logits = outputs.logits
|
82 |
+
all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy()
|
83 |
+
all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy()
|
84 |
+
count += len(batch["label"])
|
85 |
+
|
86 |
+
all_references = np.array(all_references)
|
87 |
+
all_logits = np.array(all_logits)
|
88 |
+
results = apply_metrics(all_references, all_logits)
|
89 |
+
print(results)
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == "__main__":
|
93 |
+
from argparse import ArgumentParser
|
94 |
+
|
95 |
+
parser = ArgumentParser("""DiT inference on dataset test set""")
|
96 |
+
parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated")
|
97 |
+
args = parser.parse_args()
|
98 |
+
|
99 |
+
main(args)
|
README.md
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Beyond Document Page Classification
|
2 |
+
|
3 |
+
## Installation
|
4 |
+
|
5 |
+
The scripts require [python >= 3.8](https://www.python.org/downloads/release/python-380/) to run.
|
6 |
+
We will create a fresh virtualenvironment in which to install all required packages.
|
7 |
+
```sh
|
8 |
+
mkvirtualenv -p /usr/bin/python3 BYD
|
9 |
+
```
|
10 |
+
|
11 |
+
Using poetry and the readily defined pyproject.toml, we will install all required packages
|
12 |
+
```sh
|
13 |
+
workon BYD
|
14 |
+
pip3 install poetry
|
15 |
+
poetry install
|
16 |
+
```
|
17 |
+
|
18 |
+
## something
|
experiments.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Single page f_p experiments
|
2 |
+
command="python3 DiT_inference.py -d jordyvl/RVL_CDIP_N" ; echo $command ; $command
|
3 |
+
command="python3 DiT_inference.py -d jordyvl/RVL_CDIP_N" ; echo $command ; $command
|
4 |
+
|
5 |
+
|
6 |
+
# Multi-page f_d experiments
|
7 |
+
command="python3 simulate_document_classifier.py grid -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
8 |
+
command="python3 simulate_document_classifier.py first -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
9 |
+
command="python3 simulate_document_classifier.py second -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
10 |
+
command="python3 simulate_document_classifier.py last -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
11 |
+
command="python3 simulate_document_classifier.py max_confidence -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
12 |
+
command="python3 simulate_document_classifier.py soft_voting -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
13 |
+
command="python3 simulate_document_classifier.py hard_voting -d bdpc/rvl_cdip_mp" ; echo $command ; $command
|
14 |
+
|
15 |
+
|
16 |
+
command="python3 simulate_document_classifier.py grid -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
17 |
+
command="python3 simulate_document_classifier.py first -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
18 |
+
command="python3 simulate_document_classifier.py second -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
19 |
+
command="python3 simulate_document_classifier.py last -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
20 |
+
command="python3 simulate_document_classifier.py hard_voting -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
21 |
+
command="python3 simulate_document_classifier.py soft_voting -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
22 |
+
command="python3 simulate_document_classifier.py max_confidence -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
|
23 |
+
|
inference_methods.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from enum import Enum
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
|
5 |
+
def unravel_index(index, shape):
|
6 |
+
out = []
|
7 |
+
for dim in reversed(shape):
|
8 |
+
out.append(index % dim)
|
9 |
+
index = index // dim
|
10 |
+
return tuple(reversed(out))
|
11 |
+
|
12 |
+
|
13 |
+
class ExplicitEnum(Enum):
|
14 |
+
"""
|
15 |
+
Enum with more explicit error message for missing values or getting all options
|
16 |
+
"""
|
17 |
+
|
18 |
+
@classmethod
|
19 |
+
def _missing_(cls, value):
|
20 |
+
raise ValueError(
|
21 |
+
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
|
22 |
+
)
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def options(cls):
|
26 |
+
return list(cls._value2member_map_.keys())
|
27 |
+
|
28 |
+
|
29 |
+
class InferenceMethod(ExplicitEnum):
|
30 |
+
"""All the implemented inference methods"""
|
31 |
+
|
32 |
+
FIRST = "first" # check on data setup
|
33 |
+
SECOND = "second" # robustness for multipage categories
|
34 |
+
LAST = "last" # robustness for multipage categories
|
35 |
+
|
36 |
+
GRID = "grid" # create a grid (equal spaced/ OCR density based)
|
37 |
+
# downscale resolution ; might be fine for classification
|
38 |
+
|
39 |
+
MAX_CONFIDENCE = "max_confidence" # page with highest confidence overall
|
40 |
+
SOFT_VOTING = "soft_voting" # sum conf/N -> logits/softmax
|
41 |
+
HARD_VOTING = "hard_voting" # count votes
|
42 |
+
|
43 |
+
@property
|
44 |
+
def scope(self):
|
45 |
+
if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]:
|
46 |
+
return "sample"
|
47 |
+
if self in [InferenceMethod.GRID]:
|
48 |
+
return "sample-grid" # single image yet transformation required
|
49 |
+
else:
|
50 |
+
return "iter"
|
51 |
+
|
52 |
+
def get_page_scope(self, pages):
|
53 |
+
if self.scope == "iter":
|
54 |
+
return pages
|
55 |
+
if self == InferenceMethod.GRID:
|
56 |
+
try:
|
57 |
+
return equal_image_grid(pages)
|
58 |
+
except Exception as e:
|
59 |
+
return pages[-1] # last to not positively bias
|
60 |
+
if self == InferenceMethod.FIRST:
|
61 |
+
return pages[0]
|
62 |
+
if self == InferenceMethod.SECOND:
|
63 |
+
if len(pages) > 1:
|
64 |
+
return pages[1]
|
65 |
+
return pages[0] # backoff
|
66 |
+
if self == InferenceMethod.LAST:
|
67 |
+
return pages[-1]
|
68 |
+
|
69 |
+
def apply_decision_strategy(self, page_logits):
|
70 |
+
"""
|
71 |
+
page logits is of shape [NUM_PAGES x CLASSES]
|
72 |
+
"""
|
73 |
+
if self == InferenceMethod.MAX_CONFIDENCE:
|
74 |
+
index = page_logits.argmax() # tensor with one number
|
75 |
+
indices = unravel_index(index, page_logits.shape)
|
76 |
+
print(f"The page which is max confident: {indices[0]}")
|
77 |
+
return indices[-1]
|
78 |
+
if self == InferenceMethod.HARD_VOTING:
|
79 |
+
return page_logits.argmax(-1).max()
|
80 |
+
if self == InferenceMethod.SOFT_VOTING:
|
81 |
+
return page_logits.mean(0).argmax(-1)
|
82 |
+
|
83 |
+
|
84 |
+
def equal_image_grid(images):
|
85 |
+
def compute_grid(n, max_cols=6):
|
86 |
+
equalDivisor = int(n**0.5)
|
87 |
+
cols = min(equalDivisor, max_cols)
|
88 |
+
rows = equalDivisor
|
89 |
+
if rows * cols >= n:
|
90 |
+
return rows, cols
|
91 |
+
cols += 1
|
92 |
+
if rows * cols >= n:
|
93 |
+
return rows, cols
|
94 |
+
while rows * cols < n:
|
95 |
+
rows += 1
|
96 |
+
return rows, cols
|
97 |
+
|
98 |
+
# assert len(images) == rows*cols
|
99 |
+
rows, cols = compute_grid(len(images))
|
100 |
+
|
101 |
+
# rescaling to min width [height padding]
|
102 |
+
images = [im for im in images if (im.height > 0) and (im.width > 0)] # could be NA
|
103 |
+
|
104 |
+
min_width = min(im.width for im in images)
|
105 |
+
images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images]
|
106 |
+
|
107 |
+
w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images])
|
108 |
+
|
109 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
110 |
+
grid_w, grid_h = grid.size
|
111 |
+
|
112 |
+
for i, img in enumerate(images):
|
113 |
+
grid.paste(img, box=(i % cols * w, i // cols * h))
|
114 |
+
return grid
|
load_predictions.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics
|
5 |
+
from sklearn.metrics import f1_score
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
EXPERIMENT_ROOT = "/mnt/lerna/experiments"
|
9 |
+
|
10 |
+
|
11 |
+
def softmax(x, axis=-1):
|
12 |
+
# Subtract the maximum value for numerical stability
|
13 |
+
x = x - np.max(x, axis=axis, keepdims=True)
|
14 |
+
|
15 |
+
# Compute the exponentials of the shifted input
|
16 |
+
exps = np.exp(x)
|
17 |
+
|
18 |
+
# Compute the sum of exponentials along the last axis
|
19 |
+
exps_sum = np.sum(exps, axis=axis, keepdims=True)
|
20 |
+
|
21 |
+
# Compute the softmax probabilities
|
22 |
+
softmax_probs = exps / exps_sum
|
23 |
+
|
24 |
+
return softmax_probs
|
25 |
+
|
26 |
+
|
27 |
+
def predictions_loader(predictions_path):
|
28 |
+
data = np.load(predictions_path)["arr_0"]
|
29 |
+
dataset_idx = data[:, -1]
|
30 |
+
labels = data[:, -2]
|
31 |
+
if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]):
|
32 |
+
data = data[:, :-2] # logits
|
33 |
+
predictions = np.argmax(data, -1)
|
34 |
+
else:
|
35 |
+
labels = data[:, -2].astype(int)
|
36 |
+
predictions = data[:, -3].astype(int)
|
37 |
+
data = data[:, :-3] # logits
|
38 |
+
return data, labels, predictions, dataset_idx
|
39 |
+
|
40 |
+
|
41 |
+
def compare_errors():
|
42 |
+
"""
|
43 |
+
from scipy.stats import pearsonr, spearmanr
|
44 |
+
#idx = [x for x in strategy_correctness['first'] if x ==0]
|
45 |
+
spearmanr(strategy_correctness['first'], strategy_correctness['second'])
|
46 |
+
#SignificanceResult(statistic=0.5429413617297623, pvalue=0.0)
|
47 |
+
spearmanr(strategy_correctness['first'], strategy_correctness['last'])
|
48 |
+
#SignificanceResult(statistic=0.5005224326802595, pvalue=0.0)
|
49 |
+
|
50 |
+
pearsonr(strategy_correctness['first'], strategy_correctness['second'])
|
51 |
+
#PearsonRResult(statistic=0.5429413617297617, pvalue=0.0)
|
52 |
+
pearsonr(strategy_correctness['first'], strategy_correctness['last'])
|
53 |
+
#PearsonRResult(statistic=0.5005224326802583, pvalue=0.0)
|
54 |
+
"""
|
55 |
+
for dataset in ["rvl_cdip_n_mp"]: # "DiT-base-rvl_cdip_MP",
|
56 |
+
strategy_logits = {}
|
57 |
+
strategy_correctness = {}
|
58 |
+
for strategy in ["first", "second", "last"]:
|
59 |
+
path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz"
|
60 |
+
|
61 |
+
strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path)
|
62 |
+
strategy_correctness[strategy] = (predictions == labels).astype(int)
|
63 |
+
|
64 |
+
print("Base accuracy of first: ", np.mean(strategy_correctness["first"]))
|
65 |
+
firstcorrectifsecondcorrect = [
|
66 |
+
x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"])
|
67 |
+
] # if x ==0]
|
68 |
+
print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}")
|
69 |
+
firstcorrectiflastcorrect = [
|
70 |
+
x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"])
|
71 |
+
] # if x ==0]
|
72 |
+
print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}")
|
73 |
+
|
74 |
+
firstcorrectifsecondorlastcorrect = [
|
75 |
+
x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i])
|
76 |
+
for i, x in enumerate(strategy_correctness["first"])
|
77 |
+
] # if x ==0]
|
78 |
+
print(
|
79 |
+
f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}"
|
80 |
+
)
|
81 |
+
|
82 |
+
# inverse
|
83 |
+
print("Base accuracy of second: ", np.mean(strategy_correctness["second"]))
|
84 |
+
secondcorrectiffirstcorrect = [
|
85 |
+
x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"])
|
86 |
+
] # if x ==0]
|
87 |
+
print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}")
|
88 |
+
secondcorrectiflastcorrect = [
|
89 |
+
x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"])
|
90 |
+
] # if x ==0]
|
91 |
+
print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}")
|
92 |
+
|
93 |
+
# inverse second
|
94 |
+
print("Base accuracy of last: ", np.mean(strategy_correctness["last"]))
|
95 |
+
lastcorrectiffirstcorrect = [
|
96 |
+
x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"])
|
97 |
+
] # if x ==0]
|
98 |
+
print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}")
|
99 |
+
lastcorrectifsecondcorrect = [
|
100 |
+
x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"])
|
101 |
+
] # if x ==0]
|
102 |
+
print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}")
|
103 |
+
|
104 |
+
|
105 |
+
def review_one(path):
|
106 |
+
collect = OrderedDict()
|
107 |
+
try:
|
108 |
+
logits, labels, predictions, dataset_idx = predictions_loader(path)
|
109 |
+
except Exception as e:
|
110 |
+
print(f"something went wrong in inference loading {e}")
|
111 |
+
return
|
112 |
+
# print(logits.shape, labels.shape, logits[-1], labels[-1], dataset_idx[-1])
|
113 |
+
y_correct = (predictions == labels).astype(int)
|
114 |
+
acc = np.mean(y_correct)
|
115 |
+
p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)])
|
116 |
+
|
117 |
+
res = aurc_logits(
|
118 |
+
y_correct, p_hat, plot=False, get_cache=True, use_as_is=True
|
119 |
+
) # DEV: implementation hack to allow for passing I[Y==y_hat] and p_hat instead of logits and label indices
|
120 |
+
|
121 |
+
collect["aurc"] = res["aurc"]
|
122 |
+
collect["accuracy"] = 100 * acc
|
123 |
+
collect["f1"] = 100 * f1_score(labels, predictions, average="weighted")
|
124 |
+
collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro")
|
125 |
+
collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True)
|
126 |
+
|
127 |
+
df = pd.DataFrame.from_dict([collect])
|
128 |
+
# df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]]
|
129 |
+
print(df.to_latex())
|
130 |
+
print(df.to_string())
|
131 |
+
return collect, res
|
132 |
+
|
133 |
+
|
134 |
+
def experiments_review():
|
135 |
+
STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"]
|
136 |
+
for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]:
|
137 |
+
collect = {}
|
138 |
+
aurcs = []
|
139 |
+
caches = []
|
140 |
+
for strategy in STRATEGIES:
|
141 |
+
path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz"
|
142 |
+
collect[strategy], res = review_one(path)
|
143 |
+
aurcs.append(res["aurc"])
|
144 |
+
caches.append(res["cache"])
|
145 |
+
|
146 |
+
df = pd.DataFrame.from_dict(collect, orient="index")
|
147 |
+
df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]]
|
148 |
+
print(df.to_latex())
|
149 |
+
print(df.to_string())
|
150 |
+
"""
|
151 |
+
subset = [0, 1, 2]
|
152 |
+
multi_aurc_plot(
|
153 |
+
[x for i, x in enumerate(caches) if i in subset],
|
154 |
+
[x for i, x in enumerate(STRATEGIES) if i in subset],
|
155 |
+
aurcs=[x for i, x in enumerate(aurcs) if i in subset],
|
156 |
+
)
|
157 |
+
"""
|
158 |
+
|
159 |
+
|
160 |
+
if __name__ == "__main__":
|
161 |
+
from argparse import ArgumentParser
|
162 |
+
|
163 |
+
parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""")
|
164 |
+
DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz"
|
165 |
+
parser.add_argument(
|
166 |
+
"predictions_path",
|
167 |
+
type=str,
|
168 |
+
default=DEFAULT,
|
169 |
+
nargs="?",
|
170 |
+
help="path to predictions",
|
171 |
+
)
|
172 |
+
|
173 |
+
args = parser.parse_args()
|
174 |
+
if args.predictions_path == DEFAULT:
|
175 |
+
experiments_review()
|
176 |
+
compare_errors()
|
177 |
+
sys.exit(1)
|
178 |
+
|
179 |
+
print(f"Running default experiment on {args.predictions_path}")
|
180 |
+
review_one(args.predictions_path)
|
mapping_functions.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from io import BytesIO
|
3 |
+
from tqdm import tqdm
|
4 |
+
import numpy as np
|
5 |
+
from typing import Callable, Dict, List # , Literal, NamedTuple, Optional, Tuple, Union
|
6 |
+
from PIL import Image as PIL_Image
|
7 |
+
from PIL.Image import Image
|
8 |
+
|
9 |
+
from datasets import logging
|
10 |
+
|
11 |
+
logger = logging.get_logger(__name__)
|
12 |
+
import PyPDF2
|
13 |
+
|
14 |
+
MAX_PAGES = 50
|
15 |
+
MAX_PDF_SIZE = 100000000 # almost 100MB
|
16 |
+
MIN_WIDTH, MIN_HEIGHT = 150, 150
|
17 |
+
import pdf2image
|
18 |
+
|
19 |
+
|
20 |
+
def pdf2image_image_extraction(pdf_stream):
|
21 |
+
try:
|
22 |
+
images: List[Image] = pdf2image.convert_from_bytes(pdf_stream)
|
23 |
+
return images
|
24 |
+
except Exception as e:
|
25 |
+
logger.warning(f"{e}")
|
26 |
+
|
27 |
+
|
28 |
+
def pdf_to_pixelvalues_extractor(example, feature_extractor, inference_method):
|
29 |
+
example["pages"] = 0
|
30 |
+
example["pixel_values"] = None
|
31 |
+
pixel_values = []
|
32 |
+
if len(example["file"]) > MAX_PDF_SIZE:
|
33 |
+
logger.warning(f"too large file {len(example['file'])}")
|
34 |
+
return example
|
35 |
+
try:
|
36 |
+
reader = PyPDF2.PdfReader(BytesIO(example["file"]))
|
37 |
+
except Exception as e:
|
38 |
+
logger.warning(f"read_pdf {e}")
|
39 |
+
return example
|
40 |
+
example["pages"] = len(reader.pages)
|
41 |
+
reached_page_limit = False
|
42 |
+
if "sample" in inference_method.scope and inference_method.scope != "sample-grid":
|
43 |
+
page_iterator = [inference_method.get_page_scope(reader.pages)]
|
44 |
+
else:
|
45 |
+
page_iterator = reader.pages
|
46 |
+
|
47 |
+
try:
|
48 |
+
for p, page in enumerate(page_iterator):
|
49 |
+
if reached_page_limit:
|
50 |
+
break
|
51 |
+
for image in page.images:
|
52 |
+
if len(pixel_values) == MAX_PAGES:
|
53 |
+
reached_page_limit = True
|
54 |
+
break
|
55 |
+
im = PIL_Image.open(BytesIO(image.data))
|
56 |
+
if im.width < MIN_WIDTH and im.height < MIN_HEIGHT:
|
57 |
+
continue
|
58 |
+
# try:
|
59 |
+
# except Exception as e:
|
60 |
+
# logger.warning(f"get_images {e}")
|
61 |
+
if inference_method.scope != "sample-grid":
|
62 |
+
im = feature_extractor([im.convert("RGB")])["pixel_values"][0]
|
63 |
+
pixel_values.append(im)
|
64 |
+
except Exception as e:
|
65 |
+
print(f"{example.get('id')} PyPDF get_images {e}")
|
66 |
+
pixel_values = []
|
67 |
+
|
68 |
+
if len(pixel_values) == 0:
|
69 |
+
# at least try with another API
|
70 |
+
try:
|
71 |
+
images = pdf2image_image_extraction(example["file"])
|
72 |
+
except Exception as e:
|
73 |
+
print(f"{example.get('id')} pdf2image get_images {e}")
|
74 |
+
images = []
|
75 |
+
|
76 |
+
if not images:
|
77 |
+
print(f"{example.get('id')} pdf2image has no images")
|
78 |
+
example["pages"] = 0
|
79 |
+
return example
|
80 |
+
|
81 |
+
# got lucky with pdf2image
|
82 |
+
example["pages"] = len(images)
|
83 |
+
for im in images:
|
84 |
+
if len(pixel_values) == MAX_PAGES:
|
85 |
+
reached_page_limit = True
|
86 |
+
break
|
87 |
+
if im.width < MIN_WIDTH and im.height < MIN_HEIGHT:
|
88 |
+
continue
|
89 |
+
if inference_method.scope != "sample-grid":
|
90 |
+
im = feature_extractor([im.convert("RGB")])["pixel_values"][0]
|
91 |
+
pixel_values.append(im)
|
92 |
+
|
93 |
+
if inference_method.scope == "sample-grid":
|
94 |
+
grid = inference_method.get_page_scope(pixel_values)
|
95 |
+
pixel_values = feature_extractor([grid.convert("RGB")])["pixel_values"][0]
|
96 |
+
elif "sample" in inference_method.scope:
|
97 |
+
pixel_values = pixel_values[0]
|
98 |
+
example["pixel_values"] = np.array(pixel_values)
|
99 |
+
return example
|
100 |
+
|
101 |
+
|
102 |
+
def nativepdf_to_pixelvalues_extractor(example, feature_extractor, inference_method):
|
103 |
+
IMPOSSIBLE = ["6483941-Letter-to-John-Campbell.pdf", "7276809-Ocoee-Newspaper-Pages.pdf"]
|
104 |
+
example["pages"] = 0
|
105 |
+
example["pixel_values"] = None
|
106 |
+
pixel_values = []
|
107 |
+
if len(example["file"]) > MAX_PDF_SIZE:
|
108 |
+
logger.warning(f"too large file {len(example['file'])}")
|
109 |
+
return example
|
110 |
+
|
111 |
+
# images = example['images']
|
112 |
+
try:
|
113 |
+
images = pdf2image_image_extraction(example["file"])
|
114 |
+
except Exception as e:
|
115 |
+
print(f"{example.get('id')} pdf2image get_images {e}")
|
116 |
+
images = []
|
117 |
+
|
118 |
+
if not images:
|
119 |
+
print(f"{example.get('id')} pdf2image has no images")
|
120 |
+
example["pages"] = 0
|
121 |
+
return example
|
122 |
+
|
123 |
+
# do image checks before and after
|
124 |
+
images = [im for im in images if im.width >= MIN_WIDTH and im.height >= MIN_HEIGHT]
|
125 |
+
|
126 |
+
if not images or (example.get("id") in IMPOSSIBLE and inference_method.scope == "sample-grid"):
|
127 |
+
print(f"{example.get('id')} pdf2image has no images")
|
128 |
+
example["pages"] = 0
|
129 |
+
return example
|
130 |
+
|
131 |
+
example["pages"] = len(images)
|
132 |
+
reached_page_limit = False
|
133 |
+
if "sample" in inference_method.scope and inference_method.scope != "sample-grid":
|
134 |
+
page_iterator = [inference_method.get_page_scope(images)]
|
135 |
+
else:
|
136 |
+
page_iterator = images
|
137 |
+
|
138 |
+
for im in page_iterator:
|
139 |
+
if len(pixel_values) == MAX_PAGES:
|
140 |
+
reached_page_limit = True
|
141 |
+
break
|
142 |
+
if inference_method.scope != "sample-grid":
|
143 |
+
im = feature_extractor([im.convert("RGB")])["pixel_values"][0]
|
144 |
+
pixel_values.append(im)
|
145 |
+
|
146 |
+
if len(pixel_values) == 0:
|
147 |
+
print(f"{example.get('id')} pdf2image has no valid images")
|
148 |
+
example["pages"] = 0
|
149 |
+
return example
|
150 |
+
|
151 |
+
if inference_method.scope == "sample-grid":
|
152 |
+
grid = inference_method.get_page_scope(pixel_values)
|
153 |
+
pixel_values = feature_extractor([grid.convert("RGB")])["pixel_values"][0]
|
154 |
+
elif "sample" in inference_method.scope:
|
155 |
+
pixel_values = pixel_values[0]
|
156 |
+
example["pixel_values"] = np.array(pixel_values)
|
157 |
+
return example
|
metrics.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# base calibration metric
|
2 |
+
|
3 |
+
# https://github.com/JonathanWenger/pycalib/blob/master/pycalib/scoring.py
|
4 |
+
# https://github.com/google-research/robustness_metrics/blob/master/robustness_metrics/metrics/uncertainty.py
|
5 |
+
# https://github.com/kjdhfg/fd-shifts
|
6 |
+
|
7 |
+
from __future__ import annotations
|
8 |
+
import scipy
|
9 |
+
import sklearn.utils.validation
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from functools import cached_property
|
13 |
+
from typing import Any
|
14 |
+
import pandas as pd
|
15 |
+
import numpy as np
|
16 |
+
import numpy.typing as npt
|
17 |
+
from collections import OrderedDict
|
18 |
+
from sklearn import metrics as skm
|
19 |
+
import evaluate as HF_evaluate
|
20 |
+
|
21 |
+
ArrayType = npt.NDArray[np.floating]
|
22 |
+
|
23 |
+
|
24 |
+
## https://github.com/IML-DKFZ/fd-shifts/blob/main/fd_shifts/analysis/confid_scores.py#L20
|
25 |
+
|
26 |
+
# ----------------------------------- general metrics with consistent metric(y_true, p_hat) API -----------------------------------
|
27 |
+
|
28 |
+
|
29 |
+
def f1_w(y_true, p_hat, y_hat=None):
|
30 |
+
if y_hat is None:
|
31 |
+
y_hat = np.argmax(p_hat, axis=-1)
|
32 |
+
return skm.f1_score(y_true, y_hat, average="weighted")
|
33 |
+
|
34 |
+
|
35 |
+
def f1_micro(y_true, p_hat, y_hat=None):
|
36 |
+
if y_hat is None:
|
37 |
+
y_hat = np.argmax(p_hat, axis=-1)
|
38 |
+
return skm.f1_score(y_true, y_hat, average="micro")
|
39 |
+
|
40 |
+
|
41 |
+
def f1_macro(y_true, p_hat, y_hat=None):
|
42 |
+
if y_hat is None:
|
43 |
+
y_hat = np.argmax(p_hat, axis=-1)
|
44 |
+
return skm.f1_score(y_true, y_hat, average="macro")
|
45 |
+
|
46 |
+
|
47 |
+
# Pure numpy and TF implementations of proper losses (as metrics) -----------------------------------
|
48 |
+
|
49 |
+
|
50 |
+
def brier_loss(y_true, p_hat):
|
51 |
+
r"""Brier score.
|
52 |
+
If the true label is k, while the predicted vector of probabilities is
|
53 |
+
[y_1, ..., y_n], then the Brier score is equal to
|
54 |
+
\sum_{i != k} y_i^2 + (y_k - 1)^2.
|
55 |
+
|
56 |
+
The smaller the Brier score, the better, hence the naming with "loss".
|
57 |
+
Across all items in a set N predictions, the Brier score measures the
|
58 |
+
mean squared difference between (1) the predicted probability assigned
|
59 |
+
to the possible outcomes for item i, and (2) the actual outcome.
|
60 |
+
Therefore, the lower the Brier score is for a set of predictions, the
|
61 |
+
better the predictions are calibrated. Note that the Brier score always
|
62 |
+
takes on a value between zero and one, since this is the largest
|
63 |
+
possible difference between a predicted probability (which must be
|
64 |
+
between zero and one) and the actual outcome (which can take on values
|
65 |
+
of only 0 and 1). The Brier loss is composed of refinement loss and
|
66 |
+
calibration loss.
|
67 |
+
|
68 |
+
"""
|
69 |
+
N = len(y_true)
|
70 |
+
K = p_hat.shape[-1]
|
71 |
+
|
72 |
+
if y_true.shape != p_hat.shape:
|
73 |
+
zeros = scipy.sparse.lil_matrix((N, K))
|
74 |
+
for i in range(N):
|
75 |
+
zeros[i, y_true[i]] = 1
|
76 |
+
|
77 |
+
if not np.isclose(np.sum(p_hat), len(p_hat)):
|
78 |
+
p_hat = scipy.special.softmax(p_hat, axis=-1)
|
79 |
+
|
80 |
+
return np.mean(np.sum(np.array(p_hat - zeros) ** 2, axis=1))
|
81 |
+
|
82 |
+
|
83 |
+
def nll(y_true, p_hat):
|
84 |
+
r"""Multi-class negative log likelihood.
|
85 |
+
If the true label is k, while the predicted vector of probabilities is
|
86 |
+
[p_1, ..., p_K], then the negative log likelihood is -log(p_k).
|
87 |
+
Does not require onehot encoding
|
88 |
+
"""
|
89 |
+
labels = np.arange(p_hat.shape[-1])
|
90 |
+
return skm.log_loss(y_true, p_hat, labels=labels)
|
91 |
+
|
92 |
+
|
93 |
+
def accuracy(y_true, p_hat):
|
94 |
+
y_pred = np.argmax(p_hat, axis=-1)
|
95 |
+
return sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_pred)
|
96 |
+
|
97 |
+
|
98 |
+
AURC_DISPLAY_SCALE = 1 # 1000
|
99 |
+
|
100 |
+
"""
|
101 |
+
From: https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1204/reports/custom/report52.pdf
|
102 |
+
|
103 |
+
The risk-coverage (RC) curve [28, 16] is a measure of the trade-off between the
|
104 |
+
coverage (the proportion of test data encountered), and the risk (the error rate under this coverage). Since each
|
105 |
+
prediction comes with a confidence score, given a list of prediction correctness Z paired up with the confidence
|
106 |
+
scores C, we sort C in reverse order to obtain sorted C'
|
107 |
+
, and its corresponding correctness Z'
|
108 |
+
. Note that the correctness is computed based on Exact Match (EM) as described in [22]. The RC curve is then obtained by
|
109 |
+
computing the risk of the coverage from the beginning of Z'
|
110 |
+
(most confident) to the end (least confident). In particular, these metrics evaluate
|
111 |
+
the relative order of the confidence score, which means that we want wrong
|
112 |
+
answers have lower confidence score than the correct ones, ignoring their absolute values.
|
113 |
+
|
114 |
+
Source: https://github.com/kjdhfg/fd-shifts
|
115 |
+
|
116 |
+
References:
|
117 |
+
-----------
|
118 |
+
|
119 |
+
[1] Jaeger, P.F., Lüth, C.T., Klein, L. and Bungert, T.J., 2022. A Call to Reflect on Evaluation Practices for Failure Detection in Image Classification. arXiv preprint arXiv:2211.15259.
|
120 |
+
|
121 |
+
[2] Kamath, A., Jia, R. and Liang, P., 2020. Selective Question Answering under Domain Shift. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics (pp. 5684-5696).
|
122 |
+
|
123 |
+
"""
|
124 |
+
|
125 |
+
|
126 |
+
@dataclass
|
127 |
+
class StatsCache:
|
128 |
+
"""Cache for stats computed by scikit used by multiple metrics.
|
129 |
+
|
130 |
+
Attributes:
|
131 |
+
confids (array_like): Confidence values
|
132 |
+
correct (array_like): Boolean array (best converted to int) where predictions were correct
|
133 |
+
"""
|
134 |
+
|
135 |
+
confids: npt.NDArray[Any]
|
136 |
+
correct: npt.NDArray[Any]
|
137 |
+
|
138 |
+
@cached_property
|
139 |
+
def roc_curve_stats(self) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
|
140 |
+
fpr, tpr, _ = skm.roc_curve(self.correct, self.confids)
|
141 |
+
return fpr, tpr
|
142 |
+
|
143 |
+
@property
|
144 |
+
def residuals(self) -> npt.NDArray[Any]:
|
145 |
+
return 1 - self.correct
|
146 |
+
|
147 |
+
@cached_property
|
148 |
+
def rc_curve_stats(self) -> tuple[list[float], list[float], list[float]]:
|
149 |
+
coverages = []
|
150 |
+
risks = []
|
151 |
+
|
152 |
+
n_residuals = len(self.residuals)
|
153 |
+
idx_sorted = np.argsort(self.confids)
|
154 |
+
|
155 |
+
coverage = n_residuals
|
156 |
+
error_sum = sum(self.residuals[idx_sorted])
|
157 |
+
|
158 |
+
coverages.append(coverage / n_residuals)
|
159 |
+
risks.append(error_sum / n_residuals)
|
160 |
+
|
161 |
+
weights = []
|
162 |
+
|
163 |
+
tmp_weight = 0
|
164 |
+
for i in range(0, len(idx_sorted) - 1):
|
165 |
+
coverage = coverage - 1
|
166 |
+
error_sum = error_sum - self.residuals[idx_sorted[i]]
|
167 |
+
selective_risk = error_sum / (n_residuals - 1 - i)
|
168 |
+
tmp_weight += 1
|
169 |
+
if i == 0 or self.confids[idx_sorted[i]] != self.confids[idx_sorted[i - 1]]:
|
170 |
+
coverages.append(coverage / n_residuals)
|
171 |
+
risks.append(selective_risk)
|
172 |
+
weights.append(tmp_weight / n_residuals)
|
173 |
+
tmp_weight = 0
|
174 |
+
|
175 |
+
# add a well-defined final point to the RC-curve.
|
176 |
+
if tmp_weight > 0:
|
177 |
+
coverages.append(0)
|
178 |
+
risks.append(risks[-1])
|
179 |
+
weights.append(tmp_weight / n_residuals)
|
180 |
+
return coverages, risks, weights
|
181 |
+
|
182 |
+
|
183 |
+
def aurc(stats_cache: StatsCache):
|
184 |
+
"""auc metric function
|
185 |
+
Args:
|
186 |
+
stats_cache (StatsCache): StatsCache object
|
187 |
+
Returns:
|
188 |
+
metric value
|
189 |
+
Important for assessment: LOWER is better!
|
190 |
+
"""
|
191 |
+
_, risks, weights = stats_cache.rc_curve_stats
|
192 |
+
return sum([(risks[i] + risks[i + 1]) * 0.5 * weights[i] for i in range(len(weights))]) * AURC_DISPLAY_SCALE
|
193 |
+
|
194 |
+
|
195 |
+
def aurc_logits(references, predictions, plot=False, get_cache=False, use_as_is=False):
|
196 |
+
if not use_as_is:
|
197 |
+
if not np.isclose(np.sum(references), len(references)):
|
198 |
+
references = (np.argmax(predictions, -1) == references).astype(int) # correctness
|
199 |
+
|
200 |
+
if not np.isclose(np.sum(predictions), len(predictions)):
|
201 |
+
predictions = scipy.special.softmax(predictions, axis=-1)
|
202 |
+
|
203 |
+
if predictions.ndim == 2:
|
204 |
+
predictions = np.max(predictions, -1)
|
205 |
+
|
206 |
+
cache = StatsCache(confids=predictions, correct=references)
|
207 |
+
|
208 |
+
if plot:
|
209 |
+
coverages, risks, weights = cache.rc_curve_stats
|
210 |
+
pd.options.plotting.backend = "plotly"
|
211 |
+
df = pd.DataFrame(zip(coverages, risks, weights), columns=["% Coverage", "% Risk", "weights"])
|
212 |
+
fig = df.plot(x="% Coverage", y="% Risk")
|
213 |
+
fig.show()
|
214 |
+
if get_cache:
|
215 |
+
return {"aurc": aurc(cache), "cache": cache}
|
216 |
+
return aurc(cache)
|
217 |
+
|
218 |
+
|
219 |
+
def multi_aurc_plot(caches, names, aurcs=None, verbose=False):
|
220 |
+
pd.options.plotting.backend = "plotly"
|
221 |
+
df = pd.DataFrame()
|
222 |
+
for cache, name in zip(caches, names):
|
223 |
+
coverages, risks, weights = cache.rc_curve_stats
|
224 |
+
df[name] = pd.Series(risks, index=coverages)
|
225 |
+
if verbose:
|
226 |
+
print(df.head(), df.index, df.columns)
|
227 |
+
fig = df.plot()
|
228 |
+
title = ""
|
229 |
+
if aurcs is not None:
|
230 |
+
title = "AURC: " + " - ".join([str(round(aurc, 4)) for aurc in aurcs])
|
231 |
+
fig.update_layout(title=title, xaxis_title="% Coverage", yaxis_title="% Risk")
|
232 |
+
fig.show()
|
233 |
+
|
234 |
+
|
235 |
+
def ece_logits(references, predictions, use_as_is=False):
|
236 |
+
if not use_as_is:
|
237 |
+
if not np.isclose(np.sum(predictions), len(predictions)):
|
238 |
+
predictions = scipy.special.softmax(predictions, axis=-1)
|
239 |
+
|
240 |
+
metric = HF_evaluate.load("jordyvl/ece")
|
241 |
+
kwargs = dict(
|
242 |
+
n_bins=min(len(predictions) - 1, 100),
|
243 |
+
scheme="equal-mass",
|
244 |
+
bin_range=[0, 1],
|
245 |
+
proxy="upper-edge",
|
246 |
+
p=1,
|
247 |
+
detail=False,
|
248 |
+
)
|
249 |
+
|
250 |
+
ece_result = metric.compute(
|
251 |
+
references=references,
|
252 |
+
predictions=predictions,
|
253 |
+
**kwargs,
|
254 |
+
)
|
255 |
+
return ece_result["ECE"]
|
256 |
+
|
257 |
+
|
258 |
+
METRICS = [accuracy, brier_loss, nll, f1_w, f1_macro, ece_logits, aurc_logits]
|
259 |
+
|
260 |
+
|
261 |
+
def apply_metrics(y_true, y_probs, metrics=METRICS):
|
262 |
+
predictive_performance = OrderedDict()
|
263 |
+
for metric in metrics:
|
264 |
+
try:
|
265 |
+
predictive_performance[f"{metric.__name__.replace('_logits', '')}"] = metric(y_true, y_probs)
|
266 |
+
except Exception as e:
|
267 |
+
print(e)
|
268 |
+
# print(json.dumps(predictive_performance, indent=4))
|
269 |
+
return predictive_performance
|
270 |
+
|
271 |
+
|
272 |
+
def evaluate_coverages(
|
273 |
+
logits, labels, confidence, coverages=[100, 99, 98, 97, 95, 90, 85, 80, 75, 70, 60, 50, 40, 30, 20, 10]
|
274 |
+
):
|
275 |
+
correctness = np.equal(logits.argmax(-1), labels)
|
276 |
+
abstention_results = list(zip(list(confidence), list(correctness)))
|
277 |
+
# sort the abstention results according to their reservations, from high to low
|
278 |
+
abstention_results.sort(key=lambda x: x[0])
|
279 |
+
# get the "correct or not" list for the sorted results
|
280 |
+
sorted_correct = list(map(lambda x: int(x[1]), abstention_results))
|
281 |
+
size = len(sorted_correct)
|
282 |
+
print("Abstention Logit: accuracy of coverage ") # 1-risk
|
283 |
+
for coverage in coverages:
|
284 |
+
covered_correct = sorted_correct[: round(size / 100 * coverage)]
|
285 |
+
print("{:.0f}: {:.3f}, ".format(coverage, sum(covered_correct) / len(covered_correct) * 100.0), end="")
|
286 |
+
print("")
|
287 |
+
|
288 |
+
sr_results = list(zip(list(logits.max(-1)), list(correctness)))
|
289 |
+
# sort the abstention results according to Softmax Response scores, from high to low
|
290 |
+
sr_results.sort(key=lambda x: -x[0])
|
291 |
+
# get the "correct or not" list for the sorted results
|
292 |
+
sorted_correct = list(map(lambda x: int(x[1]), sr_results))
|
293 |
+
size = len(sorted_correct)
|
294 |
+
print("Softmax Response: accuracy of coverage ")
|
295 |
+
for coverage in coverages:
|
296 |
+
covered_correct = sorted_correct[: round(size / 100 * coverage)]
|
297 |
+
print("{:.0f}: {:.3f}, ".format(coverage, sum(covered_correct) / len(covered_correct) * 100.0), end="")
|
298 |
+
print("")
|
299 |
+
|
300 |
+
|
301 |
+
def compute_metrics(eval_preds):
|
302 |
+
logits, labels = eval_preds # output of forward
|
303 |
+
if isinstance(logits, tuple):
|
304 |
+
confidence = logits[1]
|
305 |
+
logits = logits[0]
|
306 |
+
if confidence.size == logits.shape[0]:
|
307 |
+
evaluate_coverages(logits, labels, confidence)
|
308 |
+
results = apply_metrics(labels, logits)
|
309 |
+
return results
|
pyproject.toml
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[tool.poetry]
|
2 |
+
name = "BDPC"
|
3 |
+
version = "0.1.0"
|
4 |
+
description = "something"
|
5 |
+
authors = ["ANON"]
|
6 |
+
|
7 |
+
readme = 'README.md' # Markdown files are supported
|
8 |
+
keywords = ['research', 'evaluation', 'classification', 'documetns']
|
9 |
+
|
10 |
+
classifiers = [
|
11 |
+
"Development Status :: 1 - Planning",
|
12 |
+
"Environment :: Web Environment",
|
13 |
+
"Operating System :: POSIX :: Linux",
|
14 |
+
"Programming Language :: Python",
|
15 |
+
"Programming Language :: Python :: 3 :: Only",
|
16 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
17 |
+
]
|
18 |
+
|
19 |
+
[tool.poetry.dependencies]
|
20 |
+
python = "^3.8"
|
21 |
+
numpy = "^1.24.1"
|
22 |
+
evaluate = "^0.4.0"
|
23 |
+
scikit-learn = "^1.2.0"
|
24 |
+
transformers = "^4.26.0"
|
25 |
+
pandas = "^1.5.3"
|
26 |
+
torch = "^1.13.1"
|
27 |
+
pillow = "^9.4.0"
|
28 |
+
|
29 |
+
[tool.poetry.dev-dependencies]
|
30 |
+
pytest = "^3.10.1"
|
31 |
+
pytest-cov = "^2.9.0"
|
32 |
+
|
33 |
+
|
34 |
+
[build-system]
|
35 |
+
requires = ["poetry-core"]
|
36 |
+
build-backend = "poetry.masonry.api"
|
simulate_document_classifier.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from tqdm import tqdm
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from datasets import load_dataset, logging
|
7 |
+
from datasets import Features, Value, Image, Sequence, Array3D, Array4D
|
8 |
+
import evaluate
|
9 |
+
from metrics import apply_metrics
|
10 |
+
|
11 |
+
from transformers import AutoFeatureExtractor, AutoModelForImageClassification # DiT
|
12 |
+
|
13 |
+
logger = logging.get_logger(__name__)
|
14 |
+
|
15 |
+
from mapping_functions import (
|
16 |
+
pdf_to_pixelvalues_extractor,
|
17 |
+
nativepdf_to_pixelvalues_extractor,
|
18 |
+
)
|
19 |
+
from inference_methods import InferenceMethod
|
20 |
+
|
21 |
+
EXPERIMENT_ROOT = "/mnt/lerna/experiments"
|
22 |
+
|
23 |
+
|
24 |
+
def load_base_model():
|
25 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
26 |
+
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
|
27 |
+
return model, feature_extractor
|
28 |
+
|
29 |
+
|
30 |
+
def logits_monitor(args, running_logits, references, predictions, identifier="a"):
|
31 |
+
output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}-i{identifier}.npz"
|
32 |
+
|
33 |
+
raw_output = torch.cat(
|
34 |
+
[
|
35 |
+
torch.cat(running_logits, dim=0).cpu(),
|
36 |
+
torch.Tensor(references).unsqueeze(1),
|
37 |
+
torch.Tensor(predictions).unsqueeze(1),
|
38 |
+
torch.Tensor(np.arange(int(identifier) - len(references), int(identifier))).unsqueeze(1),
|
39 |
+
],
|
40 |
+
dim=1,
|
41 |
+
)
|
42 |
+
np.savez_compressed(output_path, raw_output.cpu().data.numpy())
|
43 |
+
tqdm.write("saved raw test outputs to {}".format(output_path))
|
44 |
+
|
45 |
+
|
46 |
+
def monitor_cleanup(args, buffer_keys):
|
47 |
+
"""
|
48 |
+
This merges all previous buffers to 1 file
|
49 |
+
"""
|
50 |
+
output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}"
|
51 |
+
|
52 |
+
for i, identifier in enumerate(buffer_keys):
|
53 |
+
identifier_path = f"{output_path}-i{identifier}.npz"
|
54 |
+
saved = np.load(identifier_path)["arr_0"]
|
55 |
+
if i == 0:
|
56 |
+
catted = saved
|
57 |
+
else:
|
58 |
+
catted = np.concatenate([catted, saved])
|
59 |
+
out_path = f"{output_path}-final.npz"
|
60 |
+
np.savez_compressed(out_path, catted)
|
61 |
+
tqdm.write("saved raw test outputs to {}".format(out_path))
|
62 |
+
# cleanup
|
63 |
+
for i, identifier in enumerate(buffer_keys):
|
64 |
+
identifier_path = f"{output_path}-i{identifier}.npz"
|
65 |
+
os.remove(identifier_path)
|
66 |
+
|
67 |
+
|
68 |
+
def main(args):
|
69 |
+
testds = load_dataset(
|
70 |
+
args.dataset,
|
71 |
+
cache_dir="/mnt/lerna/data/HFcache",
|
72 |
+
split="test",
|
73 |
+
revision=None if args.dataset != "bdpc/rvl_cdip_mp" else "d3a654c9f63f14d0aaa94e08aa30aa3dc20713c1",
|
74 |
+
)
|
75 |
+
|
76 |
+
if args.downsampling:
|
77 |
+
testds = testds.select(list(range(0, args.downsampling)))
|
78 |
+
|
79 |
+
model = AutoModelForImageClassification.from_pretrained(args.model)
|
80 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
81 |
+
model.to(device)
|
82 |
+
label2idx = {label: i for label, i in model.config.label2id.items()} # .replace(" ", "_")
|
83 |
+
print(label2idx)
|
84 |
+
|
85 |
+
data_idx2label = dict(enumerate(testds.features["labels"].names))
|
86 |
+
model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
|
87 |
+
diff = [i for i in range(len(data_idx2label)) if data_idx2label[i] != model_idx2label[i]]
|
88 |
+
if diff:
|
89 |
+
print(f"aligning labels {diff}")
|
90 |
+
testds = testds.align_labels_with_mapping(label2idx, "labels")
|
91 |
+
|
92 |
+
inference_method = InferenceMethod[args.inference_method.upper()]
|
93 |
+
dummy_inference_method = inference_method
|
94 |
+
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model)
|
95 |
+
|
96 |
+
features = {
|
97 |
+
**{k: v for k, v in testds.features.items() if k in ["labels", "pixel_values", "id"]},
|
98 |
+
"pages": Value(dtype="int32"),
|
99 |
+
"pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)),
|
100 |
+
}
|
101 |
+
if not "sample" in inference_method.scope:
|
102 |
+
features["pixel_values"] = Array4D(dtype="float32", shape=(None, 3, 224, 224))
|
103 |
+
dummy_inference_method = InferenceMethod["max_confidence".upper()]
|
104 |
+
features = Features(features)
|
105 |
+
|
106 |
+
remove_columns = ["file"]
|
107 |
+
if args.dataset == "bdpc/rvl_cdip_mp":
|
108 |
+
image_preprocessor = lambda batch: pdf_to_pixelvalues_extractor(
|
109 |
+
batch, feature_extractor, dummy_inference_method
|
110 |
+
)
|
111 |
+
encoded_testds = testds.map(
|
112 |
+
image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues"
|
113 |
+
)
|
114 |
+
else:
|
115 |
+
image_preprocessor = lambda batch: nativepdf_to_pixelvalues_extractor(
|
116 |
+
batch, feature_extractor, dummy_inference_method
|
117 |
+
)
|
118 |
+
encoded_testds = testds.map(
|
119 |
+
image_preprocessor,
|
120 |
+
features=features,
|
121 |
+
remove_columns=remove_columns,
|
122 |
+
desc="pdf_to_pixelvalues",
|
123 |
+
batch_size=10,
|
124 |
+
)
|
125 |
+
# remove_columns.append("images")
|
126 |
+
|
127 |
+
# select approach
|
128 |
+
print(f"Before filtering: {len(encoded_testds)}")
|
129 |
+
more_complex_filter = lambda example: example["pages"] != 0 and not np.any(np.isnan(example["pixel_values"]))
|
130 |
+
good_indices = [i for i, x in tqdm(enumerate(encoded_testds), desc="filter") if more_complex_filter(x)]
|
131 |
+
encoded_testds = encoded_testds.select(good_indices)
|
132 |
+
print(f"After filtering: {len(encoded_testds)}")
|
133 |
+
|
134 |
+
metric = evaluate.load("accuracy")
|
135 |
+
|
136 |
+
# going to have to manually iterate without dataloader and do tensor conversion
|
137 |
+
encoded_testds.set_format(type="torch", columns=["pixel_values", "labels"])
|
138 |
+
args.batch_size = args.batch_size if "sample" in inference_method.scope else 1
|
139 |
+
dataloader = torch.utils.data.DataLoader(encoded_testds, batch_size=args.batch_size)
|
140 |
+
|
141 |
+
running_logits = []
|
142 |
+
predictions, references = [], []
|
143 |
+
buffer_references = []
|
144 |
+
buffer_predictions = []
|
145 |
+
buffer = 0
|
146 |
+
BUFFER_SIZE = 5000
|
147 |
+
buffer_keys = []
|
148 |
+
for i, batch in tqdm(enumerate(dataloader), desc="Inference loop"):
|
149 |
+
with torch.no_grad():
|
150 |
+
batch["labels"] = batch["labels"].to(device)
|
151 |
+
batch["pixel_values"] = batch["pixel_values"].to(device)
|
152 |
+
if "sample" in inference_method.scope:
|
153 |
+
outputs = model(batch["pixel_values"].to(device))
|
154 |
+
logits = outputs.logits
|
155 |
+
buffer_predictions.extend(logits.argmax(-1).tolist())
|
156 |
+
buffer_references.extend(batch["labels"].tolist())
|
157 |
+
running_logits.append(logits)
|
158 |
+
else:
|
159 |
+
try:
|
160 |
+
page_logits = model(batch["pixel_values"][0]).logits
|
161 |
+
except Exception as e:
|
162 |
+
print(f"something went wrong in inference {e}")
|
163 |
+
continue
|
164 |
+
prediction = inference_method.apply_decision_strategy(page_logits) # apply logic depending on method
|
165 |
+
buffer_predictions.append(prediction.tolist())
|
166 |
+
buffer_references.extend(batch["labels"].tolist())
|
167 |
+
running_logits.append(page_logits.mean(0).unsqueeze(0)) # average over pages as representative
|
168 |
+
|
169 |
+
buffer += args.batch_size
|
170 |
+
if buffer >= BUFFER_SIZE:
|
171 |
+
predictions.extend(buffer_predictions)
|
172 |
+
references.extend(buffer_references)
|
173 |
+
logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i))
|
174 |
+
buffer_keys.append(str(i))
|
175 |
+
running_logits = []
|
176 |
+
buffer_references = []
|
177 |
+
buffer_predictions = []
|
178 |
+
buffer = 0
|
179 |
+
|
180 |
+
if buffer != 0: # dump remaining out of buffer
|
181 |
+
predictions.extend(buffer_predictions)
|
182 |
+
references.extend(buffer_references)
|
183 |
+
logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i))
|
184 |
+
buffer_keys.append(str(i))
|
185 |
+
|
186 |
+
accuracy = metric.compute(references=references, predictions=predictions)
|
187 |
+
print(f"Accuracy on this inference configuration {inference_method}:", accuracy)
|
188 |
+
monitor_cleanup(args, buffer_keys)
|
189 |
+
|
190 |
+
|
191 |
+
if __name__ == "__main__":
|
192 |
+
from argparse import ArgumentParser
|
193 |
+
|
194 |
+
parser = ArgumentParser("""Test different inference strategies to classify a document""")
|
195 |
+
parser.add_argument(
|
196 |
+
"inference_method",
|
197 |
+
type=str,
|
198 |
+
default="first",
|
199 |
+
nargs="?",
|
200 |
+
help="how to evaluate DiT on RVL-CDIP_multi",
|
201 |
+
)
|
202 |
+
parser.add_argument("-s", dest="downsampling", type=int, default=0, help="number of testset samples")
|
203 |
+
parser.add_argument("-d", dest="dataset", type=str, default="bdpc/rvl_cdip_mp", help="the dataset to be evaluated")
|
204 |
+
parser.add_argument(
|
205 |
+
"-m",
|
206 |
+
dest="model",
|
207 |
+
type=str,
|
208 |
+
default="microsoft/dit-base-finetuned-rvlcdip",
|
209 |
+
help="the model checkpoint to be evaluated",
|
210 |
+
)
|
211 |
+
parser.add_argument("-b", dest="batch_size", type=int, default=16, help="batch size")
|
212 |
+
parser.add_argument(
|
213 |
+
"-k",
|
214 |
+
dest="keep_in_memory",
|
215 |
+
default=False,
|
216 |
+
action="store_true",
|
217 |
+
help="do not cache operations (for testing)",
|
218 |
+
)
|
219 |
+
|
220 |
+
args = parser.parse_args()
|
221 |
+
|
222 |
+
main(args)
|