Spaces:
Running
Running
import torch | |
from metrics.utils import haversine, reverse | |
from sklearn.metrics import pairwise_distances | |
from torchmetrics import Metric | |
import numpy as np | |
from utils.kde import BatchedKDE | |
from tqdm import tqdm | |
class HaversineMetrics(Metric): | |
""" | |
Computes the average haversine distance between the predicted and ground truth points. | |
Compute the accuracy given some radiuses. | |
Compute the Geoguessr score given some radiuses. | |
Args: | |
acc_radiuses (list): list of radiuses to compute the accuracy from | |
acc_area (list): list of areas to compute the accuracy from. | |
""" | |
def __init__( | |
self, | |
acc_radiuses=[], | |
acc_area=["country", "region", "sub-region", "city"], | |
use_kde=False, | |
manifold_k=3, | |
): | |
super().__init__() | |
self.use_kde = use_kde | |
self.add_state("haversine_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
self.add_state("geoguessr_sum", default=torch.tensor(0.0), dist_reduce_fx="sum") | |
for acc in acc_radiuses: | |
self.add_state( | |
f"close_enough_points_{acc}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
for acc in acc_area: | |
self.add_state( | |
f"close_enough_points_{acc}", | |
default=torch.tensor(0.0), | |
dist_reduce_fx="sum", | |
) | |
self.add_state( | |
f"count_{acc}", default=torch.tensor(0), dist_reduce_fx="sum" | |
) | |
self.acc_radius = acc_radiuses | |
self.acc_area = acc_area | |
self.add_state("count", default=torch.tensor(0), dist_reduce_fx="sum") | |
self.add_state( | |
"real_points", | |
[], | |
dist_reduce_fx=None, | |
) | |
self.add_state( | |
"fake_points", | |
[], | |
dist_reduce_fx=None, | |
) | |
self.manifold_k = manifold_k | |
def update(self, pred, gt): | |
if self.use_kde: | |
(x_mode, y_mode), kde = estimate_kde_mode(pred["gps"]) | |
# self.nll_sum += -torch.log( | |
# kde.score(gt["gps"].unsqueeze(1).to(pred["gps"].device)) | |
# ).sum() | |
pred["gps"] = torch.stack([x_mode, y_mode], dim=1) | |
# Handle NaN values without modifying the original inputs | |
if pred["gps"].isnan().any(): | |
valid_mask = ~pred["gps"].isnan().any(dim=1) | |
pred_gps = pred["gps"][valid_mask] | |
gt_gps = gt["gps"][valid_mask] | |
if len(pred_gps) == 0: # Skip if no valid predictions remain | |
return | |
else: | |
pred_gps = pred["gps"] | |
gt_gps = gt["gps"] | |
haversine_distance = haversine(pred_gps, gt_gps) | |
for acc in self.acc_radius: | |
self.__dict__[f"close_enough_points_{acc}"] += ( | |
haversine_distance < acc | |
).sum() | |
if len(self.acc_area) > 0: | |
area_pred, area_gt = reverse(pred_gps, gt, self.acc_area) | |
for acc in self.acc_area: | |
self.__dict__[f"close_enough_points_{acc}"] += ( | |
area_pred[acc] == area_gt["_".join(["unique", acc])] | |
).sum() | |
self.__dict__[f"count_{acc}"] += len(area_gt["_".join(["unique", acc])]) | |
self.haversine_sum += haversine_distance.sum() | |
self.geoguessr_sum += 5000 * torch.exp(-haversine_distance / 1492.7).sum() | |
self.real_points.append(gt_gps) | |
self.fake_points.append(pred_gps) | |
self.count += pred_gps.shape[0] | |
def compute(self): | |
output = { | |
"Haversine": self.haversine_sum / self.count, | |
"Geoguessr": self.geoguessr_sum / self.count, | |
} | |
for acc in self.acc_radius: | |
output[f"Accuracy_{acc}_km_radius"] = ( | |
self.__dict__[f"close_enough_points_{acc}"] / self.count | |
) | |
for acc in self.acc_area: | |
output[f"Accuracy_{acc}"] = ( | |
self.__dict__[f"close_enough_points_{acc}"] | |
/ self.__dict__[f"count_{acc}"] | |
) | |
real_points = torch.cat(self.real_points, dim=0) | |
fake_points = torch.cat(self.fake_points, dim=0) | |
( | |
output["precision"], | |
output["recall"], | |
output["density"], | |
output["coverage"], | |
) = self.manifold_metrics(real_points, fake_points, self.manifold_k) | |
return output | |
def compute_pairwise_distance(self, data_x, data_y=None): | |
""" | |
Args: | |
data_x: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
data_y: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
Returns: | |
numpy.ndarray([N, N], dtype=np.float32) of pairwise distances. | |
""" | |
if data_y is None: | |
data_y = data_x | |
dists = pairwise_distances(data_x, data_y, metric="haversine", n_jobs=8) | |
return dists | |
def get_kth_value(self, unsorted, k, axis=-1): | |
""" | |
Args: | |
unsorted: numpy.ndarray of any dimensionality. | |
k: int | |
Returns: | |
kth values along the designated axis. | |
""" | |
indices = np.argpartition(unsorted, k, axis=axis)[..., :k] | |
k_smallests = np.take_along_axis(unsorted, indices, axis=axis) | |
kth_values = k_smallests.max(axis=axis) | |
return kth_values | |
def compute_nearest_neighbour_distances(self, input_features, nearest_k): | |
""" | |
Args: | |
input_features: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
nearest_k: int | |
Returns: | |
Distances to kth nearest neighbours. | |
""" | |
distances = self.compute_pairwise_distance(input_features) | |
radii = self.get_kth_value(distances, k=nearest_k + 1, axis=-1) | |
return radii | |
def compute_prdc(self, real_features, fake_features, nearest_k): | |
""" | |
Computes precision, recall, density, and coverage given two manifolds. | |
Args: | |
real_features: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
fake_features: numpy.ndarray([N, feature_dim], dtype=np.float32) | |
nearest_k: int. | |
Returns: | |
dict of precision, recall, density, and coverage. | |
""" | |
real_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( | |
real_features, nearest_k | |
) | |
fake_nearest_neighbour_distances = self.compute_nearest_neighbour_distances( | |
fake_features, nearest_k | |
) | |
distance_real_fake = self.compute_pairwise_distance( | |
real_features, fake_features | |
) | |
precision = ( | |
( | |
distance_real_fake | |
< np.expand_dims(real_nearest_neighbour_distances, axis=1) | |
) | |
.any(axis=0) | |
.mean() | |
) | |
recall = ( | |
( | |
distance_real_fake | |
< np.expand_dims(fake_nearest_neighbour_distances, axis=0) | |
) | |
.any(axis=1) | |
.mean() | |
) | |
density = (1.0 / float(nearest_k)) * ( | |
distance_real_fake | |
< np.expand_dims(real_nearest_neighbour_distances, axis=1) | |
).sum(axis=0).mean() | |
coverage = ( | |
distance_real_fake.min(axis=1) < real_nearest_neighbour_distances | |
).mean() | |
return precision, recall, density, coverage | |
def manifold_metrics(self, real_features, fake_features, nearest_k, num_splits=20): | |
""" | |
Computes precision, recall, density, and coverage given two manifolds. | |
Args: | |
real_features: torch.Tensor([N, feature_dim], dtype=torch.float32) | |
fake_features: torch.Tensor([N, feature_dim], dtype=torch.float32) | |
nearest_k: int. | |
num_splits: int. Number of splits to use for computing metrics. | |
Returns: | |
dict of precision, recall, density, and coverage. | |
""" | |
real_features = real_features.chunk(num_splits, dim=0) | |
fake_features = fake_features.chunk(num_splits, dim=0) | |
precision, recall, density, coverage = [], [], [], [] | |
for real, fake in tqdm( | |
zip(real_features, fake_features), desc="Computing manifold" | |
): | |
p, r, d, c = self.compute_prdc( | |
real.cpu().numpy(), fake.cpu().numpy(), nearest_k=nearest_k | |
) | |
precision.append(torch.tensor(p, device=real.device)) | |
recall.append(torch.tensor(r, device=real.device)) | |
density.append(torch.tensor(d, device=real.device)) | |
coverage.append(torch.tensor(c, device=real.device)) | |
return ( | |
torch.stack(precision).mean().item(), | |
torch.stack(recall).mean().item(), | |
torch.stack(density).mean().item(), | |
torch.stack(coverage).mean().item(), | |
) | |
def estimate_kde_mode(points): | |
kde = BatchedKDE() | |
kde.fit(points) | |
batch_size = points.shape[0] | |
X, Y, positions = batched_make_grid(points.cpu()) | |
X = X.to(points.device) | |
Y = Y.to(points.device) | |
positions = positions.to(points.device) | |
Z = kde.score(positions).reshape(X.shape) | |
x_mode = X.reshape(batch_size, -1)[ | |
torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) | |
] | |
y_mode = Y.reshape(batch_size, -1)[ | |
torch.arange(batch_size), Z.reshape(batch_size, -1).argmax(dim=1) | |
] | |
return (x_mode, y_mode), kde | |
def make_grid(points): | |
(lat_min, long_min), _ = points.min(dim=-2) | |
(lat_max, long_max), _ = points.max(dim=-2) | |
x = torch.linspace(lat_min, lat_max, 100) | |
y = torch.linspace(long_min, long_max, 100) | |
X, Y = torch.meshgrid(x, y) | |
positions = torch.vstack([X.flatten(), Y.flatten()]).transpose(-1, -2) | |
return X, Y, positions | |
batched_make_grid = torch.vmap(make_grid) | |