src / DiT_inference.py
bdpc's picture
Upload 9 files
1ceb840
import os
import numpy as np
from tqdm import tqdm
import torch
from datasets import load_dataset, ClassLabel
from datasets import Features, Array3D
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from metrics import apply_metrics
def process_label_ids(batch, remapper, label_column="label"):
batch[label_column] = [remapper[label_id] for label_id in batch[label_column]]
return batch
CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None
def main(args):
dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR)
if args.dataset == "rvl_cdip":
dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) # corrupt sample
batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()}
data_idx2label = dict(zip(enumerate(dataset.features["label"].names)))
data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)}
model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]]
if diff:
print(f"aligning labels {diff}")
print(f"model labels: {model_idx2label}")
print(f"data labels: {data_idx2label}")
print(f"Remapping to {label2idx}") # YET cannot change the length of the labels
remapper = {}
for k, v in label2idx.items():
if k in data_label2idx:
remapper[data_label2idx[k]] = v
print(remapper)
new_features = Features(
{
**{k: v for k, v in dataset.features.items() if k != "label"},
"label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())),
}
)
dataset = dataset.map(
lambda example: process_label_ids(example, remapper),
features=new_features,
batched=True,
batch_size=batch_size,
desc="Aligning the labels",
)
features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))})
encoded_dataset = dataset.map(
lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]),
batched=True,
batch_size=batch_size,
features=features,
)
encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"])
BATCH_SIZE = 16
dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE)
all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros(
len(encoded_dataset), dtype=int
)
count = 0
for i, batch in tqdm(enumerate(dataloader)):
with torch.no_grad():
outputs = model(batch["pixel_values"])
logits = outputs.logits
all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy()
all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy()
count += len(batch["label"])
all_references = np.array(all_references)
all_logits = np.array(all_logits)
results = apply_metrics(all_references, all_logits)
print(results)
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser("""DiT inference on dataset test set""")
parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated")
args = parser.parse_args()
main(args)