# Copyright (c) OpenMMLab. All rights reserved. from collections import OrderedDict from typing import List import mmengine from mmengine import get_file_backend from mmpretrain.registry import DATASETS from .base_dataset import BaseDataset @DATASETS.register_module() class Flickr30kRetrieval(BaseDataset): """Flickr30k Retrieval dataset. Args: data_root (str): The root directory for ``data_prefix``, ``ann_file`` and ``question_file``. data_prefix (str): The directory of images. ann_file (str): Annotation file path for training and validation. split (str): 'train', 'val' or 'test'. **kwargs: Other keyword arguments in :class:`BaseDataset`. """ def __init__(self, data_root: str, data_prefix: str, ann_file: str, split: str, **kwarg): assert split in ['train', 'val', 'test'], \ '`split` must be train, val or test' self.split = split super().__init__( data_root=data_root, data_prefix=dict(img_path=data_prefix), ann_file=ann_file, **kwarg, ) def load_data_list(self) -> List[dict]: """Load data list.""" # get file backend img_prefix = self.data_prefix['img_path'] file_backend = get_file_backend(img_prefix) annotations = mmengine.load(self.ann_file) # mapping img_id to img filename img_dict = OrderedDict() img_idx = 0 sentence_idx = 0 train_list = [] for img in annotations['images']: # img_example={ # "sentids": [0, 1, 2], # "imgid": 0, # "sentences": [ # {"raw": "Two men in green shirts standing in a yard.", # "imgid": 0, "sentid": 0}, # {"raw": "A man in a blue shirt standing in a garden.", # "imgid": 0, "sentid": 1}, # {"raw": "Two friends enjoy time spent together.", # "imgid": 0, "sentid": 2} # ], # "split": "train", # "filename": "1000092795.jpg" # }, if img['split'] != self.split: continue # create new idx for image train_image = dict( ori_id=img['imgid'], image_id=img_idx, # used for evaluation img_path=file_backend.join_path(img_prefix, img['filename']), text=[], gt_text_id=[], gt_image_id=[], ) for sentence in img['sentences']: ann = {} ann['text'] = sentence['raw'] ann['ori_id'] = sentence['sentid'] ann['text_id'] = sentence_idx # used for evaluation ann['image_ori_id'] = train_image['ori_id'] ann['image_id'] = train_image['image_id'] ann['img_path'] = train_image['img_path'] ann['is_matched'] = True # 1. prepare train data list item train_list.append(ann) # 2. prepare eval data list item based on img dict train_image['text'].append(ann['text']) train_image['gt_text_id'].append(ann['text_id']) train_image['gt_image_id'].append(ann['image_id']) sentence_idx += 1 img_dict[img['imgid']] = train_image img_idx += 1 self.img_size = len(img_dict) self.text_size = len(train_list) # return needed format data list if self.test_mode: return list(img_dict.values()) return train_list