File size: 6,188 Bytes
4d0eb62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
# Copyright (c) OpenMMLab. All rights reserved.
from mmengine import get_file_backend, list_from_file

from mmpretrain.registry import DATASETS
from .base_dataset import BaseDataset


@DATASETS.register_module()
class InShop(BaseDataset):
    """InShop Dataset for Image Retrieval.

    Please download the images from the homepage
    'https://mmlab.ie.cuhk.edu.hk/projects/DeepFashion/InShopRetrieval.html'
    (In-shop Clothes Retrieval Benchmark -> Img -> img.zip,
    Eval/list_eval_partition.txt), and organize them as follows way: ::

        In-shop Clothes Retrieval Benchmark (data_root)/
           β”œβ”€β”€ Eval /
           β”‚    └── list_eval_partition.txt (ann_file)
           β”œβ”€β”€ Img (img_prefix)
           β”‚    └── img/
           β”œβ”€β”€ README.txt
           └── .....

    Args:
        data_root (str): The root directory for dataset.
        split (str): Choose from 'train', 'query' and 'gallery'.
            Defaults to 'train'.
        data_prefix (str | dict): Prefix for training data.
            Defaults to 'Img'.
        ann_file (str): Annotation file path, path relative to
            ``data_root``. Defaults to 'Eval/list_eval_partition.txt'.
        **kwargs: Other keyword arguments in :class:`BaseDataset`.

    Examples:
        >>> from mmpretrain.datasets import InShop
        >>>
        >>> # build train InShop dataset
        >>> inshop_train_cfg = dict(data_root='data/inshop', split='train')
        >>> inshop_train = InShop(**inshop_train_cfg)
        >>> inshop_train
        Dataset InShop
            Number of samples:  25882
            The `CLASSES` meta info is not set.
            Root of dataset:    data/inshop
        >>>
        >>> # build query InShop dataset
        >>> inshop_query_cfg =  dict(data_root='data/inshop', split='query')
        >>> inshop_query = InShop(**inshop_query_cfg)
        >>> inshop_query
        Dataset InShop
            Number of samples:  14218
            The `CLASSES` meta info is not set.
            Root of dataset:    data/inshop
        >>>
        >>> # build gallery InShop dataset
        >>> inshop_gallery_cfg = dict(data_root='data/inshop', split='gallery')
        >>> inshop_gallery = InShop(**inshop_gallery_cfg)
        >>> inshop_gallery
        Dataset InShop
            Number of samples:  12612
            The `CLASSES` meta info is not set.
            Root of dataset:    data/inshop
    """

    def __init__(self,
                 data_root: str,
                 split: str = 'train',
                 data_prefix: str = 'Img',
                 ann_file: str = 'Eval/list_eval_partition.txt',
                 **kwargs):

        assert split in ('train', 'query', 'gallery'), "'split' of `InShop`" \
            f" must be one of ['train', 'query', 'gallery'], bu get '{split}'"
        self.backend = get_file_backend(data_root, enable_singleton=True)
        self.split = split
        super().__init__(
            data_root=data_root,
            data_prefix=data_prefix,
            ann_file=ann_file,
            **kwargs)

    def _process_annotations(self):
        lines = list_from_file(self.ann_file)

        anno_train = dict(metainfo=dict(), data_list=list())
        anno_gallery = dict(metainfo=dict(), data_list=list())

        # item_id to label, each item corresponds to one class label
        class_num = 0
        gt_label_train = {}

        # item_id to label, each label corresponds to several items
        gallery_num = 0
        gt_label_gallery = {}

        # (lines[0], lines[1]) is the image number and the field name;
        # Each line format as 'image_name, item_id, evaluation_status'
        for line in lines[2:]:
            img_name, item_id, status = line.split()
            img_path = self.backend.join_path(self.img_prefix, img_name)
            if status == 'train':
                if item_id not in gt_label_train:
                    gt_label_train[item_id] = class_num
                    class_num += 1
                # item_id to class_id (for the training set)
                anno_train['data_list'].append(
                    dict(img_path=img_path, gt_label=gt_label_train[item_id]))
            elif status == 'gallery':
                if item_id not in gt_label_gallery:
                    gt_label_gallery[item_id] = []
                # Since there are multiple images for each item,
                # record the corresponding item for each image.
                gt_label_gallery[item_id].append(gallery_num)
                anno_gallery['data_list'].append(
                    dict(img_path=img_path, sample_idx=gallery_num))
                gallery_num += 1

        if self.split == 'train':
            anno_train['metainfo']['class_number'] = class_num
            anno_train['metainfo']['sample_number'] = \
                len(anno_train['data_list'])
            return anno_train
        elif self.split == 'gallery':
            anno_gallery['metainfo']['sample_number'] = gallery_num
            return anno_gallery

        # Generate the label for the query(val) set
        anno_query = dict(metainfo=dict(), data_list=list())
        query_num = 0
        for line in lines[2:]:
            img_name, item_id, status = line.split()
            img_path = self.backend.join_path(self.img_prefix, img_name)
            if status == 'query':
                anno_query['data_list'].append(
                    dict(
                        img_path=img_path, gt_label=gt_label_gallery[item_id]))
                query_num += 1

        anno_query['metainfo']['sample_number'] = query_num
        return anno_query

    def load_data_list(self):
        """load data list.

        For the train set, return image and ground truth label. For the query
        set, return image and ids of images in gallery. For the gallery set,
        return image and its id.
        """
        data_info = self._process_annotations()
        data_list = data_info['data_list']
        return data_list

    def extra_repr(self):
        """The extra repr information of the dataset."""
        body = [f'Root of dataset: \t{self.data_root}']
        return body