|
|
import open_clip |
|
|
import torch |
|
|
import os |
|
|
import random |
|
|
import numpy as np |
|
|
import argparse |
|
|
from inference_tool import (zeroshot_evaluation, |
|
|
retrieval_evaluation, |
|
|
semantic_localization_evaluation, |
|
|
get_preprocess |
|
|
) |
|
|
|
|
|
|
|
|
def random_seed(seed): |
|
|
torch.manual_seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.cuda.manual_seed_all(seed) |
|
|
random.seed(seed) |
|
|
torch.backends.cudnn.benchmark = True |
|
|
torch.backends.cudnn.deterministic = False |
|
|
|
|
|
|
|
|
def build_model(model_name, ckpt_path, device): |
|
|
if model_name == "ViT-B-32": |
|
|
model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai") |
|
|
checkpoint = torch.load(ckpt_path, map_location="cpu") |
|
|
msg = model.load_state_dict(checkpoint) |
|
|
|
|
|
elif model_name == "ViT-H-14": |
|
|
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k") |
|
|
checkpoint = torch.load(ckpt_path, map_location="cpu") |
|
|
msg = model.load_state_dict(checkpoint) |
|
|
|
|
|
print(msg) |
|
|
model = model.to(device) |
|
|
print("loaded RSCLIP") |
|
|
|
|
|
preprocess_val = get_preprocess( |
|
|
image_resolution=224, |
|
|
) |
|
|
|
|
|
return model, preprocess_val |
|
|
|
|
|
|
|
|
def evaluate(model, preprocess, args): |
|
|
print("making val dataset with transformation: ") |
|
|
print(preprocess) |
|
|
zeroshot_datasets = [ |
|
|
'EuroSAT', |
|
|
'RESISC45', |
|
|
'AID' |
|
|
] |
|
|
selo_datasets = [ |
|
|
'AIR-SLT' |
|
|
] |
|
|
|
|
|
model.eval() |
|
|
all_metrics = {} |
|
|
|
|
|
|
|
|
metrics = {} |
|
|
for zeroshot_dataset in zeroshot_datasets: |
|
|
zeroshot_metrics = zeroshot_evaluation(model, zeroshot_dataset, preprocess, args) |
|
|
metrics.update(zeroshot_metrics) |
|
|
all_metrics.update(zeroshot_metrics) |
|
|
print(all_metrics) |
|
|
|
|
|
|
|
|
metrics = {} |
|
|
retrieval_metrics_rsitmd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10], |
|
|
dataset_name="rsitmd") |
|
|
metrics.update(retrieval_metrics_rsitmd) |
|
|
all_metrics.update(retrieval_metrics_rsitmd) |
|
|
print(all_metrics) |
|
|
|
|
|
|
|
|
metrics = {} |
|
|
retrieval_metrics_rsicd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10], |
|
|
dataset_name="rsicd") |
|
|
metrics.update(retrieval_metrics_rsicd) |
|
|
all_metrics.update(retrieval_metrics_rsicd) |
|
|
print(all_metrics) |
|
|
|
|
|
|
|
|
|
|
|
metrics = {} |
|
|
for selo_dataset in selo_datasets: |
|
|
selo_metrics = semantic_localization_evaluation(model, selo_dataset, preprocess, args) |
|
|
metrics.update(selo_metrics) |
|
|
all_metrics.update(selo_metrics) |
|
|
print(all_metrics) |
|
|
|
|
|
return all_metrics |
|
|
|
|
|
|
|
|
def main(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument( |
|
|
"--model-name", default="ViT-B-32", type=str, |
|
|
help="ViT-B-32 or ViT-H-14", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--ckpt-path", default="/home/zilun/RS5M_v5/ckpt/RS5M_ViT-B-32.pt", type=str, |
|
|
help="Path to RS5M_ViT-B-32.pt", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--random-seed", default=3407, type=int, |
|
|
help="random seed", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--test-dataset-dir", default="/home/zilun/RS5M_v5/data/rs5m_test_data", type=str, |
|
|
help="test dataset dir", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--batch-size", default=500, type=int, |
|
|
help="batch size", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--workers", default=8, type=int, |
|
|
help="number of workers", |
|
|
) |
|
|
args = parser.parse_args() |
|
|
args.device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
print(args) |
|
|
|
|
|
|
|
|
model, img_preprocess = build_model(args.model_name, args.ckpt_path, args.device) |
|
|
|
|
|
eval_result = evaluate(model, img_preprocess, args) |
|
|
|
|
|
for key, value in eval_result.items(): |
|
|
print("{}: {}".format(key, value)) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|