Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import List, Optional | |
| from mmengine.evaluator import BaseMetric | |
| from mmpretrain.evaluation.metrics.vqa import (_process_digit_article, | |
| _process_punctuation) | |
| from mmpretrain.registry import METRICS | |
| class GQAAcc(BaseMetric): | |
| """GQA Acc metric. | |
| Compute GQA accuracy. | |
| Args: | |
| collect_device (str): Device name used for collecting results from | |
| different ranks during distributed training. Must be 'cpu' or | |
| 'gpu'. Defaults to 'cpu'. | |
| prefix (str, optional): The prefix that will be added in the metric | |
| names to disambiguate homonymous metrics of different evaluators. | |
| If prefix is not provided in the argument, self.default_prefix | |
| will be used instead. Should be modified according to the | |
| `retrieval_type` for unambiguous results. Defaults to TR. | |
| """ | |
| default_prefix = 'GQA' | |
| def __init__(self, | |
| collect_device: str = 'cpu', | |
| prefix: Optional[str] = None) -> None: | |
| super().__init__(collect_device=collect_device, prefix=prefix) | |
| def process(self, data_batch, data_samples) -> None: | |
| """Process one batch of data samples. | |
| The processed results should be stored in ``self.results``, which will | |
| be used to computed the metrics when all batches have been processed. | |
| Args: | |
| data_batch: A batch of data from the dataloader. | |
| data_samples (Sequence[dict]): A batch of outputs from the model. | |
| """ | |
| for sample in data_samples: | |
| gt_answer = sample.get('gt_answer') | |
| result = { | |
| 'pred_answer': sample.get('pred_answer'), | |
| 'gt_answer': gt_answer | |
| } | |
| self.results.append(result) | |
| def compute_metrics(self, results: List) -> dict: | |
| """Compute the metrics from processed results. | |
| Args: | |
| results (dict): The processed results of each batch. | |
| Returns: | |
| Dict: The computed metrics. The keys are the names of the metrics, | |
| and the values are corresponding results. | |
| """ | |
| acc = [] | |
| for result in results: | |
| pred_answer = self._process_answer(result['pred_answer']) | |
| gt_answer = self._process_answer(result['gt_answer']) | |
| gqa_acc = 1 if pred_answer == gt_answer else 0 | |
| acc.append(gqa_acc) | |
| accuracy = sum(acc) / len(acc) | |
| metrics = {'acc': accuracy} | |
| return metrics | |
| def _process_answer(self, answer) -> str: | |
| answer = _process_punctuation(answer) | |
| answer = _process_digit_article(answer) | |
| return answer | |