src / simulate_document_classifier.py
bdpc's picture
Upload 9 files
1ceb840
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import torch
from datasets import load_dataset, logging
from datasets import Features, Value, Image, Sequence, Array3D, Array4D
import evaluate
from metrics import apply_metrics
from transformers import AutoFeatureExtractor, AutoModelForImageClassification # DiT
logger = logging.get_logger(__name__)
from mapping_functions import (
pdf_to_pixelvalues_extractor,
nativepdf_to_pixelvalues_extractor,
)
from inference_methods import InferenceMethod
EXPERIMENT_ROOT = "/mnt/lerna/experiments"
def load_base_model():
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
return model, feature_extractor
def logits_monitor(args, running_logits, references, predictions, identifier="a"):
output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}-i{identifier}.npz"
raw_output = torch.cat(
[
torch.cat(running_logits, dim=0).cpu(),
torch.Tensor(references).unsqueeze(1),
torch.Tensor(predictions).unsqueeze(1),
torch.Tensor(np.arange(int(identifier) - len(references), int(identifier))).unsqueeze(1),
],
dim=1,
)
np.savez_compressed(output_path, raw_output.cpu().data.numpy())
tqdm.write("saved raw test outputs to {}".format(output_path))
def monitor_cleanup(args, buffer_keys):
"""
This merges all previous buffers to 1 file
"""
output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}"
for i, identifier in enumerate(buffer_keys):
identifier_path = f"{output_path}-i{identifier}.npz"
saved = np.load(identifier_path)["arr_0"]
if i == 0:
catted = saved
else:
catted = np.concatenate([catted, saved])
out_path = f"{output_path}-final.npz"
np.savez_compressed(out_path, catted)
tqdm.write("saved raw test outputs to {}".format(out_path))
# cleanup
for i, identifier in enumerate(buffer_keys):
identifier_path = f"{output_path}-i{identifier}.npz"
os.remove(identifier_path)
def main(args):
testds = load_dataset(
args.dataset,
cache_dir="/mnt/lerna/data/HFcache",
split="test",
revision=None if args.dataset != "bdpc/rvl_cdip_mp" else "d3a654c9f63f14d0aaa94e08aa30aa3dc20713c1",
)
if args.downsampling:
testds = testds.select(list(range(0, args.downsampling)))
model = AutoModelForImageClassification.from_pretrained(args.model)
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
label2idx = {label: i for label, i in model.config.label2id.items()} # .replace(" ", "_")
print(label2idx)
data_idx2label = dict(enumerate(testds.features["labels"].names))
model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
diff = [i for i in range(len(data_idx2label)) if data_idx2label[i] != model_idx2label[i]]
if diff:
print(f"aligning labels {diff}")
testds = testds.align_labels_with_mapping(label2idx, "labels")
inference_method = InferenceMethod[args.inference_method.upper()]
dummy_inference_method = inference_method
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model)
features = {
**{k: v for k, v in testds.features.items() if k in ["labels", "pixel_values", "id"]},
"pages": Value(dtype="int32"),
"pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)),
}
if not "sample" in inference_method.scope:
features["pixel_values"] = Array4D(dtype="float32", shape=(None, 3, 224, 224))
dummy_inference_method = InferenceMethod["max_confidence".upper()]
features = Features(features)
remove_columns = ["file"]
if args.dataset == "bdpc/rvl_cdip_mp":
image_preprocessor = lambda batch: pdf_to_pixelvalues_extractor(
batch, feature_extractor, dummy_inference_method
)
encoded_testds = testds.map(
image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues"
)
else:
image_preprocessor = lambda batch: nativepdf_to_pixelvalues_extractor(
batch, feature_extractor, dummy_inference_method
)
encoded_testds = testds.map(
image_preprocessor,
features=features,
remove_columns=remove_columns,
desc="pdf_to_pixelvalues",
batch_size=10,
)
# remove_columns.append("images")
# select approach
print(f"Before filtering: {len(encoded_testds)}")
more_complex_filter = lambda example: example["pages"] != 0 and not np.any(np.isnan(example["pixel_values"]))
good_indices = [i for i, x in tqdm(enumerate(encoded_testds), desc="filter") if more_complex_filter(x)]
encoded_testds = encoded_testds.select(good_indices)
print(f"After filtering: {len(encoded_testds)}")
metric = evaluate.load("accuracy")
# going to have to manually iterate without dataloader and do tensor conversion
encoded_testds.set_format(type="torch", columns=["pixel_values", "labels"])
args.batch_size = args.batch_size if "sample" in inference_method.scope else 1
dataloader = torch.utils.data.DataLoader(encoded_testds, batch_size=args.batch_size)
running_logits = []
predictions, references = [], []
buffer_references = []
buffer_predictions = []
buffer = 0
BUFFER_SIZE = 5000
buffer_keys = []
for i, batch in tqdm(enumerate(dataloader), desc="Inference loop"):
with torch.no_grad():
batch["labels"] = batch["labels"].to(device)
batch["pixel_values"] = batch["pixel_values"].to(device)
if "sample" in inference_method.scope:
outputs = model(batch["pixel_values"].to(device))
logits = outputs.logits
buffer_predictions.extend(logits.argmax(-1).tolist())
buffer_references.extend(batch["labels"].tolist())
running_logits.append(logits)
else:
try:
page_logits = model(batch["pixel_values"][0]).logits
except Exception as e:
print(f"something went wrong in inference {e}")
continue
prediction = inference_method.apply_decision_strategy(page_logits) # apply logic depending on method
buffer_predictions.append(prediction.tolist())
buffer_references.extend(batch["labels"].tolist())
running_logits.append(page_logits.mean(0).unsqueeze(0)) # average over pages as representative
buffer += args.batch_size
if buffer >= BUFFER_SIZE:
predictions.extend(buffer_predictions)
references.extend(buffer_references)
logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i))
buffer_keys.append(str(i))
running_logits = []
buffer_references = []
buffer_predictions = []
buffer = 0
if buffer != 0: # dump remaining out of buffer
predictions.extend(buffer_predictions)
references.extend(buffer_references)
logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i))
buffer_keys.append(str(i))
accuracy = metric.compute(references=references, predictions=predictions)
print(f"Accuracy on this inference configuration {inference_method}:", accuracy)
monitor_cleanup(args, buffer_keys)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser("""Test different inference strategies to classify a document""")
parser.add_argument(
"inference_method",
type=str,
default="first",
nargs="?",
help="how to evaluate DiT on RVL-CDIP_multi",
)
parser.add_argument("-s", dest="downsampling", type=int, default=0, help="number of testset samples")
parser.add_argument("-d", dest="dataset", type=str, default="bdpc/rvl_cdip_mp", help="the dataset to be evaluated")
parser.add_argument(
"-m",
dest="model",
type=str,
default="microsoft/dit-base-finetuned-rvlcdip",
help="the model checkpoint to be evaluated",
)
parser.add_argument("-b", dest="batch_size", type=int, default=16, help="batch size")
parser.add_argument(
"-k",
dest="keep_in_memory",
default=False,
action="store_true",
help="do not cache operations (for testing)",
)
args = parser.parse_args()
main(args)