File size: 3,860 Bytes
1ceb840
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import os
import numpy as np
from tqdm import tqdm
import torch
from datasets import load_dataset, ClassLabel
from datasets import Features, Array3D
from transformers import AutoFeatureExtractor, AutoModelForImageClassification
from metrics import apply_metrics


def process_label_ids(batch, remapper, label_column="label"):
    batch[label_column] = [remapper[label_id] for label_id in batch[label_column]]
    return batch


CACHE_DIR = "/mnt/lerna/data/HFcache" if os.path.exists("/mnt/lerna/data/HFcache") else None


def main(args):
    dataset = load_dataset(args.dataset, split="test", cache_dir=CACHE_DIR)
    if args.dataset == "rvl_cdip":
        dataset = dataset.select([i for i in range(len(dataset)) if i != 33669])  # corrupt sample
    batch_size = 100 if args.dataset == "jordyvl/RVL-CDIP-N" else 1000

    feature_extractor = AutoFeatureExtractor.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")
    model = AutoModelForImageClassification.from_pretrained("microsoft/dit-base-finetuned-rvlcdip")

    label2idx = {label.replace(" ", "_"): i for label, i in model.config.label2id.items()}
    data_idx2label = dict(zip(enumerate(dataset.features["label"].names)))
    data_label2idx = {label: i for i, label in enumerate(dataset.features["label"].names)}
    model_idx2label = dict(zip(label2idx.values(), label2idx.keys()))
    diff = [i for i in range(len(data_label2idx)) if data_idx2label[i] != model_idx2label[i]]

    if diff:
        print(f"aligning labels {diff}")
        print(f"model labels: {model_idx2label}")
        print(f"data labels: {data_idx2label}")
        print(f"Remapping to {label2idx}")  # YET cannot change the length of the labels

        remapper = {}
        for k, v in label2idx.items():
            if k in data_label2idx:
                remapper[data_label2idx[k]] = v

        print(remapper)
        new_features = Features(
            {
                **{k: v for k, v in dataset.features.items() if k != "label"},
                "label": ClassLabel(num_classes=len(label2idx), names=list(label2idx.keys())),
            }
        )

        dataset = dataset.map(
            lambda example: process_label_ids(example, remapper),
            features=new_features,
            batched=True,
            batch_size=batch_size,
            desc="Aligning the labels",
        )

    features = Features({**dataset.features, "pixel_values": Array3D(dtype="float32", shape=(3, 224, 224))})

    encoded_dataset = dataset.map(
        lambda examples: feature_extractor([image.convert("RGB") for image in examples["image"]]),
        batched=True,
        batch_size=batch_size,
        features=features,
    )
    encoded_dataset.set_format(type="torch", columns=["pixel_values", "label"])
    BATCH_SIZE = 16
    dataloader = torch.utils.data.DataLoader(encoded_dataset, batch_size=BATCH_SIZE)

    all_logits, all_references = np.zeros((len(encoded_dataset), len(label2idx))), np.zeros(
        len(encoded_dataset), dtype=int
    )

    count = 0
    for i, batch in tqdm(enumerate(dataloader)):
        with torch.no_grad():
            outputs = model(batch["pixel_values"])
            logits = outputs.logits
            all_logits[count : count + BATCH_SIZE] = logits.detach().cpu().numpy()
            all_references[count : count + BATCH_SIZE] = batch["label"].detach().cpu().numpy()
            count += len(batch["label"])

    all_references = np.array(all_references)
    all_logits = np.array(all_logits)
    results = apply_metrics(all_references, all_logits)
    print(results)


if __name__ == "__main__":
    from argparse import ArgumentParser

    parser = ArgumentParser("""DiT inference on dataset test set""")
    parser.add_argument("-d", dest="dataset", type=str, default="rvl_cdip", help="the dataset to be evaluated")
    args = parser.parse_args()

    main(args)