# Copyright (c) OpenMMLab. All rights reserved. import os from typing import Callable, List, Sequence import mmengine from mmengine.dataset import BaseDataset from mmengine.fileio import get_file_backend from mmpretrain.registry import DATASETS @DATASETS.register_module() class ScienceQA(BaseDataset): """ScienceQA dataset. This dataset is used to load the multimodal data of ScienceQA dataset. Args: data_root (str): The root directory for ``data_prefix`` and ``ann_file``. split (str): The split of dataset. Options: ``train``, ``val``, ``test``, ``trainval``, ``minival``, and ``minitest``. split_file (str): The split file of dataset, which contains the ids of data samples in the split. ann_file (str): Annotation file path. image_only (bool): Whether only to load data with image. Defaults to False. data_prefix (dict): Prefix for data field. Defaults to ``dict(img_path='')``. pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. **kwargs: Other keyword arguments in :class:`BaseDataset`. """ def __init__(self, data_root: str, split: str, split_file: str, ann_file: str, image_only: bool = False, data_prefix: dict = dict(img_path=''), pipeline: Sequence[Callable] = (), **kwargs): assert split in [ 'train', 'val', 'test', 'trainval', 'minival', 'minitest' ], f'Invalid split {split}' self.split = split self.split_file = os.path.join(data_root, split_file) self.image_only = image_only super().__init__( data_root=data_root, ann_file=ann_file, data_prefix=data_prefix, pipeline=pipeline, **kwargs) def load_data_list(self) -> List[dict]: """Load data list.""" img_prefix = self.data_prefix['img_path'] annotations = mmengine.load(self.ann_file) current_data_split = mmengine.load(self.split_file)[self.split] # noqa file_backend = get_file_backend(img_prefix) data_list = [] for data_id in current_data_split: ann = annotations[data_id] if self.image_only and ann['image'] is None: continue data_info = { 'image_id': data_id, 'question': ann['question'], 'choices': ann['choices'], 'gt_answer': ann['answer'], 'hint': ann['hint'], 'image_name': ann['image'], 'task': ann['task'], 'grade': ann['grade'], 'subject': ann['subject'], 'topic': ann['topic'], 'category': ann['category'], 'skill': ann['skill'], 'lecture': ann['lecture'], 'solution': ann['solution'], 'split': ann['split'], 'img_path': file_backend.join_path(img_prefix, data_id, ann['image']) if ann['image'] is not None else None, 'has_image': True if ann['image'] is not None else False, } data_list.append(data_info) return data_list