# Copyright (c) OpenMMLab. All rights reserved. import random from typing import List, Optional from mmengine.evaluator import BaseMetric from mmpretrain.registry import METRICS def get_pred_idx(prediction: str, choices: List[str], options: List[str]) -> int: # noqa """Get the index (e.g. 2) from the prediction (e.g. 'C') Args: prediction (str): The prediction from the model, from ['A', 'B', 'C', 'D', 'E'] choices (List(str)): The choices for the question, from ['A', 'B', 'C', 'D', 'E'] options (List(str)): The options for the question, from ['A', 'B', 'C', 'D', 'E'] Returns: int: The index of the prediction, from [0, 1, 2, 3, 4] """ if prediction in options[:len(choices)]: return options.index(prediction) else: return random.choice(range(len(choices))) @METRICS.register_module() class ScienceQAMetric(BaseMetric): """Evaluation Metric for ScienceQA. Args: options (List(str)): Options for each question. Defaults to ["A", "B", "C", "D", "E"]. collect_device (str): Device name used for collecting results from different ranks during distributed training. Must be 'cpu' or 'gpu'. Defaults to 'cpu'. prefix (str, optional): The prefix that will be added in the metric names to disambiguate homonymous metrics of different evaluators. If prefix is not provided in the argument, self.default_prefix will be used instead. Should be modified according to the `retrieval_type` for unambiguous results. Defaults to TR. """ def __init__(self, options: List[str] = ['A', 'B', 'C', 'D', 'E'], collect_device: str = 'cpu', prefix: Optional[str] = None) -> None: super().__init__(collect_device=collect_device, prefix=prefix) self.options = options def process(self, data_batch, data_samples) -> None: """Process one batch of data samples. data_samples should contain the following keys: 1. pred_answer (str): The prediction from the model, from ['A', 'B', 'C', 'D', 'E'] 2. choices (List(str)): The choices for the question, from ['A', 'B', 'C', 'D', 'E'] 3. grade (int): The grade for the question, from grade1 to grade12 4. subject (str): The subject for the question, from ['natural science', 'social science', 'language science'] 5. answer (str): The answer for the question, from ['A', 'B', 'C', 'D', 'E'] 6. hint (str): The hint for the question 7. has_image (bool): Whether or not the question has image The processed results should be stored in ``self.results``, which will be used to computed the metrics when all batches have been processed. Args: data_batch: A batch of data from the dataloader. data_samples (Sequence[dict]): A batch of outputs from the model. """ for data_sample in data_samples: result = dict() choices = data_sample.get('choices') result['prediction'] = get_pred_idx( data_sample.get('pred_answer'), choices, self.options) result['grade'] = data_sample.get('grade') result['subject'] = data_sample.get('subject') result['answer'] = data_sample.get('gt_answer') hint = data_sample.get('hint') has_image = data_sample.get('has_image', False) result[ 'no_context'] = True if not has_image and hint is None else False # noqa result['has_text'] = True if hint is not None else False result['has_image'] = has_image # Save the result to `self.results`. self.results.append(result) def compute_metrics(self, results: List) -> dict: """Compute the metrics from processed results. Args: results (dict): The processed results of each batch. Returns: Dict: The computed metrics. The keys are the names of the metrics, and the values are corresponding results. """ # NOTICE: don't access `self.results` from the method. metrics = dict() all_acc = [] acc_natural = [] acc_social = [] acc_language = [] acc_has_text = [] acc_has_image = [] acc_no_context = [] acc_grade_1_6 = [] acc_grade_7_12 = [] for result in results: correct = result['prediction'] == result['answer'] all_acc.append(correct) # different subjects if result['subject'] == 'natural science': acc_natural.append(correct) elif result['subject'] == 'social science': acc_social.append(correct) elif result['subject'] == 'language science': acc_language.append(correct) # different context if result['has_text']: acc_has_text.append(correct) elif result['has_image']: acc_has_image.append(correct) elif result['no_context']: acc_no_context.append(correct) # different grade if result['grade'] in [ 'grade1', 'grade2', 'grade3', 'grade4', 'grade5', 'grade6' ]: acc_grade_1_6.append(correct) elif result['grade'] in [ 'grade7', 'grade8', 'grade9', 'grade10', 'grade11', 'grade12' ]: acc_grade_7_12.append(correct) metrics['all_acc'] = sum(all_acc) / len(all_acc) if len(acc_natural) > 0: metrics['acc_natural'] = sum(acc_natural) / len(acc_natural) if len(acc_social) > 0: metrics['acc_social'] = sum(acc_social) / len(acc_social) if len(acc_language) > 0: metrics['acc_language'] = sum(acc_language) / len(acc_language) if len(acc_has_text) > 0: metrics['acc_has_text'] = sum(acc_has_text) / len(acc_has_text) if len(acc_has_image) > 0: metrics['acc_has_image'] = sum(acc_has_image) / len(acc_has_image) if len(acc_no_context) > 0: metrics['acc_no_context'] = sum(acc_no_context) / len( acc_no_context) if len(acc_grade_1_6) > 0: metrics['acc_grade_1_6'] = sum(acc_grade_1_6) / len(acc_grade_1_6) if len(acc_grade_7_12) > 0: metrics['acc_grade_7_12'] = sum(acc_grade_7_12) / len( acc_grade_7_12) return metrics