Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import re | |
| from collections import Counter | |
| from typing import List | |
| import mmengine | |
| from mmengine.dataset import BaseDataset | |
| from mmpretrain.registry import DATASETS | |
| class COCOVQA(BaseDataset): | |
| """VQAv2 dataset. | |
| Args: | |
| data_root (str): The root directory for ``data_prefix``, ``ann_file`` | |
| and ``question_file``. | |
| data_prefix (str): The directory of images. | |
| question_file (str): Question file path. | |
| ann_file (str, optional): Annotation file path for training and | |
| validation. Defaults to an empty string. | |
| **kwargs: Other keyword arguments in :class:`BaseDataset`. | |
| """ | |
| def __init__(self, | |
| data_root: str, | |
| data_prefix: str, | |
| question_file: str, | |
| ann_file: str = '', | |
| **kwarg): | |
| self.question_file = question_file | |
| super().__init__( | |
| data_root=data_root, | |
| data_prefix=dict(img_path=data_prefix), | |
| ann_file=ann_file, | |
| **kwarg, | |
| ) | |
| def _join_prefix(self): | |
| if not mmengine.is_abs(self.question_file) and self.question_file: | |
| self.question_file = osp.join(self.data_root, self.question_file) | |
| return super()._join_prefix() | |
| def _create_image_index(self): | |
| img_prefix = self.data_prefix['img_path'] | |
| files = mmengine.list_dir_or_file(img_prefix, list_dir=False) | |
| image_index = {} | |
| for file in files: | |
| image_id = re.findall(r'\d{12}', file) | |
| if len(image_id) > 0: | |
| image_id = int(image_id[-1]) | |
| image_index[image_id] = mmengine.join_path(img_prefix, file) | |
| return image_index | |
| def load_data_list(self) -> List[dict]: | |
| """Load data list.""" | |
| questions = mmengine.load(self.question_file)['questions'] | |
| if self.ann_file: | |
| annotations = mmengine.load(self.ann_file)['annotations'] | |
| assert len(questions) == len(annotations) | |
| else: | |
| annotations = [None] * len(questions) | |
| # The original VQAv2 annotation file and question file includes | |
| # only image id but no image file paths. | |
| self.image_index = self._create_image_index() | |
| data_list = [] | |
| for question, ann in zip(questions, annotations): | |
| # question example | |
| # { | |
| # 'image_id': 262144, | |
| # 'question': "Is the ball flying towards the batter?", | |
| # 'question_id': 262144000 | |
| # } | |
| # | |
| # ann example | |
| # { | |
| # 'question_type': "what are the", | |
| # 'answer_type': "other", | |
| # 'answers': [ | |
| # {'answer': 'watching', | |
| # 'answer_id': 1, | |
| # 'answer_confidence': 'yes'}, | |
| # ... | |
| # ], | |
| # 'image_id': 262148, | |
| # 'question_id': 262148000, | |
| # 'multiple_choice_answer': 'watching', | |
| # 'answer_type': 'other', | |
| # } | |
| data_info = question | |
| data_info['img_path'] = self.image_index[question['image_id']] | |
| if ann is not None: | |
| assert ann['question_id'] == question['question_id'] | |
| # add answer_weight & answer_count, delete duplicate answer | |
| answers = [item['answer'] for item in ann.pop('answers')] | |
| count = Counter(answers) | |
| answer_weight = [i / len(answers) for i in count.values()] | |
| data_info['gt_answer'] = list(count.keys()) | |
| data_info['gt_answer_weight'] = answer_weight | |
| data_info.update(ann) | |
| data_list.append(data_info) | |
| return data_list | |