|
import os |
|
from tqdm import tqdm |
|
import pandas as pd |
|
import numpy as np |
|
import torch |
|
from datasets import load_dataset, logging |
|
from datasets import Features, Value, Image, Sequence, Array3D, Array4D |
|
import evaluate |
|
from metrics import apply_metrics |
|
|
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
from mapping_functions import ( |
|
pdf_to_pixelvalues_extractor, |
|
nativepdf_to_pixelvalues_extractor, |
|
) |
|
from inference_methods import InferenceMethod |
|
|
|
EXPERIMENT_ROOT = "/mnt/lerna/experiments" |
|
|
|
|
|
def load_base_model(): |
|
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
|
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
|
return model, feature_extractor |
|
|
|
|
|
def logits_monitor(args, running_logits, references, predictions, identifier="a"): |
|
output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}-i{identifier}.npz" |
|
|
|
raw_output = torch.cat( |
|
[ |
|
torch.cat(running_logits, dim=0).cpu(), |
|
torch.Tensor(references).unsqueeze(1), |
|
torch.Tensor(predictions).unsqueeze(1), |
|
torch.Tensor(np.arange(int(identifier) - len(references), int(identifier))).unsqueeze(1), |
|
], |
|
dim=1, |
|
) |
|
np.savez_compressed(output_path, raw_output.cpu().data.numpy()) |
|
tqdm.write("saved raw test outputs to {}".format(output_path)) |
|
|
|
|
|
def monitor_cleanup(args, buffer_keys): |
|
""" |
|
This merges all previous buffers to 1 file |
|
""" |
|
output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}" |
|
|
|
for i, identifier in enumerate(buffer_keys): |
|
identifier_path = f"{output_path}-i{identifier}.npz" |
|
saved = np.load(identifier_path)["arr_0"] |
|
if i == 0: |
|
catted = saved |
|
else: |
|
catted = np.concatenate([catted, saved]) |
|
out_path = f"{output_path}-final.npz" |
|
np.savez_compressed(out_path, catted) |
|
tqdm.write("saved raw test outputs to {}".format(out_path)) |
|
|
|
for i, identifier in enumerate(buffer_keys): |
|
identifier_path = f"{output_path}-i{identifier}.npz" |
|
os.remove(identifier_path) |
|
|
|
|
|
def main(args): |
|
testds = load_dataset( |
|
args.dataset, |
|
cache_dir="/mnt/lerna/data/HFcache", |
|
split="test", |
|
revision=None if args.dataset != "bdpc/rvl_cdip_mp" else "d3a654c9f63f14d0aaa94e08aa30aa3dc20713c1", |
|
) |
|
|
|
if args.downsampling: |
|
testds = testds.select(list(range(0, args.downsampling))) |
|
|
|
model = AutoModelForImageClassification.from_pretrained(args.model) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
label2idx = {label: i for label, i in model.config.label2id.items()} |
|
print(label2idx) |
|
|
|
data_idx2label = dict(enumerate(testds.features["labels"].names)) |
|
model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) |
|
diff = [i for i in range(len(data_idx2label)) if data_idx2label[i] != model_idx2label[i]] |
|
if diff: |
|
print(f"aligning labels {diff}") |
|
testds = testds.align_labels_with_mapping(label2idx, "labels") |
|
|
|
inference_method = InferenceMethod[args.inference_method.upper()] |
|
dummy_inference_method = inference_method |
|
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model) |
|
|
|
features = { |
|
**{k: v for k, v in testds.features.items() if k in ["labels", "pixel_values", "id"]}, |
|
"pages": Value(dtype="int32"), |
|
"pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)), |
|
} |
|
if not "sample" in inference_method.scope: |
|
features["pixel_values"] = Array4D(dtype="float32", shape=(None, 3, 224, 224)) |
|
dummy_inference_method = InferenceMethod["max_confidence".upper()] |
|
features = Features(features) |
|
|
|
remove_columns = ["file"] |
|
if args.dataset == "bdpc/rvl_cdip_mp": |
|
image_preprocessor = lambda batch: pdf_to_pixelvalues_extractor( |
|
batch, feature_extractor, dummy_inference_method |
|
) |
|
encoded_testds = testds.map( |
|
image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues" |
|
) |
|
else: |
|
image_preprocessor = lambda batch: nativepdf_to_pixelvalues_extractor( |
|
batch, feature_extractor, dummy_inference_method |
|
) |
|
encoded_testds = testds.map( |
|
image_preprocessor, |
|
features=features, |
|
remove_columns=remove_columns, |
|
desc="pdf_to_pixelvalues", |
|
batch_size=10, |
|
) |
|
|
|
|
|
|
|
print(f"Before filtering: {len(encoded_testds)}") |
|
more_complex_filter = lambda example: example["pages"] != 0 and not np.any(np.isnan(example["pixel_values"])) |
|
good_indices = [i for i, x in tqdm(enumerate(encoded_testds), desc="filter") if more_complex_filter(x)] |
|
encoded_testds = encoded_testds.select(good_indices) |
|
print(f"After filtering: {len(encoded_testds)}") |
|
|
|
metric = evaluate.load("accuracy") |
|
|
|
|
|
encoded_testds.set_format(type="torch", columns=["pixel_values", "labels"]) |
|
args.batch_size = args.batch_size if "sample" in inference_method.scope else 1 |
|
dataloader = torch.utils.data.DataLoader(encoded_testds, batch_size=args.batch_size) |
|
|
|
running_logits = [] |
|
predictions, references = [], [] |
|
buffer_references = [] |
|
buffer_predictions = [] |
|
buffer = 0 |
|
BUFFER_SIZE = 5000 |
|
buffer_keys = [] |
|
for i, batch in tqdm(enumerate(dataloader), desc="Inference loop"): |
|
with torch.no_grad(): |
|
batch["labels"] = batch["labels"].to(device) |
|
batch["pixel_values"] = batch["pixel_values"].to(device) |
|
if "sample" in inference_method.scope: |
|
outputs = model(batch["pixel_values"].to(device)) |
|
logits = outputs.logits |
|
buffer_predictions.extend(logits.argmax(-1).tolist()) |
|
buffer_references.extend(batch["labels"].tolist()) |
|
running_logits.append(logits) |
|
else: |
|
try: |
|
page_logits = model(batch["pixel_values"][0]).logits |
|
except Exception as e: |
|
print(f"something went wrong in inference {e}") |
|
continue |
|
prediction = inference_method.apply_decision_strategy(page_logits) |
|
buffer_predictions.append(prediction.tolist()) |
|
buffer_references.extend(batch["labels"].tolist()) |
|
running_logits.append(page_logits.mean(0).unsqueeze(0)) |
|
|
|
buffer += args.batch_size |
|
if buffer >= BUFFER_SIZE: |
|
predictions.extend(buffer_predictions) |
|
references.extend(buffer_references) |
|
logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i)) |
|
buffer_keys.append(str(i)) |
|
running_logits = [] |
|
buffer_references = [] |
|
buffer_predictions = [] |
|
buffer = 0 |
|
|
|
if buffer != 0: |
|
predictions.extend(buffer_predictions) |
|
references.extend(buffer_references) |
|
logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i)) |
|
buffer_keys.append(str(i)) |
|
|
|
accuracy = metric.compute(references=references, predictions=predictions) |
|
print(f"Accuracy on this inference configuration {inference_method}:", accuracy) |
|
monitor_cleanup(args, buffer_keys) |
|
|
|
|
|
if __name__ == "__main__": |
|
from argparse import ArgumentParser |
|
|
|
parser = ArgumentParser("""Test different inference strategies to classify a document""") |
|
parser.add_argument( |
|
"inference_method", |
|
type=str, |
|
default="first", |
|
nargs="?", |
|
help="how to evaluate DiT on RVL-CDIP_multi", |
|
) |
|
parser.add_argument("-s", dest="downsampling", type=int, default=0, help="number of testset samples") |
|
parser.add_argument("-d", dest="dataset", type=str, default="bdpc/rvl_cdip_mp", help="the dataset to be evaluated") |
|
parser.add_argument( |
|
"-m", |
|
dest="model", |
|
type=str, |
|
default="microsoft/dit-base-finetuned-rvlcdip", |
|
help="the model checkpoint to be evaluated", |
|
) |
|
parser.add_argument("-b", dest="batch_size", type=int, default=16, help="batch size") |
|
parser.add_argument( |
|
"-k", |
|
dest="keep_in_memory", |
|
default=False, |
|
action="store_true", |
|
help="do not cache operations (for testing)", |
|
) |
|
|
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|