File size: 4,046 Bytes
b0c0186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6906236
b0c0186
 
d3beb1f
b0c0186
6906236
b0c0186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import open_clip
import torch
import os
import random
import numpy as np
import argparse
from inference_tool import (zeroshot_evaluation,
                            retrieval_evaluation,
                            semantic_localization_evaluation,
                            get_preprocess
                            )


def random_seed(seed):
    torch.manual_seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed_all(seed)
    random.seed(seed)
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False


def build_model(model_name, ckpt_path, device):
    if model_name == "ViT-B-32":
        model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        msg = model.load_state_dict(checkpoint)

    elif model_name == "ViT-H-14":
        model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
        checkpoint = torch.load(ckpt_path, map_location="cpu")
        msg = model.load_state_dict(checkpoint)

    print(msg)
    model = model.to(device)
    print("loaded RSCLIP")

    preprocess_val = get_preprocess(
        image_resolution=224,
    )

    return model, preprocess_val


def evaluate(model, preprocess, args):
    print("making val dataset with transformation: ")
    print(preprocess)
    zeroshot_datasets = [
        'EuroSAT',
        'RESISC45',
        'AID'
    ]
    selo_datasets = [
        'AIR-SLT'
    ]

    model.eval()
    all_metrics = {}

    # zeroshot classification
    metrics = {}
    for zeroshot_dataset in zeroshot_datasets:
        zeroshot_metrics = zeroshot_evaluation(model, zeroshot_dataset, preprocess, args)
        metrics.update(zeroshot_metrics)
        all_metrics.update(zeroshot_metrics)
    print(all_metrics)

    # RSITMD
    metrics = {}
    retrieval_metrics_rsitmd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10],
                                                    dataset_name="rsitmd")
    metrics.update(retrieval_metrics_rsitmd)
    all_metrics.update(retrieval_metrics_rsitmd)
    print(all_metrics)

    # RSICD
    metrics = {}
    retrieval_metrics_rsicd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10],
                                                   dataset_name="rsicd")
    metrics.update(retrieval_metrics_rsicd)
    all_metrics.update(retrieval_metrics_rsicd)
    print(all_metrics)

    # selo_datasets
    # Semantic Localization
    metrics = {}
    for selo_dataset in selo_datasets:
        selo_metrics = semantic_localization_evaluation(model, selo_dataset, preprocess, args)
        metrics.update(selo_metrics)
        all_metrics.update(selo_metrics)
    print(all_metrics)

    return all_metrics


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--model-name", default="ViT-B-32", type=str,
        help="ViT-B-32 or ViT-H-14",
    )
    parser.add_argument(
        "--ckpt-path", default="/home/zilun/RS5M_v5/ckpt/RS5M_ViT-B-32.pt", type=str,
        help="Path to RS5M_ViT-B-32.pt",
    )
    parser.add_argument(
        "--random-seed", default=3407, type=int,
        help="random seed",
    )
    parser.add_argument(
        "--test-dataset-dir", default="/home/zilun/RS5M_v5/data/rs5m_test_data", type=str,
        help="test dataset dir",
    )
    parser.add_argument(
        "--batch-size", default=500, type=int,
        help="batch size",
    )
    parser.add_argument(
        "--workers", default=8, type=int,
        help="number of workers",
    )
    args = parser.parse_args()
    args.device = "cuda" if torch.cuda.is_available() else "cpu"
    print(args)
    # random_seed(args.random_seed)

    model, img_preprocess = build_model(args.model_name, args.ckpt_path, args.device)

    eval_result = evaluate(model, img_preprocess, args)

    for key, value in eval_result.items():
        print("{}: {}".format(key, value))


if __name__ == "__main__":
    main()