bdpc commited on
Commit
1ceb840
1 Parent(s): 8237fd2

Upload 9 files

Browse files
DiT_inference.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+ import torch
5
+ from datasets import load_dataset, ClassLabel
6
+ from datasets import Features, Array3D
7
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification
8
+ from metrics import apply_metrics
9
+
10
+
11
+ def process_label_ids(batch, remapper, label_column="label"):
12
+ batch[label_column] = [remapper[label_id] for label_id in batch[label_column]]
13
+ return batch
14
+
15
+
16
+ CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None
17
+
18
+
19
+ def main(args):
20
+ dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR)
21
+ if args.dataset == "rvl_cdip":
22
+ dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) # corrupt sample
23
+ batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000
24
+
25
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
26
+ model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
27
+
28
+ label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()}
29
+ data_idx2label = dict(zip(enumerate(dataset.features["label"].names)))
30
+ data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)}
31
+ model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
32
+ diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]]
33
+
34
+ if diff:
35
+ print(f"aligning labels {diff}")
36
+ print(f"model labels: {model_idx2label}")
37
+ print(f"data labels: {data_idx2label}")
38
+ print(f"Remapping to {label2idx}") # YET cannot change the length of the labels
39
+
40
+ remapper = {}
41
+ for k, v in label2idx.items():
42
+ if k in data_label2idx:
43
+ remapper[data_label2idx[k]] = v
44
+
45
+ print(remapper)
46
+ new_features = Features(
47
+ {
48
+ **{k: v for k, v in dataset.features.items() if k != "label"},
49
+ "label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())),
50
+ }
51
+ )
52
+
53
+ dataset = dataset.map(
54
+ lambda example: process_label_ids(example, remapper),
55
+ features=new_features,
56
+ batched=True,
57
+ batch_size=batch_size,
58
+ desc="Aligning the labels",
59
+ )
60
+
61
+ features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))})
62
+
63
+ encoded_dataset = dataset.map(
64
+ lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]),
65
+ batched=True,
66
+ batch_size=batch_size,
67
+ features=features,
68
+ )
69
+ encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"])
70
+ BATCH_SIZE = 16
71
+ dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE)
72
+
73
+ all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros(
74
+ len(encoded_dataset), dtype=int
75
+ )
76
+
77
+ count = 0
78
+ for i, batch in tqdm(enumerate(dataloader)):
79
+ with torch.no_grad():
80
+ outputs = model(batch["pixel_values"])
81
+ logits = outputs.logits
82
+ all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy()
83
+ all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy()
84
+ count += len(batch["label"])
85
+
86
+ all_references = np.array(all_references)
87
+ all_logits = np.array(all_logits)
88
+ results = apply_metrics(all_references, all_logits)
89
+ print(results)
90
+
91
+
92
+ if __name__ == "__main__":
93
+ from argparse import ArgumentParser
94
+
95
+ parser = ArgumentParser("""DiT inference on dataset test set""")
96
+ parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated")
97
+ args = parser.parse_args()
98
+
99
+ main(args)
README.md ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Beyond Document Page Classification
2
+
3
+ ## Installation
4
+
5
+ The scripts require [python >= 3.8](https://www.python.org/downloads/release/python-380/) to run.
6
+ We will create a fresh virtualenvironment in which to install all required packages.
7
+ ```sh
8
+ mkvirtualenv -p /usr/bin/python3 BYD
9
+ ```
10
+
11
+ Using poetry and the readily defined pyproject.toml, we will install all required packages
12
+ ```sh
13
+ workon BYD
14
+ pip3 install poetry
15
+ poetry install
16
+ ```
17
+
18
+ ## something
experiments.sh ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Single page f_p experiments
2
+ command="python3 DiT_inference.py -d jordyvl/RVL_CDIP_N" ; echo $command ; $command
3
+ command="python3 DiT_inference.py -d jordyvl/RVL_CDIP_N" ; echo $command ; $command
4
+
5
+
6
+ # Multi-page f_d experiments
7
+ command="python3 simulate_document_classifier.py grid -d bdpc/rvl_cdip_mp" ; echo $command ; $command
8
+ command="python3 simulate_document_classifier.py first -d bdpc/rvl_cdip_mp" ; echo $command ; $command
9
+ command="python3 simulate_document_classifier.py second -d bdpc/rvl_cdip_mp" ; echo $command ; $command
10
+ command="python3 simulate_document_classifier.py last -d bdpc/rvl_cdip_mp" ; echo $command ; $command
11
+ command="python3 simulate_document_classifier.py max_confidence -d bdpc/rvl_cdip_mp" ; echo $command ; $command
12
+ command="python3 simulate_document_classifier.py soft_voting -d bdpc/rvl_cdip_mp" ; echo $command ; $command
13
+ command="python3 simulate_document_classifier.py hard_voting -d bdpc/rvl_cdip_mp" ; echo $command ; $command
14
+
15
+
16
+ command="python3 simulate_document_classifier.py grid -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
17
+ command="python3 simulate_document_classifier.py first -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
18
+ command="python3 simulate_document_classifier.py second -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
19
+ command="python3 simulate_document_classifier.py last -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
20
+ command="python3 simulate_document_classifier.py hard_voting -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
21
+ command="python3 simulate_document_classifier.py soft_voting -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
22
+ command="python3 simulate_document_classifier.py max_confidence -d bdpc/rvl_cdip_n_mp" ; echo $command ; $command
23
+
inference_methods.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from PIL import Image
3
+
4
+
5
+ def unravel_index(index, shape):
6
+ out = []
7
+ for dim in reversed(shape):
8
+ out.append(index % dim)
9
+ index = index // dim
10
+ return tuple(reversed(out))
11
+
12
+
13
+ class ExplicitEnum(Enum):
14
+ """
15
+ Enum with more explicit error message for missing values or getting all options
16
+ """
17
+
18
+ @classmethod
19
+ def _missing_(cls, value):
20
+ raise ValueError(
21
+ f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
22
+ )
23
+
24
+ @classmethod
25
+ def options(cls):
26
+ return list(cls._value2member_map_.keys())
27
+
28
+
29
+ class InferenceMethod(ExplicitEnum):
30
+ """All the implemented inference methods"""
31
+
32
+ FIRST = "first" # check on data setup
33
+ SECOND = "second" # robustness for multipage categories
34
+ LAST = "last" # robustness for multipage categories
35
+
36
+ GRID = "grid" # create a grid (equal spaced/ OCR density based)
37
+ # downscale resolution ; might be fine for classification
38
+
39
+ MAX_CONFIDENCE = "max_confidence" # page with highest confidence overall
40
+ SOFT_VOTING = "soft_voting" # sum conf/N -> logits/softmax
41
+ HARD_VOTING = "hard_voting" # count votes
42
+
43
+ @property
44
+ def scope(self):
45
+ if self in [InferenceMethod.FIRST, InferenceMethod.SECOND, InferenceMethod.LAST]:
46
+ return "sample"
47
+ if self in [InferenceMethod.GRID]:
48
+ return "sample-grid" # single image yet transformation required
49
+ else:
50
+ return "iter"
51
+
52
+ def get_page_scope(self, pages):
53
+ if self.scope == "iter":
54
+ return pages
55
+ if self == InferenceMethod.GRID:
56
+ try:
57
+ return equal_image_grid(pages)
58
+ except Exception as e:
59
+ return pages[-1] # last to not positively bias
60
+ if self == InferenceMethod.FIRST:
61
+ return pages[0]
62
+ if self == InferenceMethod.SECOND:
63
+ if len(pages) > 1:
64
+ return pages[1]
65
+ return pages[0] # backoff
66
+ if self == InferenceMethod.LAST:
67
+ return pages[-1]
68
+
69
+ def apply_decision_strategy(self, page_logits):
70
+ """
71
+ page logits is of shape [NUM_PAGES x CLASSES]
72
+ """
73
+ if self == InferenceMethod.MAX_CONFIDENCE:
74
+ index = page_logits.argmax() # tensor with one number
75
+ indices = unravel_index(index, page_logits.shape)
76
+ print(f"The page which is max confident: {indices[0]}")
77
+ return indices[-1]
78
+ if self == InferenceMethod.HARD_VOTING:
79
+ return page_logits.argmax(-1).max()
80
+ if self == InferenceMethod.SOFT_VOTING:
81
+ return page_logits.mean(0).argmax(-1)
82
+
83
+
84
+ def equal_image_grid(images):
85
+ def compute_grid(n, max_cols=6):
86
+ equalDivisor = int(n**0.5)
87
+ cols = min(equalDivisor, max_cols)
88
+ rows = equalDivisor
89
+ if rows * cols >= n:
90
+ return rows, cols
91
+ cols += 1
92
+ if rows * cols >= n:
93
+ return rows, cols
94
+ while rows * cols < n:
95
+ rows += 1
96
+ return rows, cols
97
+
98
+ # assert len(images) == rows*cols
99
+ rows, cols = compute_grid(len(images))
100
+
101
+ # rescaling to min width [height padding]
102
+ images = [im for im in images if (im.height > 0) and (im.width > 0)] # could be NA
103
+
104
+ min_width = min(im.width for im in images)
105
+ images = [im.resize((min_width, int(im.height * min_width / im.width)), resample=Image.BICUBIC) for im in images]
106
+
107
+ w, h = max([img.size[0] for img in images]), max([img.size[1] for img in images])
108
+
109
+ grid = Image.new("RGB", size=(cols * w, rows * h))
110
+ grid_w, grid_h = grid.size
111
+
112
+ for i, img in enumerate(images):
113
+ grid.paste(img, box=(i % cols * w, i // cols * h))
114
+ return grid
load_predictions.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import pandas as pd
4
+ from metrics import ece_logits, aurc_logits, multi_aurc_plot, apply_metrics
5
+ from sklearn.metrics import f1_score
6
+ from collections import OrderedDict
7
+
8
+ EXPERIMENT_ROOT = "/mnt/lerna/experiments"
9
+
10
+
11
+ def softmax(x, axis=-1):
12
+ # Subtract the maximum value for numerical stability
13
+ x = x - np.max(x, axis=axis, keepdims=True)
14
+
15
+ # Compute the exponentials of the shifted input
16
+ exps = np.exp(x)
17
+
18
+ # Compute the sum of exponentials along the last axis
19
+ exps_sum = np.sum(exps, axis=axis, keepdims=True)
20
+
21
+ # Compute the softmax probabilities
22
+ softmax_probs = exps / exps_sum
23
+
24
+ return softmax_probs
25
+
26
+
27
+ def predictions_loader(predictions_path):
28
+ data = np.load(predictions_path)["arr_0"]
29
+ dataset_idx = data[:, -1]
30
+ labels = data[:, -2]
31
+ if "DiT-base-rvl_cdip_MP" in predictions_path and any(x in predictions_path for x in ["first", "second", "last"]):
32
+ data = data[:, :-2] # logits
33
+ predictions = np.argmax(data, -1)
34
+ else:
35
+ labels = data[:, -2].astype(int)
36
+ predictions = data[:, -3].astype(int)
37
+ data = data[:, :-3] # logits
38
+ return data, labels, predictions, dataset_idx
39
+
40
+
41
+ def compare_errors():
42
+ """
43
+ from scipy.stats import pearsonr, spearmanr
44
+ #idx = [x for x in strategy_correctness['first'] if x ==0]
45
+ spearmanr(strategy_correctness['first'], strategy_correctness['second'])
46
+ #SignificanceResult(statistic=0.5429413617297623, pvalue=0.0)
47
+ spearmanr(strategy_correctness['first'], strategy_correctness['last'])
48
+ #SignificanceResult(statistic=0.5005224326802595, pvalue=0.0)
49
+
50
+ pearsonr(strategy_correctness['first'], strategy_correctness['second'])
51
+ #PearsonRResult(statistic=0.5429413617297617, pvalue=0.0)
52
+ pearsonr(strategy_correctness['first'], strategy_correctness['last'])
53
+ #PearsonRResult(statistic=0.5005224326802583, pvalue=0.0)
54
+ """
55
+ for dataset in ["rvl_cdip_n_mp"]: # "DiT-base-rvl_cdip_MP",
56
+ strategy_logits = {}
57
+ strategy_correctness = {}
58
+ for strategy in ["first", "second", "last"]:
59
+ path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz"
60
+
61
+ strategy_logits[strategy], labels, predictions, dataset_idx = predictions_loader(path)
62
+ strategy_correctness[strategy] = (predictions == labels).astype(int)
63
+
64
+ print("Base accuracy of first: ", np.mean(strategy_correctness["first"]))
65
+ firstcorrectifsecondcorrect = [
66
+ x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["first"])
67
+ ] # if x ==0]
68
+ print(f"Accuracy of first when adding knowledge from second page: {np.mean(firstcorrectifsecondcorrect)}")
69
+ firstcorrectiflastcorrect = [
70
+ x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["first"])
71
+ ] # if x ==0]
72
+ print(f"Accuracy of first when adding knowledge from last page: {np.mean(firstcorrectiflastcorrect)}")
73
+
74
+ firstcorrectifsecondorlastcorrect = [
75
+ x if x == 1 else (strategy_correctness["second"][i] or strategy_correctness["last"][i])
76
+ for i, x in enumerate(strategy_correctness["first"])
77
+ ] # if x ==0]
78
+ print(
79
+ f"Accuracy of first when adding knowledge from second/last page: {np.mean(firstcorrectifsecondorlastcorrect)}"
80
+ )
81
+
82
+ # inverse
83
+ print("Base accuracy of second: ", np.mean(strategy_correctness["second"]))
84
+ secondcorrectiffirstcorrect = [
85
+ x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["second"])
86
+ ] # if x ==0]
87
+ print(f"Accuracy of second when adding knowledge from first page: {np.mean(secondcorrectiffirstcorrect)}")
88
+ secondcorrectiflastcorrect = [
89
+ x if x == 1 else strategy_correctness["last"][i] for i, x in enumerate(strategy_correctness["second"])
90
+ ] # if x ==0]
91
+ print(f"Accuracy of second when adding knowledge from last page: {np.mean(secondcorrectiflastcorrect)}")
92
+
93
+ # inverse second
94
+ print("Base accuracy of last: ", np.mean(strategy_correctness["last"]))
95
+ lastcorrectiffirstcorrect = [
96
+ x if x == 1 else strategy_correctness["first"][i] for i, x in enumerate(strategy_correctness["last"])
97
+ ] # if x ==0]
98
+ print(f"Accuracy of last when adding knowledge from first page: {np.mean(lastcorrectiffirstcorrect)}")
99
+ lastcorrectifsecondcorrect = [
100
+ x if x == 1 else strategy_correctness["second"][i] for i, x in enumerate(strategy_correctness["last"])
101
+ ] # if x ==0]
102
+ print(f"Accuracy of last when adding knowledge from second page: {np.mean(lastcorrectifsecondcorrect)}")
103
+
104
+
105
+ def review_one(path):
106
+ collect = OrderedDict()
107
+ try:
108
+ logits, labels, predictions, dataset_idx = predictions_loader(path)
109
+ except Exception as e:
110
+ print(f"something went wrong in inference loading {e}")
111
+ return
112
+ # print(logits.shape, labels.shape, logits[-1], labels[-1], dataset_idx[-1])
113
+ y_correct = (predictions == labels).astype(int)
114
+ acc = np.mean(y_correct)
115
+ p_hat = np.array([softmax(p, -1)[predictions[i]] for i, p in enumerate(logits)])
116
+
117
+ res = aurc_logits(
118
+ y_correct, p_hat, plot=False, get_cache=True, use_as_is=True
119
+ ) # DEV: implementation hack to allow for passing I[Y==y_hat] and p_hat instead of logits and label indices
120
+
121
+ collect["aurc"] = res["aurc"]
122
+ collect["accuracy"] = 100 * acc
123
+ collect["f1"] = 100 * f1_score(labels, predictions, average="weighted")
124
+ collect["f1_macro"] = 100 * f1_score(labels, predictions, average="macro")
125
+ collect["ece"] = ece_logits(np.logical_not(y_correct), np.expand_dims(p_hat, -1), use_as_is=True)
126
+
127
+ df = pd.DataFrame.from_dict([collect])
128
+ # df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]]
129
+ print(df.to_latex())
130
+ print(df.to_string())
131
+ return collect, res
132
+
133
+
134
+ def experiments_review():
135
+ STRATEGIES = ["first", "second", "last", "max_confidence", "soft_voting", "hard_voting", "grid"]
136
+ for dataset in ["DiT-base-rvl_cdip_MP", "rvl_cdip_n_mp"]:
137
+ collect = {}
138
+ aurcs = []
139
+ caches = []
140
+ for strategy in STRATEGIES:
141
+ path = f"{EXPERIMENT_ROOT}/{dataset}/dit-base-finetuned-rvlcdip_{strategy}-0-final.npz"
142
+ collect[strategy], res = review_one(path)
143
+ aurcs.append(res["aurc"])
144
+ caches.append(res["cache"])
145
+
146
+ df = pd.DataFrame.from_dict(collect, orient="index")
147
+ df = df[["accuracy", "f1", "f1_macro", "ece", "aurc"]]
148
+ print(df.to_latex())
149
+ print(df.to_string())
150
+ """
151
+ subset = [0, 1, 2]
152
+ multi_aurc_plot(
153
+ [x for i, x in enumerate(caches) if i in subset],
154
+ [x for i, x in enumerate(STRATEGIES) if i in subset],
155
+ aurcs=[x for i, x in enumerate(aurcs) if i in subset],
156
+ )
157
+ """
158
+
159
+
160
+ if __name__ == "__main__":
161
+ from argparse import ArgumentParser
162
+
163
+ parser = ArgumentParser("""Deeper evaluation of different inference strategies to classify a document""")
164
+ DEFAULT = "./dit-base-finetuned-rvlcdip_last-10.npz"
165
+ parser.add_argument(
166
+ "predictions_path",
167
+ type=str,
168
+ default=DEFAULT,
169
+ nargs="?",
170
+ help="path to predictions",
171
+ )
172
+
173
+ args = parser.parse_args()
174
+ if args.predictions_path == DEFAULT:
175
+ experiments_review()
176
+ compare_errors()
177
+ sys.exit(1)
178
+
179
+ print(f"Running default experiment on {args.predictions_path}")
180
+ review_one(args.predictions_path)
mapping_functions.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from io import BytesIO
3
+ from tqdm import tqdm
4
+ import numpy as np
5
+ from typing import Callable, Dict, List # , Literal, NamedTuple, Optional, Tuple, Union
6
+ from PIL import Image as PIL_Image
7
+ from PIL.Image import Image
8
+
9
+ from datasets import logging
10
+
11
+ logger = logging.get_logger(__name__)
12
+ import PyPDF2
13
+
14
+ MAX_PAGES = 50
15
+ MAX_PDF_SIZE = 100000000 # almost 100MB
16
+ MIN_WIDTH, MIN_HEIGHT = 150, 150
17
+ import pdf2image
18
+
19
+
20
+ def pdf2image_image_extraction(pdf_stream):
21
+ try:
22
+ images: List[Image] = pdf2image.convert_from_bytes(pdf_stream)
23
+ return images
24
+ except Exception as e:
25
+ logger.warning(f"{e}")
26
+
27
+
28
+ def pdf_to_pixelvalues_extractor(example, feature_extractor, inference_method):
29
+ example["pages"] = 0
30
+ example["pixel_values"] = None
31
+ pixel_values = []
32
+ if len(example["file"]) > MAX_PDF_SIZE:
33
+ logger.warning(f"too large file {len(example['file'])}")
34
+ return example
35
+ try:
36
+ reader = PyPDF2.PdfReader(BytesIO(example["file"]))
37
+ except Exception as e:
38
+ logger.warning(f"read_pdf {e}")
39
+ return example
40
+ example["pages"] = len(reader.pages)
41
+ reached_page_limit = False
42
+ if "sample" in inference_method.scope and inference_method.scope != "sample-grid":
43
+ page_iterator = [inference_method.get_page_scope(reader.pages)]
44
+ else:
45
+ page_iterator = reader.pages
46
+
47
+ try:
48
+ for p, page in enumerate(page_iterator):
49
+ if reached_page_limit:
50
+ break
51
+ for image in page.images:
52
+ if len(pixel_values) == MAX_PAGES:
53
+ reached_page_limit = True
54
+ break
55
+ im = PIL_Image.open(BytesIO(image.data))
56
+ if im.width < MIN_WIDTH and im.height < MIN_HEIGHT:
57
+ continue
58
+ # try:
59
+ # except Exception as e:
60
+ # logger.warning(f"get_images {e}")
61
+ if inference_method.scope != "sample-grid":
62
+ im = feature_extractor([im.convert("RGB")])["pixel_values"][0]
63
+ pixel_values.append(im)
64
+ except Exception as e:
65
+ print(f"{example.get('id')} PyPDF get_images {e}")
66
+ pixel_values = []
67
+
68
+ if len(pixel_values) == 0:
69
+ # at least try with another API
70
+ try:
71
+ images = pdf2image_image_extraction(example["file"])
72
+ except Exception as e:
73
+ print(f"{example.get('id')} pdf2image get_images {e}")
74
+ images = []
75
+
76
+ if not images:
77
+ print(f"{example.get('id')} pdf2image has no images")
78
+ example["pages"] = 0
79
+ return example
80
+
81
+ # got lucky with pdf2image
82
+ example["pages"] = len(images)
83
+ for im in images:
84
+ if len(pixel_values) == MAX_PAGES:
85
+ reached_page_limit = True
86
+ break
87
+ if im.width < MIN_WIDTH and im.height < MIN_HEIGHT:
88
+ continue
89
+ if inference_method.scope != "sample-grid":
90
+ im = feature_extractor([im.convert("RGB")])["pixel_values"][0]
91
+ pixel_values.append(im)
92
+
93
+ if inference_method.scope == "sample-grid":
94
+ grid = inference_method.get_page_scope(pixel_values)
95
+ pixel_values = feature_extractor([grid.convert("RGB")])["pixel_values"][0]
96
+ elif "sample" in inference_method.scope:
97
+ pixel_values = pixel_values[0]
98
+ example["pixel_values"] = np.array(pixel_values)
99
+ return example
100
+
101
+
102
+ def nativepdf_to_pixelvalues_extractor(example, feature_extractor, inference_method):
103
+ IMPOSSIBLE = ["6483941-Letter-to-John-Campbell.pdf", "7276809-Ocoee-Newspaper-Pages.pdf"]
104
+ example["pages"] = 0
105
+ example["pixel_values"] = None
106
+ pixel_values = []
107
+ if len(example["file"]) > MAX_PDF_SIZE:
108
+ logger.warning(f"too large file {len(example['file'])}")
109
+ return example
110
+
111
+ # images = example['images']
112
+ try:
113
+ images = pdf2image_image_extraction(example["file"])
114
+ except Exception as e:
115
+ print(f"{example.get('id')} pdf2image get_images {e}")
116
+ images = []
117
+
118
+ if not images:
119
+ print(f"{example.get('id')} pdf2image has no images")
120
+ example["pages"] = 0
121
+ return example
122
+
123
+ # do image checks before and after
124
+ images = [im for im in images if im.width >= MIN_WIDTH and im.height >= MIN_HEIGHT]
125
+
126
+ if not images or (example.get("id") in IMPOSSIBLE and inference_method.scope == "sample-grid"):
127
+ print(f"{example.get('id')} pdf2image has no images")
128
+ example["pages"] = 0
129
+ return example
130
+
131
+ example["pages"] = len(images)
132
+ reached_page_limit = False
133
+ if "sample" in inference_method.scope and inference_method.scope != "sample-grid":
134
+ page_iterator = [inference_method.get_page_scope(images)]
135
+ else:
136
+ page_iterator = images
137
+
138
+ for im in page_iterator:
139
+ if len(pixel_values) == MAX_PAGES:
140
+ reached_page_limit = True
141
+ break
142
+ if inference_method.scope != "sample-grid":
143
+ im = feature_extractor([im.convert("RGB")])["pixel_values"][0]
144
+ pixel_values.append(im)
145
+
146
+ if len(pixel_values) == 0:
147
+ print(f"{example.get('id')} pdf2image has no valid images")
148
+ example["pages"] = 0
149
+ return example
150
+
151
+ if inference_method.scope == "sample-grid":
152
+ grid = inference_method.get_page_scope(pixel_values)
153
+ pixel_values = feature_extractor([grid.convert("RGB")])["pixel_values"][0]
154
+ elif "sample" in inference_method.scope:
155
+ pixel_values = pixel_values[0]
156
+ example["pixel_values"] = np.array(pixel_values)
157
+ return example
metrics.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # base calibration metric
2
+
3
+ # https://github.com/JonathanWenger/pycalib/blob/master/pycalib/scoring.py
4
+ # https://github.com/google-research/robustness_metrics/blob/master/robustness_metrics/metrics/uncertainty.py
5
+ # https://github.com/kjdhfg/fd-shifts
6
+
7
+ from __future__ import annotations
8
+ import scipy
9
+ import sklearn.utils.validation
10
+
11
+ from dataclasses import dataclass
12
+ from functools import cached_property
13
+ from typing import Any
14
+ import pandas as pd
15
+ import numpy as np
16
+ import numpy.typing as npt
17
+ from collections import OrderedDict
18
+ from sklearn import metrics as skm
19
+ import evaluate as HF_evaluate
20
+
21
+ ArrayType = npt.NDArray[np.floating]
22
+
23
+
24
+ ## https://github.com/IML-DKFZ/fd-shifts/blob/main/fd_shifts/analysis/confid_scores.py#L20
25
+
26
+ # ----------------------------------- general metrics with consistent metric(y_true, p_hat) API -----------------------------------
27
+
28
+
29
+ def f1_w(y_true, p_hat, y_hat=None):
30
+ if y_hat is None:
31
+ y_hat = np.argmax(p_hat, axis=-1)
32
+ return skm.f1_score(y_true, y_hat, average="weighted")
33
+
34
+
35
+ def f1_micro(y_true, p_hat, y_hat=None):
36
+ if y_hat is None:
37
+ y_hat = np.argmax(p_hat, axis=-1)
38
+ return skm.f1_score(y_true, y_hat, average="micro")
39
+
40
+
41
+ def f1_macro(y_true, p_hat, y_hat=None):
42
+ if y_hat is None:
43
+ y_hat = np.argmax(p_hat, axis=-1)
44
+ return skm.f1_score(y_true, y_hat, average="macro")
45
+
46
+
47
+ # Pure numpy and TF implementations of proper losses (as metrics) -----------------------------------
48
+
49
+
50
+ def brier_loss(y_true, p_hat):
51
+ r"""Brier score.
52
+ If the true label is k, while the predicted vector of probabilities is
53
+ [y_1, ..., y_n], then the Brier score is equal to
54
+ \sum_{i != k} y_i^2 + (y_k - 1)^2.
55
+
56
+ The smaller the Brier score, the better, hence the naming with "loss".
57
+ Across all items in a set N predictions, the Brier score measures the
58
+ mean squared difference between (1) the predicted probability assigned
59
+ to the possible outcomes for item i, and (2) the actual outcome.
60
+ Therefore, the lower the Brier score is for a set of predictions, the
61
+ better the predictions are calibrated. Note that the Brier score always
62
+ takes on a value between zero and one, since this is the largest
63
+ possible difference between a predicted probability (which must be
64
+ between zero and one) and the actual outcome (which can take on values
65
+ of only 0 and 1). The Brier loss is composed of refinement loss and
66
+ calibration loss.
67
+
68
+ """
69
+ N = len(y_true)
70
+ K = p_hat.shape[-1]
71
+
72
+ if y_true.shape != p_hat.shape:
73
+ zeros = scipy.sparse.lil_matrix((N, K))
74
+ for i in range(N):
75
+ zeros[i, y_true[i]] = 1
76
+
77
+ if not np.isclose(np.sum(p_hat), len(p_hat)):
78
+ p_hat = scipy.special.softmax(p_hat, axis=-1)
79
+
80
+ return np.mean(np.sum(np.array(p_hat - zeros) ** 2, axis=1))
81
+
82
+
83
+ def nll(y_true, p_hat):
84
+ r"""Multi-class negative log likelihood.
85
+ If the true label is k, while the predicted vector of probabilities is
86
+ [p_1, ..., p_K], then the negative log likelihood is -log(p_k).
87
+ Does not require onehot encoding
88
+ """
89
+ labels = np.arange(p_hat.shape[-1])
90
+ return skm.log_loss(y_true, p_hat, labels=labels)
91
+
92
+
93
+ def accuracy(y_true, p_hat):
94
+ y_pred = np.argmax(p_hat, axis=-1)
95
+ return sklearn.metrics.accuracy_score(y_true=y_true, y_pred=y_pred)
96
+
97
+
98
+ AURC_DISPLAY_SCALE = 1 # 1000
99
+
100
+ """
101
+ From: https://web.stanford.edu/class/archive/cs/cs224n/cs224n.1204/reports/custom/report52.pdf
102
+
103
+ The risk-coverage (RC) curve [28, 16] is a measure of the trade-off between the
104
+ coverage (the proportion of test data encountered), and the risk (the error rate under this coverage). Since each
105
+ prediction comes with a confidence score, given a list of prediction correctness Z paired up with the confidence
106
+ scores C, we sort C in reverse order to obtain sorted C'
107
+ , and its corresponding correctness Z'
108
+ . Note that the correctness is computed based on Exact Match (EM) as described in [22]. The RC curve is then obtained by
109
+ computing the risk of the coverage from the beginning of Z'
110
+ (most confident) to the end (least confident). In particular, these metrics evaluate
111
+ the relative order of the confidence score, which means that we want wrong
112
+ answers have lower confidence score than the correct ones, ignoring their absolute values.
113
+
114
+ Source: https://github.com/kjdhfg/fd-shifts
115
+
116
+ References:
117
+ -----------
118
+
119
+ [1] Jaeger, P.F., Lüth, C.T., Klein, L. and Bungert, T.J., 2022. A Call to Reflect on Evaluation Practices for Failure Detection in Image Classification. arXiv preprint arXiv:2211.15259.
120
+
121
+ [2] Kamath, A., Jia, R. and Liang, P., 2020. Selective Question Answering under Domain Shift. In Proceedings of the 58th Annual Meeting of the Association for Computational Linguistics (pp. 5684-5696).
122
+
123
+ """
124
+
125
+
126
+ @dataclass
127
+ class StatsCache:
128
+ """Cache for stats computed by scikit used by multiple metrics.
129
+
130
+ Attributes:
131
+ confids (array_like): Confidence values
132
+ correct (array_like): Boolean array (best converted to int) where predictions were correct
133
+ """
134
+
135
+ confids: npt.NDArray[Any]
136
+ correct: npt.NDArray[Any]
137
+
138
+ @cached_property
139
+ def roc_curve_stats(self) -> tuple[npt.NDArray[Any], npt.NDArray[Any]]:
140
+ fpr, tpr, _ = skm.roc_curve(self.correct, self.confids)
141
+ return fpr, tpr
142
+
143
+ @property
144
+ def residuals(self) -> npt.NDArray[Any]:
145
+ return 1 - self.correct
146
+
147
+ @cached_property
148
+ def rc_curve_stats(self) -> tuple[list[float], list[float], list[float]]:
149
+ coverages = []
150
+ risks = []
151
+
152
+ n_residuals = len(self.residuals)
153
+ idx_sorted = np.argsort(self.confids)
154
+
155
+ coverage = n_residuals
156
+ error_sum = sum(self.residuals[idx_sorted])
157
+
158
+ coverages.append(coverage / n_residuals)
159
+ risks.append(error_sum / n_residuals)
160
+
161
+ weights = []
162
+
163
+ tmp_weight = 0
164
+ for i in range(0, len(idx_sorted) - 1):
165
+ coverage = coverage - 1
166
+ error_sum = error_sum - self.residuals[idx_sorted[i]]
167
+ selective_risk = error_sum / (n_residuals - 1 - i)
168
+ tmp_weight += 1
169
+ if i == 0 or self.confids[idx_sorted[i]] != self.confids[idx_sorted[i - 1]]:
170
+ coverages.append(coverage / n_residuals)
171
+ risks.append(selective_risk)
172
+ weights.append(tmp_weight / n_residuals)
173
+ tmp_weight = 0
174
+
175
+ # add a well-defined final point to the RC-curve.
176
+ if tmp_weight > 0:
177
+ coverages.append(0)
178
+ risks.append(risks[-1])
179
+ weights.append(tmp_weight / n_residuals)
180
+ return coverages, risks, weights
181
+
182
+
183
+ def aurc(stats_cache: StatsCache):
184
+ """auc metric function
185
+ Args:
186
+ stats_cache (StatsCache): StatsCache object
187
+ Returns:
188
+ metric value
189
+ Important for assessment: LOWER is better!
190
+ """
191
+ _, risks, weights = stats_cache.rc_curve_stats
192
+ return sum([(risks[i] + risks[i + 1]) * 0.5 * weights[i] for i in range(len(weights))]) * AURC_DISPLAY_SCALE
193
+
194
+
195
+ def aurc_logits(references, predictions, plot=False, get_cache=False, use_as_is=False):
196
+ if not use_as_is:
197
+ if not np.isclose(np.sum(references), len(references)):
198
+ references = (np.argmax(predictions, -1) == references).astype(int) # correctness
199
+
200
+ if not np.isclose(np.sum(predictions), len(predictions)):
201
+ predictions = scipy.special.softmax(predictions, axis=-1)
202
+
203
+ if predictions.ndim == 2:
204
+ predictions = np.max(predictions, -1)
205
+
206
+ cache = StatsCache(confids=predictions, correct=references)
207
+
208
+ if plot:
209
+ coverages, risks, weights = cache.rc_curve_stats
210
+ pd.options.plotting.backend = "plotly"
211
+ df = pd.DataFrame(zip(coverages, risks, weights), columns=["% Coverage", "% Risk", "weights"])
212
+ fig = df.plot(x="% Coverage", y="% Risk")
213
+ fig.show()
214
+ if get_cache:
215
+ return {"aurc": aurc(cache), "cache": cache}
216
+ return aurc(cache)
217
+
218
+
219
+ def multi_aurc_plot(caches, names, aurcs=None, verbose=False):
220
+ pd.options.plotting.backend = "plotly"
221
+ df = pd.DataFrame()
222
+ for cache, name in zip(caches, names):
223
+ coverages, risks, weights = cache.rc_curve_stats
224
+ df[name] = pd.Series(risks, index=coverages)
225
+ if verbose:
226
+ print(df.head(), df.index, df.columns)
227
+ fig = df.plot()
228
+ title = ""
229
+ if aurcs is not None:
230
+ title = "AURC: " + " - ".join([str(round(aurc, 4)) for aurc in aurcs])
231
+ fig.update_layout(title=title, xaxis_title="% Coverage", yaxis_title="% Risk")
232
+ fig.show()
233
+
234
+
235
+ def ece_logits(references, predictions, use_as_is=False):
236
+ if not use_as_is:
237
+ if not np.isclose(np.sum(predictions), len(predictions)):
238
+ predictions = scipy.special.softmax(predictions, axis=-1)
239
+
240
+ metric = HF_evaluate.load("jordyvl/ece")
241
+ kwargs = dict(
242
+ n_bins=min(len(predictions) - 1, 100),
243
+ scheme="equal-mass",
244
+ bin_range=[0, 1],
245
+ proxy="upper-edge",
246
+ p=1,
247
+ detail=False,
248
+ )
249
+
250
+ ece_result = metric.compute(
251
+ references=references,
252
+ predictions=predictions,
253
+ **kwargs,
254
+ )
255
+ return ece_result["ECE"]
256
+
257
+
258
+ METRICS = [accuracy, brier_loss, nll, f1_w, f1_macro, ece_logits, aurc_logits]
259
+
260
+
261
+ def apply_metrics(y_true, y_probs, metrics=METRICS):
262
+ predictive_performance = OrderedDict()
263
+ for metric in metrics:
264
+ try:
265
+ predictive_performance[f"{metric.__name__.replace('_logits', '')}"] = metric(y_true, y_probs)
266
+ except Exception as e:
267
+ print(e)
268
+ # print(json.dumps(predictive_performance, indent=4))
269
+ return predictive_performance
270
+
271
+
272
+ def evaluate_coverages(
273
+ logits, labels, confidence, coverages=[100, 99, 98, 97, 95, 90, 85, 80, 75, 70, 60, 50, 40, 30, 20, 10]
274
+ ):
275
+ correctness = np.equal(logits.argmax(-1), labels)
276
+ abstention_results = list(zip(list(confidence), list(correctness)))
277
+ # sort the abstention results according to their reservations, from high to low
278
+ abstention_results.sort(key=lambda x: x[0])
279
+ # get the "correct or not" list for the sorted results
280
+ sorted_correct = list(map(lambda x: int(x[1]), abstention_results))
281
+ size = len(sorted_correct)
282
+ print("Abstention Logit: accuracy of coverage ") # 1-risk
283
+ for coverage in coverages:
284
+ covered_correct = sorted_correct[: round(size / 100 * coverage)]
285
+ print("{:.0f}: {:.3f}, ".format(coverage, sum(covered_correct) / len(covered_correct) * 100.0), end="")
286
+ print("")
287
+
288
+ sr_results = list(zip(list(logits.max(-1)), list(correctness)))
289
+ # sort the abstention results according to Softmax Response scores, from high to low
290
+ sr_results.sort(key=lambda x: -x[0])
291
+ # get the "correct or not" list for the sorted results
292
+ sorted_correct = list(map(lambda x: int(x[1]), sr_results))
293
+ size = len(sorted_correct)
294
+ print("Softmax Response: accuracy of coverage ")
295
+ for coverage in coverages:
296
+ covered_correct = sorted_correct[: round(size / 100 * coverage)]
297
+ print("{:.0f}: {:.3f}, ".format(coverage, sum(covered_correct) / len(covered_correct) * 100.0), end="")
298
+ print("")
299
+
300
+
301
+ def compute_metrics(eval_preds):
302
+ logits, labels = eval_preds # output of forward
303
+ if isinstance(logits, tuple):
304
+ confidence = logits[1]
305
+ logits = logits[0]
306
+ if confidence.size == logits.shape[0]:
307
+ evaluate_coverages(logits, labels, confidence)
308
+ results = apply_metrics(labels, logits)
309
+ return results
pyproject.toml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [tool.poetry]
2
+ name = "BDPC"
3
+ version = "0.1.0"
4
+ description = "something"
5
+ authors = ["ANON"]
6
+
7
+ readme = 'README.md' # Markdown files are supported
8
+ keywords = ['research', 'evaluation', 'classification', 'documetns']
9
+
10
+ classifiers = [
11
+ "Development Status :: 1 - Planning",
12
+ "Environment :: Web Environment",
13
+ "Operating System :: POSIX :: Linux",
14
+ "Programming Language :: Python",
15
+ "Programming Language :: Python :: 3 :: Only",
16
+ "Topic :: Scientific/Engineering :: Artificial Intelligence"
17
+ ]
18
+
19
+ [tool.poetry.dependencies]
20
+ python = "^3.8"
21
+ numpy = "^1.24.1"
22
+ evaluate = "^0.4.0"
23
+ scikit-learn = "^1.2.0"
24
+ transformers = "^4.26.0"
25
+ pandas = "^1.5.3"
26
+ torch = "^1.13.1"
27
+ pillow = "^9.4.0"
28
+
29
+ [tool.poetry.dev-dependencies]
30
+ pytest = "^3.10.1"
31
+ pytest-cov = "^2.9.0"
32
+
33
+
34
+ [build-system]
35
+ requires = ["poetry-core"]
36
+ build-backend = "poetry.masonry.api"
simulate_document_classifier.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from tqdm import tqdm
3
+ import pandas as pd
4
+ import numpy as np
5
+ import torch
6
+ from datasets import load_dataset, logging
7
+ from datasets import Features, Value, Image, Sequence, Array3D, Array4D
8
+ import evaluate
9
+ from metrics import apply_metrics
10
+
11
+ from transformers import AutoFeatureExtractor, AutoModelForImageClassification # DiT
12
+
13
+ logger = logging.get_logger(__name__)
14
+
15
+ from mapping_functions import (
16
+ pdf_to_pixelvalues_extractor,
17
+ nativepdf_to_pixelvalues_extractor,
18
+ )
19
+ from inference_methods import InferenceMethod
20
+
21
+ EXPERIMENT_ROOT = "/mnt/lerna/experiments"
22
+
23
+
24
+ def load_base_model():
25
+ feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
26
+ model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
27
+ return model, feature_extractor
28
+
29
+
30
+ def logits_monitor(args, running_logits, references, predictions, identifier="a"):
31
+ output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}-i{identifier}.npz"
32
+
33
+ raw_output = torch.cat(
34
+ [
35
+ torch.cat(running_logits, dim=0).cpu(),
36
+ torch.Tensor(references).unsqueeze(1),
37
+ torch.Tensor(predictions).unsqueeze(1),
38
+ torch.Tensor(np.arange(int(identifier) - len(references), int(identifier))).unsqueeze(1),
39
+ ],
40
+ dim=1,
41
+ )
42
+ np.savez_compressed(output_path, raw_output.cpu().data.numpy())
43
+ tqdm.write("saved raw test outputs to {}".format(output_path))
44
+
45
+
46
+ def monitor_cleanup(args, buffer_keys):
47
+ """
48
+ This merges all previous buffers to 1 file
49
+ """
50
+ output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}"
51
+
52
+ for i, identifier in enumerate(buffer_keys):
53
+ identifier_path = f"{output_path}-i{identifier}.npz"
54
+ saved = np.load(identifier_path)["arr_0"]
55
+ if i == 0:
56
+ catted = saved
57
+ else:
58
+ catted = np.concatenate([catted, saved])
59
+ out_path = f"{output_path}-final.npz"
60
+ np.savez_compressed(out_path, catted)
61
+ tqdm.write("saved raw test outputs to {}".format(out_path))
62
+ # cleanup
63
+ for i, identifier in enumerate(buffer_keys):
64
+ identifier_path = f"{output_path}-i{identifier}.npz"
65
+ os.remove(identifier_path)
66
+
67
+
68
+ def main(args):
69
+ testds = load_dataset(
70
+ args.dataset,
71
+ cache_dir="/mnt/lerna/data/HFcache",
72
+ split="test",
73
+ revision=None if args.dataset != "bdpc/rvl_cdip_mp" else "d3a654c9f63f14d0aaa94e08aa30aa3dc20713c1",
74
+ )
75
+
76
+ if args.downsampling:
77
+ testds = testds.select(list(range(0, args.downsampling)))
78
+
79
+ model = AutoModelForImageClassification.from_pretrained(args.model)
80
+ device = "cuda" if torch.cuda.is_available() else "cpu"
81
+ model.to(device)
82
+ label2idx = {label: i for label, i in model.config.label2id.items()} # .replace(" ", "_")
83
+ print(label2idx)
84
+
85
+ data_idx2label = dict(enumerate(testds.features["labels"].names))
86
+ model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
87
+ diff = [i for i in range(len(data_idx2label)) if data_idx2label[i] != model_idx2label[i]]
88
+ if diff:
89
+ print(f"aligning labels {diff}")
90
+ testds = testds.align_labels_with_mapping(label2idx, "labels")
91
+
92
+ inference_method = InferenceMethod[args.inference_method.upper()]
93
+ dummy_inference_method = inference_method
94
+ feature_extractor = AutoFeatureExtractor.from_pretrained(args.model)
95
+
96
+ features = {
97
+ **{k: v for k, v in testds.features.items() if k in ["labels", "pixel_values", "id"]},
98
+ "pages": Value(dtype="int32"),
99
+ "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)),
100
+ }
101
+ if not "sample" in inference_method.scope:
102
+ features["pixel_values"] = Array4D(dtype="float32", shape=(None, 3, 224, 224))
103
+ dummy_inference_method = InferenceMethod["max_confidence".upper()]
104
+ features = Features(features)
105
+
106
+ remove_columns = ["file"]
107
+ if args.dataset == "bdpc/rvl_cdip_mp":
108
+ image_preprocessor = lambda batch: pdf_to_pixelvalues_extractor(
109
+ batch, feature_extractor, dummy_inference_method
110
+ )
111
+ encoded_testds = testds.map(
112
+ image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues"
113
+ )
114
+ else:
115
+ image_preprocessor = lambda batch: nativepdf_to_pixelvalues_extractor(
116
+ batch, feature_extractor, dummy_inference_method
117
+ )
118
+ encoded_testds = testds.map(
119
+ image_preprocessor,
120
+ features=features,
121
+ remove_columns=remove_columns,
122
+ desc="pdf_to_pixelvalues",
123
+ batch_size=10,
124
+ )
125
+ # remove_columns.append("images")
126
+
127
+ # select approach
128
+ print(f"Before filtering: {len(encoded_testds)}")
129
+ more_complex_filter = lambda example: example["pages"] != 0 and not np.any(np.isnan(example["pixel_values"]))
130
+ good_indices = [i for i, x in tqdm(enumerate(encoded_testds), desc="filter") if more_complex_filter(x)]
131
+ encoded_testds = encoded_testds.select(good_indices)
132
+ print(f"After filtering: {len(encoded_testds)}")
133
+
134
+ metric = evaluate.load("accuracy")
135
+
136
+ # going to have to manually iterate without dataloader and do tensor conversion
137
+ encoded_testds.set_format(type="torch", columns=["pixel_values", "labels"])
138
+ args.batch_size = args.batch_size if "sample" in inference_method.scope else 1
139
+ dataloader = torch.utils.data.DataLoader(encoded_testds, batch_size=args.batch_size)
140
+
141
+ running_logits = []
142
+ predictions, references = [], []
143
+ buffer_references = []
144
+ buffer_predictions = []
145
+ buffer = 0
146
+ BUFFER_SIZE = 5000
147
+ buffer_keys = []
148
+ for i, batch in tqdm(enumerate(dataloader), desc="Inference loop"):
149
+ with torch.no_grad():
150
+ batch["labels"] = batch["labels"].to(device)
151
+ batch["pixel_values"] = batch["pixel_values"].to(device)
152
+ if "sample" in inference_method.scope:
153
+ outputs = model(batch["pixel_values"].to(device))
154
+ logits = outputs.logits
155
+ buffer_predictions.extend(logits.argmax(-1).tolist())
156
+ buffer_references.extend(batch["labels"].tolist())
157
+ running_logits.append(logits)
158
+ else:
159
+ try:
160
+ page_logits = model(batch["pixel_values"][0]).logits
161
+ except Exception as e:
162
+ print(f"something went wrong in inference {e}")
163
+ continue
164
+ prediction = inference_method.apply_decision_strategy(page_logits) # apply logic depending on method
165
+ buffer_predictions.append(prediction.tolist())
166
+ buffer_references.extend(batch["labels"].tolist())
167
+ running_logits.append(page_logits.mean(0).unsqueeze(0)) # average over pages as representative
168
+
169
+ buffer += args.batch_size
170
+ if buffer >= BUFFER_SIZE:
171
+ predictions.extend(buffer_predictions)
172
+ references.extend(buffer_references)
173
+ logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i))
174
+ buffer_keys.append(str(i))
175
+ running_logits = []
176
+ buffer_references = []
177
+ buffer_predictions = []
178
+ buffer = 0
179
+
180
+ if buffer != 0: # dump remaining out of buffer
181
+ predictions.extend(buffer_predictions)
182
+ references.extend(buffer_references)
183
+ logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i))
184
+ buffer_keys.append(str(i))
185
+
186
+ accuracy = metric.compute(references=references, predictions=predictions)
187
+ print(f"Accuracy on this inference configuration {inference_method}:", accuracy)
188
+ monitor_cleanup(args, buffer_keys)
189
+
190
+
191
+ if __name__ == "__main__":
192
+ from argparse import ArgumentParser
193
+
194
+ parser = ArgumentParser("""Test different inference strategies to classify a document""")
195
+ parser.add_argument(
196
+ "inference_method",
197
+ type=str,
198
+ default="first",
199
+ nargs="?",
200
+ help="how to evaluate DiT on RVL-CDIP_multi",
201
+ )
202
+ parser.add_argument("-s", dest="downsampling", type=int, default=0, help="number of testset samples")
203
+ parser.add_argument("-d", dest="dataset", type=str, default="bdpc/rvl_cdip_mp", help="the dataset to be evaluated")
204
+ parser.add_argument(
205
+ "-m",
206
+ dest="model",
207
+ type=str,
208
+ default="microsoft/dit-base-finetuned-rvlcdip",
209
+ help="the model checkpoint to be evaluated",
210
+ )
211
+ parser.add_argument("-b", dest="batch_size", type=int, default=16, help="batch size")
212
+ parser.add_argument(
213
+ "-k",
214
+ dest="keep_in_memory",
215
+ default=False,
216
+ action="store_true",
217
+ help="do not cache operations (for testing)",
218
+ )
219
+
220
+ args = parser.parse_args()
221
+
222
+ main(args)