Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Dict, Sequence | |
from mmengine.evaluator import BaseMetric | |
from mmpretrain.registry import METRICS | |
class MultiTasksMetric(BaseMetric): | |
"""Metrics for MultiTask | |
Args: | |
task_metrics(dict): a dictionary in the keys are the names of the tasks | |
and the values is a list of the metric corresponds to this task | |
Examples: | |
>>> import torch | |
>>> from mmpretrain.evaluation import MultiTasksMetric | |
# -------------------- The Basic Usage -------------------- | |
>>>task_metrics = { | |
'task0': [dict(type='Accuracy', topk=(1, ))], | |
'task1': [dict(type='Accuracy', topk=(1, 3))] | |
} | |
>>>pred = [{ | |
'pred_task': { | |
'task0': torch.tensor([0.7, 0.0, 0.3]), | |
'task1': torch.tensor([0.5, 0.2, 0.3]) | |
}, | |
'gt_task': { | |
'task0': torch.tensor(0), | |
'task1': torch.tensor(2) | |
} | |
}, { | |
'pred_task': { | |
'task0': torch.tensor([0.0, 0.0, 1.0]), | |
'task1': torch.tensor([0.0, 0.0, 1.0]) | |
}, | |
'gt_task': { | |
'task0': torch.tensor(2), | |
'task1': torch.tensor(2) | |
} | |
}] | |
>>>metric = MultiTasksMetric(task_metrics) | |
>>>metric.process(None, pred) | |
>>>results = metric.evaluate(2) | |
results = { | |
'task0_accuracy/top1': 100.0, | |
'task1_accuracy/top1': 50.0, | |
'task1_accuracy/top3': 100.0 | |
} | |
""" | |
def __init__(self, | |
task_metrics: Dict, | |
collect_device: str = 'cpu') -> None: | |
self.task_metrics = task_metrics | |
super().__init__(collect_device=collect_device) | |
self._metrics = {} | |
for task_name in self.task_metrics.keys(): | |
self._metrics[task_name] = [] | |
for metric in self.task_metrics[task_name]: | |
self._metrics[task_name].append(METRICS.build(metric)) | |
def process(self, data_batch, data_samples: Sequence[dict]): | |
"""Process one batch of data samples. | |
The processed results should be stored in ``self.results``, which will | |
be used to computed the metrics when all batches have been processed. | |
Args: | |
data_batch: A batch of data from the dataloader. | |
data_samples (Sequence[dict]): A batch of outputs from the model. | |
""" | |
for task_name in self.task_metrics.keys(): | |
filtered_data_samples = [] | |
for data_sample in data_samples: | |
eval_mask = data_sample[task_name]['eval_mask'] | |
if eval_mask: | |
filtered_data_samples.append(data_sample[task_name]) | |
for metric in self._metrics[task_name]: | |
metric.process(data_batch, filtered_data_samples) | |
def compute_metrics(self, results: list) -> dict: | |
raise NotImplementedError( | |
'compute metrics should not be used here directly') | |
def evaluate(self, size): | |
"""Evaluate the model performance of the whole dataset after processing | |
all batches. | |
Args: | |
size (int): Length of the entire validation dataset. When batch | |
size > 1, the dataloader may pad some data samples to make | |
sure all ranks have the same length of dataset slice. The | |
``collect_results`` function will drop the padded data based on | |
this size. | |
Returns: | |
dict: Evaluation metrics dict on the val dataset. The keys are | |
"{task_name}_{metric_name}" , and the values | |
are corresponding results. | |
""" | |
metrics = {} | |
for task_name in self._metrics: | |
for metric in self._metrics[task_name]: | |
name = metric.__class__.__name__ | |
if name == 'MultiTasksMetric' or metric.results: | |
results = metric.evaluate(size) | |
else: | |
results = {metric.__class__.__name__: 0} | |
for key in results: | |
name = f'{task_name}_{key}' | |
if name in results: | |
"""Inspired from https://github.com/open- | |
mmlab/mmengine/ bl ob/ed20a9cba52ceb371f7c825131636b9e2 | |
747172e/mmengine/evalua tor/evaluator.py#L84-L87.""" | |
raise ValueError( | |
'There are multiple metric results with the same' | |
f'metric name {name}. Please make sure all metrics' | |
'have different prefixes.') | |
metrics[name] = results[key] | |
return metrics | |