|
import os |
|
import numpy as np |
|
from tqdm import tqdm |
|
import torch |
|
from datasets import load_dataset, ClassLabel |
|
from datasets import Features, Array3D |
|
from transformers import AutoFeatureExtractor, AutoModelForImageClassification |
|
from metrics import apply_metrics |
|
|
|
|
|
def process_label_ids(batch, remapper, label_column="label"): |
|
batch[label_column] = [remapper[label_id] for label_id in batch[label_column]] |
|
return batch |
|
|
|
|
|
CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None |
|
|
|
|
|
def main(args): |
|
dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR) |
|
if args.dataset == "rvl_cdip": |
|
dataset = dataset.select([i for i in range(len(dataset)) if i != 33669]) |
|
batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000 |
|
|
|
feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
|
model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip") |
|
|
|
label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()} |
|
data_idx2label = dict(zip(enumerate(dataset.features["label"].names))) |
|
data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)} |
|
model_idx2label = dict(zip(label2idx.values(), label2idx.keys())) |
|
diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]] |
|
|
|
if diff: |
|
print(f"aligning labels {diff}") |
|
print(f"model labels: {model_idx2label}") |
|
print(f"data labels: {data_idx2label}") |
|
print(f"Remapping to {label2idx}") |
|
|
|
remapper = {} |
|
for k, v in label2idx.items(): |
|
if k in data_label2idx: |
|
remapper[data_label2idx[k]] = v |
|
|
|
print(remapper) |
|
new_features = Features( |
|
{ |
|
**{k: v for k, v in dataset.features.items() if k != "label"}, |
|
"label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())), |
|
} |
|
) |
|
|
|
dataset = dataset.map( |
|
lambda example: process_label_ids(example, remapper), |
|
features=new_features, |
|
batched=True, |
|
batch_size=batch_size, |
|
desc="Aligning the labels", |
|
) |
|
|
|
features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))}) |
|
|
|
encoded_dataset = dataset.map( |
|
lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]), |
|
batched=True, |
|
batch_size=batch_size, |
|
features=features, |
|
) |
|
encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"]) |
|
BATCH_SIZE = 16 |
|
dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE) |
|
|
|
all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros( |
|
len(encoded_dataset), dtype=int |
|
) |
|
|
|
count = 0 |
|
for i, batch in tqdm(enumerate(dataloader)): |
|
with torch.no_grad(): |
|
outputs = model(batch["pixel_values"]) |
|
logits = outputs.logits |
|
all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy() |
|
all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy() |
|
count += len(batch["label"]) |
|
|
|
all_references = np.array(all_references) |
|
all_logits = np.array(all_logits) |
|
results = apply_metrics(all_references, all_logits) |
|
print(results) |
|
|
|
|
|
if __name__ == "__main__": |
|
from argparse import ArgumentParser |
|
|
|
parser = ArgumentParser("""DiT inference on dataset test set""") |
|
parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated") |
|
args = parser.parse_args() |
|
|
|
main(args) |
|
|