Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from enum import Enum | |
| import logging | |
| from typing import Any, Dict, Optional | |
| import torch | |
| from torch import Tensor | |
| from torchmetrics import Metric, MetricCollection | |
| from torchmetrics.classification import MulticlassAccuracy | |
| from torchmetrics.utilities.data import dim_zero_cat, select_topk | |
| logger = logging.getLogger("dinov2") | |
| class MetricType(Enum): | |
| MEAN_ACCURACY = "mean_accuracy" | |
| MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" | |
| PER_CLASS_ACCURACY = "per_class_accuracy" | |
| IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" | |
| def accuracy_averaging(self): | |
| return getattr(AccuracyAveraging, self.name, None) | |
| def __str__(self): | |
| return self.value | |
| class AccuracyAveraging(Enum): | |
| MEAN_ACCURACY = "micro" | |
| MEAN_PER_CLASS_ACCURACY = "macro" | |
| PER_CLASS_ACCURACY = "none" | |
| def __str__(self): | |
| return self.value | |
| def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): | |
| if metric_type.accuracy_averaging is not None: | |
| return build_topk_accuracy_metric( | |
| average_type=metric_type.accuracy_averaging, | |
| num_classes=num_classes, | |
| ks=(1, 5) if ks is None else ks, | |
| ) | |
| elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: | |
| return build_topk_imagenet_real_accuracy_metric( | |
| num_classes=num_classes, | |
| ks=(1, 5) if ks is None else ks, | |
| ) | |
| raise ValueError(f"Unknown metric type {metric_type}") | |
| def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): | |
| metrics: Dict[str, Metric] = { | |
| f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks | |
| } | |
| return MetricCollection(metrics) | |
| def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): | |
| metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} | |
| return MetricCollection(metrics) | |
| class ImageNetReaLAccuracy(Metric): | |
| is_differentiable: bool = False | |
| higher_is_better: Optional[bool] = None | |
| full_state_update: bool = False | |
| def __init__( | |
| self, | |
| num_classes: int, | |
| top_k: int = 1, | |
| **kwargs: Any, | |
| ) -> None: | |
| super().__init__(**kwargs) | |
| self.num_classes = num_classes | |
| self.top_k = top_k | |
| self.add_state("tp", [], dist_reduce_fx="cat") | |
| def update(self, preds: Tensor, target: Tensor) -> None: # type: ignore | |
| # preds [B, D] | |
| # target [B, A] | |
| # preds_oh [B, D] with 0 and 1 | |
| # select top K highest probabilities, use one hot representation | |
| preds_oh = select_topk(preds, self.top_k) | |
| # target_oh [B, D + 1] with 0 and 1 | |
| target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) | |
| target = target.long() | |
| # for undefined targets (-1) use a fake value `num_classes` | |
| target[target == -1] = self.num_classes | |
| # fill targets, use one hot representation | |
| target_oh.scatter_(1, target, 1) | |
| # target_oh [B, D] (remove the fake target at index `num_classes`) | |
| target_oh = target_oh[:, :-1] | |
| # tp [B] with 0 and 1 | |
| tp = (preds_oh * target_oh == 1).sum(dim=1) | |
| # at least one match between prediction and target | |
| tp.clip_(max=1) | |
| # ignore instances where no targets are defined | |
| mask = target_oh.sum(dim=1) > 0 | |
| tp = tp[mask] | |
| self.tp.append(tp) # type: ignore | |
| def compute(self) -> Tensor: | |
| tp = dim_zero_cat(self.tp) # type: ignore | |
| return tp.float().mean() | |