Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from collections import Counter | |
from typing import List | |
import mmengine | |
from mmengine.dataset import BaseDataset | |
from mmpretrain.registry import DATASETS | |
class VizWiz(BaseDataset): | |
"""VizWiz dataset. | |
Args: | |
data_root (str): The root directory for ``data_prefix``, ``ann_file`` | |
and ``question_file``. | |
data_prefix (str): The directory of images. | |
ann_file (str, optional): Annotation file path for training and | |
validation. Defaults to an empty string. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def __init__(self, | |
data_root: str, | |
data_prefix: str, | |
ann_file: str = '', | |
**kwarg): | |
super().__init__( | |
data_root=data_root, | |
data_prefix=dict(img_path=data_prefix), | |
ann_file=ann_file, | |
**kwarg, | |
) | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
annotations = mmengine.load(self.ann_file) | |
data_list = [] | |
for ann in annotations: | |
# { | |
# "image": "VizWiz_val_00000001.jpg", | |
# "question": "Can you tell me what this medicine is please?", | |
# "answers": [ | |
# { | |
# "answer": "no", | |
# "answer_confidence": "yes" | |
# }, | |
# { | |
# "answer": "unanswerable", | |
# "answer_confidence": "yes" | |
# }, | |
# { | |
# "answer": "night time", | |
# "answer_confidence": "maybe" | |
# }, | |
# { | |
# "answer": "unanswerable", | |
# "answer_confidence": "yes" | |
# }, | |
# { | |
# "answer": "night time", | |
# "answer_confidence": "maybe" | |
# }, | |
# { | |
# "answer": "night time cold medicine", | |
# "answer_confidence": "maybe" | |
# }, | |
# { | |
# "answer": "night time", | |
# "answer_confidence": "maybe" | |
# }, | |
# { | |
# "answer": "night time", | |
# "answer_confidence": "maybe" | |
# }, | |
# { | |
# "answer": "night time", | |
# "answer_confidence": "maybe" | |
# }, | |
# { | |
# "answer": "night time medicine", | |
# "answer_confidence": "yes" | |
# } | |
# ], | |
# "answer_type": "other", | |
# "answerable": 1 | |
# }, | |
data_info = dict() | |
data_info['question'] = ann['question'] | |
data_info['img_path'] = mmengine.join_path( | |
self.data_prefix['img_path'], ann['image']) | |
if 'answerable' not in ann: | |
data_list.append(data_info) | |
else: | |
if ann['answerable'] == 1: | |
# add answer_weight & answer_count, delete duplicate answer | |
answers = [] | |
for item in ann.pop('answers'): | |
if item['answer_confidence'] == 'yes' and item[ | |
'answer'] != 'unanswerable': | |
answers.append(item['answer']) | |
count = Counter(answers) | |
answer_weight = [i / len(answers) for i in count.values()] | |
data_info['gt_answer'] = list(count.keys()) | |
data_info['gt_answer_weight'] = answer_weight | |
# data_info.update(ann) | |
data_list.append(data_info) | |
return data_list | |