Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import os.path as osp | |
import re | |
from collections import Counter | |
from typing import List | |
import mmengine | |
from mmengine.dataset import BaseDataset | |
from mmpretrain.registry import DATASETS | |
class COCOVQA(BaseDataset): | |
"""VQAv2 dataset. | |
Args: | |
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | |
and ``question_file``. | |
data_prefix (str): The directory of images. | |
question_file (str): Question file path. | |
ann_file (str, optional): Annotation file path for training and | |
validation. Defaults to an empty string. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
data_prefix: str, | |
question_file: str, | |
ann_file: str = '', | |
**kwarg): | |
self.question_file = question_file | |
super().__init__( | |
data_root=data_root, | |
data_prefix=dict(img_path=data_prefix), | |
ann_file=ann_file, | |
**kwarg, | |
) | |
def _join_prefix(self): | |
if not mmengine.is_abs(self.question_file) and self.question_file: | |
self.question_file = osp.join(self.data_root, self.question_file) | |
return super()._join_prefix() | |
def _create_image_index(self): | |
img_prefix = self.data_prefix['img_path'] | |
files = mmengine.list_dir_or_file(img_prefix, list_dir=False) | |
image_index = {} | |
for file in files: | |
image_id = re.findall(r'\d{12}', file) | |
if len(image_id) > 0: | |
image_id = int(image_id[-1]) | |
image_index[image_id] = mmengine.join_path(img_prefix, file) | |
return image_index | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
questions = mmengine.load(self.question_file)['questions'] | |
if self.ann_file: | |
annotations = mmengine.load(self.ann_file)['annotations'] | |
assert len(questions) == len(annotations) | |
else: | |
annotations = [None] * len(questions) | |
# The original VQAv2 annotation file and question file includes | |
# only image id but no image file paths. | |
self.image_index = self._create_image_index() | |
data_list = [] | |
for question, ann in zip(questions, annotations): | |
# question example | |
# { | |
# 'image_id': 262144, | |
# 'question': "Is the ball flying towards the batter?", | |
# 'question_id': 262144000 | |
# } | |
# | |
# ann example | |
# { | |
# 'question_type': "what are the", | |
# 'answer_type': "other", | |
# 'answers': [ | |
# {'answer': 'watching', | |
# 'answer_id': 1, | |
# 'answer_confidence': 'yes'}, | |
# ... | |
# ], | |
# 'image_id': 262148, | |
# 'question_id': 262148000, | |
# 'multiple_choice_answer': 'watching', | |
# 'answer_type': 'other', | |
# } | |
data_info = question | |
data_info['img_path'] = self.image_index[question['image_id']] | |
if ann is not None: | |
assert ann['question_id'] == question['question_id'] | |
# add answer_weight & answer_count, delete duplicate answer | |
answers = [item['answer'] for item in ann.pop('answers')] | |
count = Counter(answers) | |
answer_weight = [i / len(answers) for i in count.values()] | |
data_info['gt_answer'] = list(count.keys()) | |
data_info['gt_answer_weight'] = answer_weight | |
data_info.update(ann) | |
data_list.append(data_info) | |
return data_list | |