|
import os |
|
import sys |
|
import PIL |
|
import json |
|
import torch |
|
import numpy as np |
|
import pandas as pd |
|
import operator |
|
|
|
from PIL import Image |
|
from itertools import cycle |
|
from tqdm.auto import tqdm, trange |
|
from os.path import join |
|
from PIL import Image |
|
|
|
from tqdm import tqdm |
|
from torch.utils.data import Dataset, DataLoader |
|
from torch.nn import functional as F |
|
|
|
from backbone import get_backbone |
|
from utils import haversine, get_filenames, get_match_values, compute_print_accuracy |
|
|
|
|
|
def compute_features(path, data_dir, csv_file, tag, args): |
|
data = GeoDataset(data_dir, csv_file, tag=tag) |
|
if not os.path.isdir(test_features_dir) or len( |
|
os.listdir(test_features_dir) |
|
) != len(data): |
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model, transform, inference, collate_fn = get_backbone(args.name) |
|
dataloader = DataLoader( |
|
data, |
|
batch_size=args.batch_size, |
|
shuffle=False, |
|
num_workers=8, |
|
collate_fn=collate_fn, |
|
) |
|
model = model.to(device) |
|
os.makedirs(path, exist_ok=True) |
|
|
|
for i, x in enumerate(tqdm(dataloader)): |
|
image_ids, features = inference(model, x) |
|
|
|
for j, image_id in zip(range(features.shape[0]), image_ids): |
|
np.save(join(path, f"{image_id}.npy"), features[j].unsqueeze(0).numpy()) |
|
|
|
|
|
def get_results(args, train_test): |
|
import joblib |
|
|
|
if not os.path.isfile(join(args.features_parent, ".cache", "1-nn.pkl")): |
|
import faiss, glob, bisect |
|
|
|
|
|
indexes = [ |
|
get_filenames(idx) for idx in tqdm(range(1, 6), desc="Loading indexes...") |
|
] |
|
|
|
train_gt = pd.read_csv( |
|
join(args.data_parent, args.annotation_file), dtype={"image_id": str} |
|
)[["image_id", "latitude", "longitude"]] |
|
test_gt = pd.read_csv(test_path_csv, dtype={"id": str})[ |
|
["id", "latitude", "longitude"] |
|
] |
|
|
|
|
|
train_gt = { |
|
g[1]["image_id"]: np.array([g[1]["latitude"], g[1]["longitude"]]) |
|
for g in tqdm( |
|
train_gt.iterrows(), total=len(train_gt), desc="Loading train_gt" |
|
) |
|
} |
|
test_gt = { |
|
g[1]["id"]: np.array([g[1]["latitude"], g[1]["longitude"]]) |
|
for g in tqdm( |
|
test_gt.iterrows(), total=len(test_gt), desc="Loading test_gt" |
|
) |
|
} |
|
|
|
train_test = [] |
|
os.makedirs(join(args.features_parent, ".cache"), exist_ok=True) |
|
for f in tqdm(os.listdir(test_features_dir)): |
|
query_vector = np.load(join(test_features_dir, f)) |
|
|
|
neighbors = [] |
|
for index, ids in indexes: |
|
distances, indices = index.search(query_vector, 1) |
|
distances, indices = np.squeeze(distances), np.squeeze(indices) |
|
bisect.insort( |
|
neighbors, (ids[indices], distances), key=operator.itemgetter(1) |
|
) |
|
|
|
neighbors = list(reversed(neighbors)) |
|
train_gps = train_gt[neighbors[0][0].replace(".npy", "")][None, :] |
|
test_gps = test_gt[f.replace(".npy", "")][None, :] |
|
train_test.append((train_gps, test_gps)) |
|
joblib.dump(train_test, join(args.features_parent, ".cache", "1-nn.pkl")) |
|
else: |
|
train_test = joblib.load(join(args.features_parent, ".cache", "1-nn.pkl")) |
|
|
|
return train_test |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
import argparse |
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--id", type=int, default=1) |
|
parser.add_argument("--batch_size", type=int, default=512) |
|
parser.add_argument( |
|
"--annotation_file", type=str, required=False, default="train.csv" |
|
) |
|
parser.add_argument("--name", type=str, default="openai/clip-vit-base-patch32") |
|
parser.add_argument("--features_parent", type=str, default="faiss/") |
|
parser.add_argument("--data_parent", type=str, default="data/") |
|
parser.add_argument("--test", action="store_true") |
|
|
|
args = parser.parse_args() |
|
args.features_parent = join(args.features_parent, args.name) |
|
if args.test: |
|
csv_file = join(args.data_parent, "test.csv") |
|
data_dir = join(args.data_parent, "test") |
|
path = join(args.features_parent, "features-test") |
|
model = get_backbone(args.name) |
|
compute_features(path, data_dir, csv_file, tag="id", args=args) |
|
train_test = get_results(args, train_test) |
|
|
|
from collections import Counter |
|
|
|
N, pos = Counter(), Counter() |
|
for train_gps, test_gps in tqdm(train_test, desc="Computing accuracy..."): |
|
get_match_values(train_gps, test_gps, N, pos) |
|
|
|
for train_gps, test_gps in tqdm(train_test, desc="Computing haversine..."): |
|
haversine(train_gps, test_gps, N, pos) |
|
|
|
compute_print_accuracy(N, pos) |
|
else: |
|
csv_file = join(args.data_parent, args.annotation_file) |
|
path = join(args.features_parent, f"features-{args.id}") |
|
data_dir = join(args.data_parent, f"images-{args.id}", "train") |
|
compute_features(path, data_dir, csv_file, tag="image_id", args=args) |
|
|