import os from tqdm import tqdm import pandas as pd import numpy as np import torch from datasets import load_dataset, logging from datasets import Features, Value, Image, Sequence, Array3D, Array4D import evaluate from metrics import apply_metrics from transformers import AutoFeatureExtractor, AutoModelForImageClassification # DiT logger = logging.get_logger(__name__) from mapping_functions import ( pdf_to_pixelvalues_extractor, nativepdf_to_pixelvalues_extractor, ) from inference_methods import InferenceMethod EXPERIMENT_ROOT = "/mnt/lerna/experiments" def load_base_model(): feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") return model, feature_extractor def logits_monitor(args, running_logits, references, predictions, identifier="a"): output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}-i{identifier}.npz" raw_output = torch.cat( [ torch.cat(running_logits, dim=0).cpu(), torch.Tensor(references).unsqueeze(1), torch.Tensor(predictions).unsqueeze(1), torch.Tensor(np.arange(int(identifier) - len(references), int(identifier))).unsqueeze(1), ], dim=1, ) np.savez_compressed(output_path, raw_output.cpu().data.numpy()) tqdm.write("saved raw test outputs to {}".format(output_path)) def monitor_cleanup(args, buffer_keys): """ This merges all previous buffers to 1 file """ output_path = f"{EXPERIMENT_ROOT}/{args.model.split('/')[-1]}_{args.dataset.split('/')[-1]}_{args.inference_method}-{args.downsampling}" for i, identifier in enumerate(buffer_keys): identifier_path = f"{output_path}-i{identifier}.npz" saved = np.load(identifier_path)["arr_0"] if i == 0: catted = saved else: catted = np.concatenate([catted, saved]) out_path = f"{output_path}-final.npz" np.savez_compressed(out_path, catted) tqdm.write("saved raw test outputs to {}".format(out_path)) # cleanup for i, identifier in enumerate(buffer_keys): identifier_path = f"{output_path}-i{identifier}.npz" os.remove(identifier_path) def main(args): testds = load_dataset( args.dataset, cache_dir="/mnt/lerna/data/HFcache", split="test", revision=None if args.dataset != "bdpc/rvl_cdip_mp" else "d3a654c9f63f14d0aaa94e08aa30aa3dc20713c1", ) if args.downsampling: testds = testds.select(list(range(0, args.downsampling))) model = AutoModelForImageClassification.from_pretrained(args.model) device = "cuda" if torch.cuda.is_available() else "cpu" model.to(device) label2idx = {label: i for label, i in model.config.label2id.items()} # .replace(" ", "_") print(label2idx) data_idx2label = dict(enumerate(testds.features["labels"].names)) model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) diff = [i for i in range(len(data_idx2label)) if data_idx2label[i] != model_idx2label[i]] if diff: print(f"aligning labels {diff}") testds = testds.align_labels_with_mapping(label2idx, "labels") inference_method = InferenceMethod[args.inference_method.upper()] dummy_inference_method = inference_method feature_extractor = AutoFeatureExtractor.from_pretrained(args.model) features = { **{k: v for k, v in testds.features.items() if k in ["labels", "pixel_values", "id"]}, "pages": Value(dtype="int32"), "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224)), } if not "sample" in inference_method.scope: features["pixel_values"] = Array4D(dtype="float32", shape=(None, 3, 224, 224)) dummy_inference_method = InferenceMethod["max_confidence".upper()] features = Features(features) remove_columns = ["file"] if args.dataset == "bdpc/rvl_cdip_mp": image_preprocessor = lambda batch: pdf_to_pixelvalues_extractor( batch, feature_extractor, dummy_inference_method ) encoded_testds = testds.map( image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues" ) else: image_preprocessor = lambda batch: nativepdf_to_pixelvalues_extractor( batch, feature_extractor, dummy_inference_method ) encoded_testds = testds.map( image_preprocessor, features=features, remove_columns=remove_columns, desc="pdf_to_pixelvalues", batch_size=10, ) # remove_columns.append("images") # select approach print(f"Before filtering: {len(encoded_testds)}") more_complex_filter = lambda example: example["pages"] != 0 and not np.any(np.isnan(example["pixel_values"])) good_indices = [i for i, x in tqdm(enumerate(encoded_testds), desc="filter") if more_complex_filter(x)] encoded_testds = encoded_testds.select(good_indices) print(f"After filtering: {len(encoded_testds)}") metric = evaluate.load("accuracy") # going to have to manually iterate without dataloader and do tensor conversion encoded_testds.set_format(type="torch", columns=["pixel_values", "labels"]) args.batch_size = args.batch_size if "sample" in inference_method.scope else 1 dataloader = torch.utils.data.DataLoader(encoded_testds, batch_size=args.batch_size) running_logits = [] predictions, references = [], [] buffer_references = [] buffer_predictions = [] buffer = 0 BUFFER_SIZE = 5000 buffer_keys = [] for i, batch in tqdm(enumerate(dataloader), desc="Inference loop"): with torch.no_grad(): batch["labels"] = batch["labels"].to(device) batch["pixel_values"] = batch["pixel_values"].to(device) if "sample" in inference_method.scope: outputs = model(batch["pixel_values"].to(device)) logits = outputs.logits buffer_predictions.extend(logits.argmax(-1).tolist()) buffer_references.extend(batch["labels"].tolist()) running_logits.append(logits) else: try: page_logits = model(batch["pixel_values"][0]).logits except Exception as e: print(f"something went wrong in inference {e}") continue prediction = inference_method.apply_decision_strategy(page_logits) # apply logic depending on method buffer_predictions.append(prediction.tolist()) buffer_references.extend(batch["labels"].tolist()) running_logits.append(page_logits.mean(0).unsqueeze(0)) # average over pages as representative buffer += args.batch_size if buffer >= BUFFER_SIZE: predictions.extend(buffer_predictions) references.extend(buffer_references) logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i)) buffer_keys.append(str(i)) running_logits = [] buffer_references = [] buffer_predictions = [] buffer = 0 if buffer != 0: # dump remaining out of buffer predictions.extend(buffer_predictions) references.extend(buffer_references) logits_monitor(args, running_logits, buffer_references, buffer_predictions, identifier=str(i)) buffer_keys.append(str(i)) accuracy = metric.compute(references=references, predictions=predictions) print(f"Accuracy on this inference configuration {inference_method}:", accuracy) monitor_cleanup(args, buffer_keys) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser("""Test different inference strategies to classify a document""") parser.add_argument( "inference_method", type=str, default="first", nargs="?", help="how to evaluate DiT on RVL-CDIP_multi", ) parser.add_argument("-s", dest="downsampling", type=int, default=0, help="number of testset samples") parser.add_argument("-d", dest="dataset", type=str, default="bdpc/rvl_cdip_mp", help="the dataset to be evaluated") parser.add_argument( "-m", dest="model", type=str, default="microsoft/dit-base-finetuned-rvlcdip", help="the model checkpoint to be evaluated", ) parser.add_argument("-b", dest="batch_size", type=int, default=16, help="batch size") parser.add_argument( "-k", dest="keep_in_memory", default=False, action="store_true", help="do not cache operations (for testing)", ) args = parser.parse_args() main(args)