|
|
|
|
|
|
|
|
|
|
|
|
|
from enum import Enum |
|
import logging |
|
from typing import Any, Dict, Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
from torchmetrics import Metric, MetricCollection |
|
from torchmetrics.classification import MulticlassAccuracy |
|
from torchmetrics.utilities.data import dim_zero_cat, select_topk |
|
|
|
|
|
logger = logging.getLogger("dinov2") |
|
|
|
|
|
class MetricType(Enum): |
|
MEAN_ACCURACY = "mean_accuracy" |
|
MEAN_PER_CLASS_ACCURACY = "mean_per_class_accuracy" |
|
PER_CLASS_ACCURACY = "per_class_accuracy" |
|
IMAGENET_REAL_ACCURACY = "imagenet_real_accuracy" |
|
|
|
@property |
|
def accuracy_averaging(self): |
|
return getattr(AccuracyAveraging, self.name, None) |
|
|
|
def __str__(self): |
|
return self.value |
|
|
|
|
|
class AccuracyAveraging(Enum): |
|
MEAN_ACCURACY = "micro" |
|
MEAN_PER_CLASS_ACCURACY = "macro" |
|
PER_CLASS_ACCURACY = "none" |
|
|
|
def __str__(self): |
|
return self.value |
|
|
|
|
|
def build_metric(metric_type: MetricType, *, num_classes: int, ks: Optional[tuple] = None): |
|
if metric_type.accuracy_averaging is not None: |
|
return build_topk_accuracy_metric( |
|
average_type=metric_type.accuracy_averaging, |
|
num_classes=num_classes, |
|
ks=(1, 5) if ks is None else ks, |
|
) |
|
elif metric_type == MetricType.IMAGENET_REAL_ACCURACY: |
|
return build_topk_imagenet_real_accuracy_metric( |
|
num_classes=num_classes, |
|
ks=(1, 5) if ks is None else ks, |
|
) |
|
|
|
raise ValueError(f"Unknown metric type {metric_type}") |
|
|
|
|
|
def build_topk_accuracy_metric(average_type: AccuracyAveraging, num_classes: int, ks: tuple = (1, 5)): |
|
metrics: Dict[str, Metric] = { |
|
f"top-{k}": MulticlassAccuracy(top_k=k, num_classes=int(num_classes), average=average_type.value) for k in ks |
|
} |
|
return MetricCollection(metrics) |
|
|
|
|
|
def build_topk_imagenet_real_accuracy_metric(num_classes: int, ks: tuple = (1, 5)): |
|
metrics: Dict[str, Metric] = {f"top-{k}": ImageNetReaLAccuracy(top_k=k, num_classes=int(num_classes)) for k in ks} |
|
return MetricCollection(metrics) |
|
|
|
|
|
class ImageNetReaLAccuracy(Metric): |
|
is_differentiable: bool = False |
|
higher_is_better: Optional[bool] = None |
|
full_state_update: bool = False |
|
|
|
def __init__( |
|
self, |
|
num_classes: int, |
|
top_k: int = 1, |
|
**kwargs: Any, |
|
) -> None: |
|
super().__init__(**kwargs) |
|
self.num_classes = num_classes |
|
self.top_k = top_k |
|
self.add_state("tp", [], dist_reduce_fx="cat") |
|
|
|
def update(self, preds: Tensor, target: Tensor) -> None: |
|
|
|
|
|
|
|
|
|
preds_oh = select_topk(preds, self.top_k) |
|
|
|
target_oh = torch.zeros((preds_oh.shape[0], preds_oh.shape[1] + 1), device=target.device, dtype=torch.int32) |
|
target = target.long() |
|
|
|
target[target == -1] = self.num_classes |
|
|
|
target_oh.scatter_(1, target, 1) |
|
|
|
target_oh = target_oh[:, :-1] |
|
|
|
tp = (preds_oh * target_oh == 1).sum(dim=1) |
|
|
|
tp.clip_(max=1) |
|
|
|
mask = target_oh.sum(dim=1) > 0 |
|
tp = tp[mask] |
|
self.tp.append(tp) |
|
|
|
def compute(self) -> Tensor: |
|
tp = dim_zero_cat(self.tp) |
|
return tp.float().mean() |
|
|