import torch import torchmetrics import torchmetrics.classification class PixelAccuracy(torchmetrics.Metric): def __init__(self): super().__init__() self.add_state("correct_pixels", default=torch.tensor( 0), dist_reduce_fx="sum") self.add_state("total_pixels", default=torch.tensor(0), dist_reduce_fx="sum") def update(self, pred, data): output_mask = pred['output'] > 0.5 gt_mask = data["seg_masks"].permute(0, 3, 1, 2) self.correct_pixels += ( (output_mask == gt_mask).sum() ) self.total_pixels += torch.numel(pred["valid_bev"][..., :-1]) def compute(self): return self.correct_pixels / self.total_pixels class IOU(torchmetrics.Metric): def __init__(self, num_classes=3, **kwargs): super().__init__(**kwargs) self.num_classes = num_classes self.add_state("intersection_observable", default=torch.zeros( num_classes), dist_reduce_fx="sum") self.add_state("union_observable", default=torch.zeros( num_classes), dist_reduce_fx="sum") self.add_state("intersection_non_observable", default=torch.zeros(num_classes), dist_reduce_fx="sum") self.add_state("union_non_observable", default=torch.zeros( num_classes), dist_reduce_fx="sum") def update(self, output, data): gt = data["seg_masks"] pred = output['output'] if "confidence_map" in data: observable_mask = torch.logical_and( output["valid_bev"][..., :-1], data["confidence_map"] == 0) non_observable_mask = torch.logical_and( output["valid_bev"][..., :-1], data["confidence_map"] == 1) else: observable_mask = output["valid_bev"][..., :-1] non_observable_mask = torch.logical_not(observable_mask) for class_idx in range(self.num_classes): pred_mask = pred[:, class_idx] > 0.5 gt_mask = gt[..., class_idx] # For observable areas intersection_observable = torch.logical_and( torch.logical_and(pred_mask, gt_mask), observable_mask ).sum() union_observable = torch.logical_and( torch.logical_or(pred_mask, gt_mask), observable_mask ).sum() self.intersection_observable[class_idx] += intersection_observable self.union_observable[class_idx] += union_observable # For non-observable areas intersection_non_observable = torch.logical_and( torch.logical_and(pred_mask, gt_mask), non_observable_mask ).sum() union_non_observable = torch.logical_and( torch.logical_or(pred_mask, gt_mask), non_observable_mask ).sum() self.intersection_non_observable[class_idx] += intersection_non_observable self.union_non_observable[class_idx] += union_non_observable def compute(self): raise NotImplemented class ObservableIOU(IOU): def __init__(self, class_idx=0, **kwargs): super().__init__(**kwargs) self.class_idx = class_idx def compute(self): return (self.intersection_observable / (self.union_observable + 1e-6))[self.class_idx] class UnobservableIOU(IOU): def __init__(self, class_idx=0, **kwargs): super().__init__(**kwargs) self.class_idx = class_idx def compute(self): return (self.intersection_non_observable / (self.union_non_observable + 1e-6))[self.class_idx] class MeanObservableIOU(IOU): def compute(self): return self.intersection_observable.sum() / (self.union_observable.sum() + 1e-6) class MeanUnobservableIOU(IOU): def compute(self): return self.intersection_non_observable.sum() / (self.union_non_observable.sum() + 1e-6) class mAP(torchmetrics.classification.MultilabelPrecision): def __init__(self, num_labels, **kwargs): super().__init__(num_labels=num_labels, **kwargs) def update(self, output, data): if "confidence_map" in data: observable_mask = torch.logical_and( output["valid_bev"][..., :-1], data["confidence_map"] == 0) else: observable_mask = output["valid_bev"][..., :-1] pred = output['output'] pred = pred.permute(0, 2, 3, 1) pred = pred[observable_mask] target = data['seg_masks'] target = target[observable_mask] super(mAP, self).update(pred, target)