# Copyright (c) Meta Platforms, Inc. and affiliates. import torch import torchmetrics from torchmetrics.utilities.data import dim_zero_cat from .utils import deg2rad, rotmat2d def location_error(uv, uv_gt, ppm=1): return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm def location_error_single(uv, uv_gt, ppm=1): return torch.norm(uv - uv_gt.to(uv), dim=-1) / ppm def angle_error(t, t_gt): error = torch.abs(t % 360 - t_gt.to(t) % 360) error = torch.minimum(error, 360 - error) return error class Location2DRecall(torchmetrics.MeanMetric): def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs): self.threshold = threshold self.ppm = pixel_per_meter self.key = key super().__init__(*args, **kwargs) def update(self, pred, data): self.cuda() error = location_error(pred[self.key], data["uv"], self.ppm) # print(error,self.threshold) super().update((error <= torch.tensor(self.threshold,device=error.device)).float()) class Location1DRecall(torchmetrics.MeanMetric): def __init__(self, threshold, pixel_per_meter, key="uv_max", *args, **kwargs): self.threshold = threshold self.ppm = pixel_per_meter self.key = key super().__init__(*args, **kwargs) def update(self, pred, data): self.cuda() error = location_error(pred[self.key], data["uv"], self.ppm) # print(error,self.threshold) super().update((error <= torch.tensor(self.threshold,device=error.device)).float()) class AngleRecall(torchmetrics.MeanMetric): def __init__(self, threshold, key="yaw_max", *args, **kwargs): self.threshold = threshold self.key = key super().__init__(*args, **kwargs) def update(self, pred, data): self.cuda() error = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1]) super().update((error <= self.threshold).float()) class MeanMetricWithRecall(torchmetrics.Metric): full_state_update = True def __init__(self): super().__init__() self.add_state("value", default=[], dist_reduce_fx="cat") def compute(self): return dim_zero_cat(self.value).mean(0) def get_errors(self): return dim_zero_cat(self.value) def recall(self, thresholds): self.cuda() error = self.get_errors() thresholds = error.new_tensor(thresholds) return (error.unsqueeze(-1) < thresholds).float().mean(0) * 100 class AngleError(MeanMetricWithRecall): def __init__(self, key): super().__init__() self.key = key def update(self, pred, data): self.cuda() value = angle_error(pred[self.key], data["roll_pitch_yaw"][..., -1]) if value.numel(): self.value.append(value) class Location2DError(MeanMetricWithRecall): def __init__(self, key, pixel_per_meter): super().__init__() self.key = key self.ppm = pixel_per_meter def update(self, pred, data): self.cuda() value = location_error(pred[self.key], data["uv"], self.ppm) if value.numel(): self.value.append(value) class LateralLongitudinalError(MeanMetricWithRecall): def __init__(self, pixel_per_meter, key="uv_max"): super().__init__() self.ppm = pixel_per_meter self.key = key def update(self, pred, data): self.cuda() yaw = deg2rad(data["roll_pitch_yaw"][..., -1]) shift = (pred[self.key] - data["uv"]) * yaw.new_tensor([-1, 1]) shift = (rotmat2d(yaw) @ shift.unsqueeze(-1)).squeeze(-1) error = torch.abs(shift) / self.ppm value = error.view(-1, 2) if value.numel(): self.value.append(value)