Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import re | |
from itertools import chain | |
from typing import List | |
import mmengine | |
from mmengine.dataset import BaseDataset | |
from mmpretrain.registry import DATASETS | |
class VisualGenomeQA(BaseDataset): | |
"""Visual Genome Question Answering dataset. | |
dataset structure: :: | |
data_root | |
βββ image | |
βΒ Β βββ 1.jpg | |
βΒ Β βββ 2.jpg | |
βΒ Β βββ ... | |
βββ question_answers.json | |
Args: | |
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | |
and ``question_file``. | |
data_prefix (str): The directory of images. Defaults to ``"image"``. | |
ann_file (str, optional): Annotation file path for training and | |
validation. Defaults to ``"question_answers.json"``. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
data_prefix: str = 'image', | |
ann_file: str = 'question_answers.json', | |
**kwarg): | |
super().__init__( | |
data_root=data_root, | |
data_prefix=dict(img_path=data_prefix), | |
ann_file=ann_file, | |
**kwarg, | |
) | |
def _create_image_index(self): | |
img_prefix = self.data_prefix['img_path'] | |
files = mmengine.list_dir_or_file(img_prefix, list_dir=False) | |
image_index = {} | |
for file in files: | |
image_id = re.findall(r'\d+', file) | |
if len(image_id) > 0: | |
image_id = int(image_id[-1]) | |
image_index[image_id] = mmengine.join_path(img_prefix, file) | |
return image_index | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
annotations = mmengine.load(self.ann_file) | |
# The original Visual Genome annotation file and question file includes | |
# only image id but no image file paths. | |
self.image_index = self._create_image_index() | |
data_list = [] | |
for qas in chain.from_iterable(ann['qas'] for ann in annotations): | |
# ann example | |
# { | |
# 'id': 1, | |
# 'qas': [ | |
# { | |
# 'a_objects': [], | |
# 'question': 'What color is the clock?', | |
# 'image_id': 1, | |
# 'qa_id': 986768, | |
# 'answer': 'Two.', | |
# 'q_objects': [], | |
# } | |
# ... | |
# ] | |
# } | |
data_info = { | |
'img_path': self.image_index[qas['image_id']], | |
'quesiton': qas['quesiton'], | |
'question_id': qas['question_id'], | |
'image_id': qas['image_id'], | |
'gt_answer': [qas['answer']], | |
} | |
data_list.append(data_info) | |
return data_list | |