Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Dict, Sequence | |
| from mmengine.evaluator import BaseMetric | |
| from mmpretrain.registry import METRICS | |
| class MultiTasksMetric(BaseMetric): | |
| """Metrics for MultiTask | |
| Args: | |
| task_metrics(dict): a dictionary in the keys are the names of the tasks | |
| and the values is a list of the metric corresponds to this task | |
| Examples: | |
| >>> import torch | |
| >>> from mmpretrain.evaluation import MultiTasksMetric | |
| # -------------------- The Basic Usage -------------------- | |
| >>>task_metrics = { | |
| 'task0': [dict(type='Accuracy', topk=(1, ))], | |
| 'task1': [dict(type='Accuracy', topk=(1, 3))] | |
| } | |
| >>>pred = [{ | |
| 'pred_task': { | |
| 'task0': torch.tensor([0.7, 0.0, 0.3]), | |
| 'task1': torch.tensor([0.5, 0.2, 0.3]) | |
| }, | |
| 'gt_task': { | |
| 'task0': torch.tensor(0), | |
| 'task1': torch.tensor(2) | |
| } | |
| }, { | |
| 'pred_task': { | |
| 'task0': torch.tensor([0.0, 0.0, 1.0]), | |
| 'task1': torch.tensor([0.0, 0.0, 1.0]) | |
| }, | |
| 'gt_task': { | |
| 'task0': torch.tensor(2), | |
| 'task1': torch.tensor(2) | |
| } | |
| }] | |
| >>>metric = MultiTasksMetric(task_metrics) | |
| >>>metric.process(None, pred) | |
| >>>results = metric.evaluate(2) | |
| results = { | |
| 'task0_accuracy/top1': 100.0, | |
| 'task1_accuracy/top1': 50.0, | |
| 'task1_accuracy/top3': 100.0 | |
| } | |
| """ | |
| def __init__(self, | |
| task_metrics: Dict, | |
| collect_device: str = 'cpu') -> None: | |
| self.task_metrics = task_metrics | |
| super().__init__(collect_device=collect_device) | |
| self._metrics = {} | |
| for task_name in self.task_metrics.keys(): | |
| self._metrics[task_name] = [] | |
| for metric in self.task_metrics[task_name]: | |
| self._metrics[task_name].append(METRICS.build(metric)) | |
| def process(self, data_batch, data_samples: Sequence[dict]): | |
| """Process one batch of data samples. | |
| The processed results should be stored in ``self.results``, which will | |
| be used to computed the metrics when all batches have been processed. | |
| Args: | |
| data_batch: A batch of data from the dataloader. | |
| data_samples (Sequence[dict]): A batch of outputs from the model. | |
| """ | |
| for task_name in self.task_metrics.keys(): | |
| filtered_data_samples = [] | |
| for data_sample in data_samples: | |
| eval_mask = data_sample[task_name]['eval_mask'] | |
| if eval_mask: | |
| filtered_data_samples.append(data_sample[task_name]) | |
| for metric in self._metrics[task_name]: | |
| metric.process(data_batch, filtered_data_samples) | |
| def compute_metrics(self, results: list) -> dict: | |
| raise NotImplementedError( | |
| 'compute metrics should not be used here directly') | |
| def evaluate(self, size): | |
| """Evaluate the model performance of the whole dataset after processing | |
| all batches. | |
| Args: | |
| size (int): Length of the entire validation dataset. When batch | |
| size > 1, the dataloader may pad some data samples to make | |
| sure all ranks have the same length of dataset slice. The | |
| ``collect_results`` function will drop the padded data based on | |
| this size. | |
| Returns: | |
| dict: Evaluation metrics dict on the val dataset. The keys are | |
| "{task_name}_{metric_name}" , and the values | |
| are corresponding results. | |
| """ | |
| metrics = {} | |
| for task_name in self._metrics: | |
| for metric in self._metrics[task_name]: | |
| name = metric.__class__.__name__ | |
| if name == 'MultiTasksMetric' or metric.results: | |
| results = metric.evaluate(size) | |
| else: | |
| results = {metric.__class__.__name__: 0} | |
| for key in results: | |
| name = f'{task_name}_{key}' | |
| if name in results: | |
| """Inspired from https://github.com/open- | |
| mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 | |
| 747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" | |
| raise ValueError( | |
| 'There are multiple metric results with the same' | |
| f'metric name {name}. Please make sure all metrics' | |
| 'have different prefixes.') | |
| metrics[name] = results[key] | |
| return metrics | |