Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import json | |
from collections import OrderedDict | |
from typing import List | |
from mmengine import get_file_backend | |
from mmpretrain.registry import DATASETS | |
from .base_dataset import BaseDataset | |
class COCORetrieval(BaseDataset): | |
"""COCO Retrieval dataset. | |
Args: | |
ann_file (str): Annotation file path. | |
test_mode (bool): Whether dataset is used for evaluation. This will | |
decide the annotation format in data list annotations. | |
Defaults to False. | |
data_root (str): The root directory for ``data_prefix`` and | |
``ann_file``. Defaults to ''. | |
data_prefix (str | dict): Prefix for training data. Defaults to ''. | |
pipeline (Sequence): Processing pipeline. Defaults to an empty tuple. | |
**kwargs: Other keyword arguments in :class:`BaseDataset`. | |
""" | |
def load_data_list(self) -> List[dict]: | |
"""Load data list.""" | |
# get file backend | |
img_prefix = self.data_prefix['img_path'] | |
file_backend = get_file_backend(img_prefix) | |
anno_info = json.load(open(self.ann_file, 'r')) | |
# mapping img_id to img filename | |
img_dict = OrderedDict() | |
for idx, img in enumerate(anno_info['images']): | |
if img['id'] not in img_dict: | |
img_rel_path = img['coco_url'].rsplit('/', 2)[-2:] | |
img_path = file_backend.join_path(img_prefix, *img_rel_path) | |
# create new idx for image | |
img_dict[img['id']] = dict( | |
ori_id=img['id'], | |
image_id=idx, # will be used for evaluation | |
img_path=img_path, | |
text=[], | |
gt_text_id=[], | |
gt_image_id=[], | |
) | |
train_list = [] | |
for idx, anno in enumerate(anno_info['annotations']): | |
anno['text'] = anno.pop('caption') | |
anno['ori_id'] = anno.pop('id') | |
anno['text_id'] = idx # will be used for evaluation | |
# 1. prepare train data list item | |
train_data = anno.copy() | |
train_image = img_dict[train_data['image_id']] | |
train_data['img_path'] = train_image['img_path'] | |
train_data['image_ori_id'] = train_image['ori_id'] | |
train_data['image_id'] = train_image['image_id'] | |
train_data['is_matched'] = True | |
train_list.append(train_data) | |
# 2. prepare eval data list item based on img dict | |
img_dict[anno['image_id']]['gt_text_id'].append(anno['text_id']) | |
img_dict[anno['image_id']]['text'].append(anno['text']) | |
img_dict[anno['image_id']]['gt_image_id'].append( | |
train_image['image_id']) | |
self.img_size = len(img_dict) | |
self.text_size = len(anno_info['annotations']) | |
# return needed format data list | |
if self.test_mode: | |
return list(img_dict.values()) | |
return train_list | |