Zilun's picture
Update codebase/inference/inference.py
6906236
import open_clip
import torch
import os
import random
import numpy as np
import argparse
from inference_tool import (zeroshot_evaluation,
retrieval_evaluation,
semantic_localization_evaluation,
get_preprocess
)
def random_seed(seed):
torch.manual_seed(seed)
np.random.seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
def build_model(model_name, ckpt_path, device):
if model_name == "ViT-B-32":
model, _, _ = open_clip.create_model_and_transforms("ViT-B/32", pretrained="openai")
checkpoint = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(checkpoint)
elif model_name == "ViT-H-14":
model, _, _ = open_clip.create_model_and_transforms("ViT-H/14", pretrained="laion2b_s32b_b79k")
checkpoint = torch.load(ckpt_path, map_location="cpu")
msg = model.load_state_dict(checkpoint)
print(msg)
model = model.to(device)
print("loaded RSCLIP")
preprocess_val = get_preprocess(
image_resolution=224,
)
return model, preprocess_val
def evaluate(model, preprocess, args):
print("making val dataset with transformation: ")
print(preprocess)
zeroshot_datasets = [
'EuroSAT',
'RESISC45',
'AID'
]
selo_datasets = [
'AIR-SLT'
]
model.eval()
all_metrics = {}
# zeroshot classification
metrics = {}
for zeroshot_dataset in zeroshot_datasets:
zeroshot_metrics = zeroshot_evaluation(model, zeroshot_dataset, preprocess, args)
metrics.update(zeroshot_metrics)
all_metrics.update(zeroshot_metrics)
print(all_metrics)
# RSITMD
metrics = {}
retrieval_metrics_rsitmd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10],
dataset_name="rsitmd")
metrics.update(retrieval_metrics_rsitmd)
all_metrics.update(retrieval_metrics_rsitmd)
print(all_metrics)
# RSICD
metrics = {}
retrieval_metrics_rsicd = retrieval_evaluation(model, preprocess, args, recall_k_list=[1, 5, 10],
dataset_name="rsicd")
metrics.update(retrieval_metrics_rsicd)
all_metrics.update(retrieval_metrics_rsicd)
print(all_metrics)
# selo_datasets
# Semantic Localization
metrics = {}
for selo_dataset in selo_datasets:
selo_metrics = semantic_localization_evaluation(model, selo_dataset, preprocess, args)
metrics.update(selo_metrics)
all_metrics.update(selo_metrics)
print(all_metrics)
return all_metrics
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--model-name", default="ViT-B-32", type=str,
help="ViT-B-32 or ViT-H-14",
)
parser.add_argument(
"--ckpt-path", default="/home/zilun/RS5M_v5/ckpt/RS5M_ViT-B-32.pt", type=str,
help="Path to RS5M_ViT-B-32.pt",
)
parser.add_argument(
"--random-seed", default=3407, type=int,
help="random seed",
)
parser.add_argument(
"--test-dataset-dir", default="/home/zilun/RS5M_v5/data/rs5m_test_data", type=str,
help="test dataset dir",
)
parser.add_argument(
"--batch-size", default=500, type=int,
help="batch size",
)
parser.add_argument(
"--workers", default=8, type=int,
help="number of workers",
)
args = parser.parse_args()
args.device = "cuda" if torch.cuda.is_available() else "cpu"
print(args)
# random_seed(args.random_seed)
model, img_preprocess = build_model(args.model_name, args.ckpt_path, args.device)
eval_result = evaluate(model, img_preprocess, args)
for key, value in eval_result.items():
print("{}: {}".format(key, value))
if __name__ == "__main__":
main()