File size: 5,153 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import copy
import os.path as osp
from collections import defaultdict
from typing import Any, Dict, List

import numpy as np
from mmengine.dataset import BaseDataset
from mmengine.utils import check_file_exist

from mmdet.registry import DATASETS


@DATASETS.register_module()
class ReIDDataset(BaseDataset):
    """Dataset for ReID.

    Args:
        triplet_sampler (dict, optional): The sampler for hard mining
            triplet loss. Defaults to None.
        keys: num_ids (int): The number of person ids.
              ins_per_id (int): The number of image for each person.
    """

    def __init__(self, triplet_sampler: dict = None, *args, **kwargs):
        self.triplet_sampler = triplet_sampler
        super().__init__(*args, **kwargs)

    def load_data_list(self) -> List[dict]:
        """Load annotations from an annotation file named as ''self.ann_file''.

        Returns:
              list[dict]: A list of annotation.
        """
        assert isinstance(self.ann_file, str)
        check_file_exist(self.ann_file)
        data_list = []
        with open(self.ann_file) as f:
            samples = [x.strip().split(' ') for x in f.readlines()]
            for filename, gt_label in samples:
                info = dict(img_prefix=self.data_prefix)
                if self.data_prefix['img_path'] is not None:
                    info['img_path'] = osp.join(self.data_prefix['img_path'],
                                                filename)
                else:
                    info['img_path'] = filename
                info['gt_label'] = np.array(gt_label, dtype=np.int64)
                data_list.append(info)
        self._parse_ann_info(data_list)
        return data_list

    def _parse_ann_info(self, data_list: List[dict]):
        """Parse person id annotations."""
        index_tmp_dic = defaultdict(list)  # pid->[idx1,...,idxN]
        self.index_dic = dict()  # pid->array([idx1,...,idxN])
        for idx, info in enumerate(data_list):
            pid = info['gt_label']
            index_tmp_dic[int(pid)].append(idx)
        for pid, idxs in index_tmp_dic.items():
            self.index_dic[pid] = np.asarray(idxs, dtype=np.int64)
        self.pids = np.asarray(list(self.index_dic.keys()), dtype=np.int64)

    def prepare_data(self, idx: int) -> Any:
        """Get data processed by ''self.pipeline''.

        Args:
            idx (int): The index of ''data_info''

        Returns:
            Any: Depends on ''self.pipeline''
        """
        data_info = self.get_data_info(idx)
        if self.triplet_sampler is not None:
            img_info = self.triplet_sampling(data_info['gt_label'],
                                             **self.triplet_sampler)
            data_info = copy.deepcopy(img_info)  # triplet -> list
        else:
            data_info = copy.deepcopy(data_info)  # no triplet -> dict
        return self.pipeline(data_info)

    def triplet_sampling(self,
                         pos_pid,
                         num_ids: int = 8,
                         ins_per_id: int = 4) -> Dict:
        """Triplet sampler for hard mining triplet loss. First, for one
        pos_pid, random sample ins_per_id images with same person id.

        Then, random sample num_ids - 1 images for each negative id.
        Finally, random sample ins_per_id images for each negative id.

        Args:
            pos_pid (ndarray): The person id of the anchor.
            num_ids (int): The number of person ids.
            ins_per_id (int): The number of images for each person.

        Returns:
            Dict: Annotation information of num_ids X ins_per_id images.
        """
        assert len(self.pids) >= num_ids, \
            'The number of person ids in the training set must ' \
            'be greater than the number of person ids in the sample.'

        pos_idxs = self.index_dic[int(
            pos_pid)]  # all positive idxs for pos_pid
        idxs_list = []
        # select positive samplers
        idxs_list.extend(pos_idxs[np.random.choice(
            pos_idxs.shape[0], ins_per_id, replace=True)])
        # select negative ids
        neg_pids = np.random.choice(
            [i for i, _ in enumerate(self.pids) if i != pos_pid],
            num_ids - 1,
            replace=False)
        # select negative samplers for each negative id
        for neg_pid in neg_pids:
            neg_idxs = self.index_dic[neg_pid]
            idxs_list.extend(neg_idxs[np.random.choice(
                neg_idxs.shape[0], ins_per_id, replace=True)])
        # return the final triplet batch
        triplet_img_infos = []
        for idx in idxs_list:
            triplet_img_infos.append(copy.deepcopy(self.get_data_info(idx)))
        # Collect data_list scatters (list of dict -> dict of list)
        out = dict()
        for key in triplet_img_infos[0].keys():
            out[key] = [_info[key] for _info in triplet_img_infos]
        return out