# Copyright (c) OpenMMLab. All rights reserved. from typing import Dict, Sequence from mmengine.evaluator import BaseMetric from mmpretrain.registry import METRICS @METRICS.register_module() class MultiTasksMetric(BaseMetric): """Metrics for MultiTask Args: task_metrics(dict): a dictionary in the keys are the names of the tasks and the values is a list of the metric corresponds to this task Examples: >>> import torch >>> from mmpretrain.evaluation import MultiTasksMetric # -------------------- The Basic Usage -------------------- >>>task_metrics = { 'task0': [dict(type='Accuracy', topk=(1, ))], 'task1': [dict(type='Accuracy', topk=(1, 3))] } >>>pred = [{ 'pred_task': { 'task0': torch.tensor([0.7, 0.0, 0.3]), 'task1': torch.tensor([0.5, 0.2, 0.3]) }, 'gt_task': { 'task0': torch.tensor(0), 'task1': torch.tensor(2) } }, { 'pred_task': { 'task0': torch.tensor([0.0, 0.0, 1.0]), 'task1': torch.tensor([0.0, 0.0, 1.0]) }, 'gt_task': { 'task0': torch.tensor(2), 'task1': torch.tensor(2) } }] >>>metric = MultiTasksMetric(task_metrics) >>>metric.process(None, pred) >>>results = metric.evaluate(2) results = { 'task0_accuracy/top1': 100.0, 'task1_accuracy/top1': 50.0, 'task1_accuracy/top3': 100.0 } """ def __init__(self, task_metrics: Dict, collect_device: str = 'cpu') -> None: self.task_metrics = task_metrics super().__init__(collect_device=collect_device) self._metrics = {} for task_name in self.task_metrics.keys(): self._metrics[task_name] = [] for metric in self.task_metrics[task_name]: self._metrics[task_name].append(METRICS.build(metric)) def process(self, data_batch, data_samples: Sequence[dict]): """Process one batch of data samples. The processed results should be stored in ``self.results``, which will be used to computed the metrics when all batches have been processed. Args: data_batch: A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ for task_name in self.task_metrics.keys(): filtered_data_samples = [] for data_sample in data_samples: eval_mask = data_sample[task_name]['eval_mask'] if eval_mask: filtered_data_samples.append(data_sample[task_name]) for metric in self._metrics[task_name]: metric.process(data_batch, filtered_data_samples) def compute_metrics(self, results: list) -> dict: raise NotImplementedError( 'compute metrics should not be used here directly') def evaluate(self, size): """Evaluate the model performance of the whole dataset after processing all batches. Args: size (int): Length of the entire validation dataset. When batch size > 1, the dataloader may pad some data samples to make sure all ranks have the same length of dataset slice. The ``collect_results`` function will drop the padded data based on this size. Returns: dict: Evaluation metrics dict on the val dataset. The keys are "{task_name}_{metric_name}" , and the values are corresponding results. """ metrics = {} for task_name in self._metrics: for metric in self._metrics[task_name]: name = metric.__class__.__name__ if name == 'MultiTasksMetric' or metric.results: results = metric.evaluate(size) else: results = {metric.__class__.__name__: 0} for key in results: name = f'{task_name}_{key}' if name in results: """Inspired from https://github.com/open- mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" raise ValueError( 'There are multiple metric results with the same' f'metric name {name}. Please make sure all metrics' 'have different prefixes.') metrics[name] = results[key] return metrics