import os import numpy as np from tqdm import tqdm import torch from datasets import load_dataset, ClassLabel from datasets import Features, Array3D from transformers import AutoFeatureExtractor, AutoModelForImageClassification from metrics import apply_metrics def process_label_ids(batch, remapper, label_column="label"): batch[label_column] = [remapper[label_id] for label_id in batch[label_column]] return batch CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None def main(args): dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR) if args.dataset == "rvl_cdip": dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) # corrupt sample batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000 feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()} data_idx2label = dict(zip(enumerate(dataset.features["label"].names))) data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)} model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]] if diff: print(f"aligning labels {diff}") print(f"model labels: {model_idx2label}") print(f"data labels: {data_idx2label}") print(f"Remapping to {label2idx}") # YET cannot change the length of the labels remapper = {} for k, v in label2idx.items(): if k in data_label2idx: remapper[data_label2idx[k]] = v print(remapper) new_features = Features( { **{k: v for k, v in dataset.features.items() if k != "label"}, "label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())), } ) dataset = dataset.map( lambda example: process_label_ids(example, remapper), features=new_features, batched=True, batch_size=batch_size, desc="Aligning the labels", ) features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))}) encoded_dataset = dataset.map( lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]), batched=True, batch_size=batch_size, features=features, ) encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"]) BATCH_SIZE = 16 dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE) all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros( len(encoded_dataset), dtype=int ) count = 0 for i, batch in tqdm(enumerate(dataloader)): with torch.no_grad(): outputs = model(batch["pixel_values"]) logits = outputs.logits all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy() all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy() count += len(batch["label"]) all_references = np.array(all_references) all_logits = np.array(all_logits) results = apply_metrics(all_references, all_logits) print(results) if __name__ == "__main__": from argparse import ArgumentParser parser = ArgumentParser("""DiT inference on dataset test set""") parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated") args = parser.parse_args() main(args)