File size: 6,270 Bytes
8a6df40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
import math
from torchvision import transforms
from PIL import Image

__all__ = ['cusSampler','Sampler_uni']

'''common N-pairs sampler'''
def index_dataset(dataset):
    '''
    get the index according to the dataset type(e.g. pascal or atr or cihp)
    :param dataset:
    :return:
    '''
    return_dict = {}
    for i in range(len(dataset)):
        tmp_lbl = dataset.datasets_lbl[i]
        if tmp_lbl in return_dict:
            return_dict[tmp_lbl].append(i)
        else :
            return_dict[tmp_lbl] = [i]
    return return_dict

def sample_from_class(dataset,class_id):
    return dataset[class_id][random.randrange(len(dataset[class_id]))]

def sampler_npair_K(batch_size,dataset,K=2,label_random_list = [0,0,1,1,2,2,2]):
    images_by_class = index_dataset(dataset)
    for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))):
        example_indices = [sample_from_class(images_by_class, class_label_ind) for _ in range(batch_size)
                           for class_label_ind in [label_random_list[random.randrange(len(label_random_list))]]
                           ]
        yield example_indices[:batch_size]

def sampler_(images_by_class,batch_size,dataset,K=2,label_random_list = [0,0,1,1,]):
    # images_by_class = index_dataset(dataset)
    a = label_random_list[random.randrange(len(label_random_list))]
    # print(a)
    example_indices = [sample_from_class(images_by_class, a) for _ in range(batch_size)
                           for class_label_ind in [a]
                           ]
    return example_indices[:batch_size]

class cusSampler(torch.utils.data.sampler.Sampler):
    r"""Samples elements randomly from a given list of indices, without replacement.

    Arguments:
        indices (sequence): a sequence of indices
    """

    def __init__(self, dataset, batchsize, label_random_list=[0,1,1,1,2,2,2]):
        self.images_by_class = index_dataset(dataset)
        self.batch_size = batchsize
        self.dataset = dataset
        self.label_random_list = label_random_list
        self.len = int(math.ceil(len(dataset) * 1.0 / batchsize))

    def __iter__(self):
        # return [sample_from_class(self.images_by_class, class_label_ind) for _ in range(self.batchsize)
        #                    for class_label_ind in [self.label_random_list[random.randrange(len(self.label_random_list))]]
        #                    ]
        # print(sampler_(self.images_by_class,self.batch_size,self.dataset))
        return iter(sampler_(self.images_by_class,self.batch_size,self.dataset,self.label_random_list))

    def __len__(self):
        return self.len

def shuffle_cus(d1=20,d2=10,d3=5,batch=2):
    return_list = []
    total_num = d1 + d2 + d3
    list1 = list(range(d1))
    batch1 = d1//batch
    list2 = list(range(d1,d1+d2))
    batch2 = d2//batch
    list3 = list(range(d1+d2,d1+d2+d3))
    batch3 = d3// batch
    random.shuffle(list1)
    random.shuffle(list2)
    random.shuffle(list3)
    random_list = list(range(batch1+batch2+batch3))
    random.shuffle(random_list)
    for random_batch_index in random_list:
        if random_batch_index < batch1:
            random_batch_index1 = random_batch_index
            return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch]
        elif random_batch_index < batch1 + batch2:
            random_batch_index1 = random_batch_index - batch1
            return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch]
        else:
            random_batch_index1 = random_batch_index - batch1 - batch2
            return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch]
    return return_list

def shuffle_cus_balance(d1=20,d2=10,d3=5,batch=2,balance_index=1):
    return_list = []
    total_num = d1 + d2 + d3
    list1 = list(range(d1))
    # batch1 = d1//batch
    list2 = list(range(d1,d1+d2))
    # batch2 = d2//batch
    list3 = list(range(d1+d2,d1+d2+d3))
    # batch3 = d3// batch
    random.shuffle(list1)
    random.shuffle(list2)
    random.shuffle(list3)
    total_list = [list1,list2,list3]
    target_list = total_list[balance_index]
    for index,list_item in enumerate(total_list):
        if index == balance_index:
            continue
        if len(list_item) > len(target_list):
            list_item = list_item[:len(target_list)]
            total_list[index] = list_item
    list1 = total_list[0]
    list2 = total_list[1]
    list3 = total_list[2]
    # list1 = list(range(d1))
    d1 = len(list1)
    batch1 = d1 // batch
    # list2 = list(range(d1, d1 + d2))
    d2 = len(list2)
    batch2 = d2 // batch
    # list3 = list(range(d1 + d2, d1 + d2 + d3))
    d3 = len(list3)
    batch3 = d3 // batch

    random_list = list(range(batch1+batch2+batch3))
    random.shuffle(random_list)
    for random_batch_index in random_list:
        if random_batch_index < batch1:
            random_batch_index1 = random_batch_index
            return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch]
        elif random_batch_index < batch1 + batch2:
            random_batch_index1 = random_batch_index - batch1
            return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch]
        else:
            random_batch_index1 = random_batch_index - batch1 - batch2
            return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch]
    return return_list

class Sampler_uni(torch.utils.data.sampler.Sampler):
    def __init__(self, num1, num2, num3, batchsize,balance_id=None):
        self.num1 = num1
        self.num2 = num2
        self.num3 = num3
        self.batchsize = batchsize
        self.balance_id = balance_id

    def __iter__(self):
        if self.balance_id is not None:
            rlist = shuffle_cus_balance(self.num1, self.num2, self.num3, self.batchsize, balance_index=self.balance_id)
        else:
            rlist = shuffle_cus(self.num1, self.num2, self.num3, self.batchsize)
        return iter(rlist)


    def __len__(self):
        if self.balance_id is not None:
            return self.num1*3
        return self.num1+self.num2+self.num3