Spaces:
Running
Running
File size: 6,270 Bytes
8a6df40 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
import math
from torchvision import transforms
from PIL import Image
__all__ = ['cusSampler','Sampler_uni']
'''common N-pairs sampler'''
def index_dataset(dataset):
'''
get the index according to the dataset type(e.g. pascal or atr or cihp)
:param dataset:
:return:
'''
return_dict = {}
for i in range(len(dataset)):
tmp_lbl = dataset.datasets_lbl[i]
if tmp_lbl in return_dict:
return_dict[tmp_lbl].append(i)
else :
return_dict[tmp_lbl] = [i]
return return_dict
def sample_from_class(dataset,class_id):
return dataset[class_id][random.randrange(len(dataset[class_id]))]
def sampler_npair_K(batch_size,dataset,K=2,label_random_list = [0,0,1,1,2,2,2]):
images_by_class = index_dataset(dataset)
for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))):
example_indices = [sample_from_class(images_by_class, class_label_ind) for _ in range(batch_size)
for class_label_ind in [label_random_list[random.randrange(len(label_random_list))]]
]
yield example_indices[:batch_size]
def sampler_(images_by_class,batch_size,dataset,K=2,label_random_list = [0,0,1,1,]):
# images_by_class = index_dataset(dataset)
a = label_random_list[random.randrange(len(label_random_list))]
# print(a)
example_indices = [sample_from_class(images_by_class, a) for _ in range(batch_size)
for class_label_ind in [a]
]
return example_indices[:batch_size]
class cusSampler(torch.utils.data.sampler.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, dataset, batchsize, label_random_list=[0,1,1,1,2,2,2]):
self.images_by_class = index_dataset(dataset)
self.batch_size = batchsize
self.dataset = dataset
self.label_random_list = label_random_list
self.len = int(math.ceil(len(dataset) * 1.0 / batchsize))
def __iter__(self):
# return [sample_from_class(self.images_by_class, class_label_ind) for _ in range(self.batchsize)
# for class_label_ind in [self.label_random_list[random.randrange(len(self.label_random_list))]]
# ]
# print(sampler_(self.images_by_class,self.batch_size,self.dataset))
return iter(sampler_(self.images_by_class,self.batch_size,self.dataset,self.label_random_list))
def __len__(self):
return self.len
def shuffle_cus(d1=20,d2=10,d3=5,batch=2):
return_list = []
total_num = d1 + d2 + d3
list1 = list(range(d1))
batch1 = d1//batch
list2 = list(range(d1,d1+d2))
batch2 = d2//batch
list3 = list(range(d1+d2,d1+d2+d3))
batch3 = d3// batch
random.shuffle(list1)
random.shuffle(list2)
random.shuffle(list3)
random_list = list(range(batch1+batch2+batch3))
random.shuffle(random_list)
for random_batch_index in random_list:
if random_batch_index < batch1:
random_batch_index1 = random_batch_index
return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch]
elif random_batch_index < batch1 + batch2:
random_batch_index1 = random_batch_index - batch1
return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch]
else:
random_batch_index1 = random_batch_index - batch1 - batch2
return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch]
return return_list
def shuffle_cus_balance(d1=20,d2=10,d3=5,batch=2,balance_index=1):
return_list = []
total_num = d1 + d2 + d3
list1 = list(range(d1))
# batch1 = d1//batch
list2 = list(range(d1,d1+d2))
# batch2 = d2//batch
list3 = list(range(d1+d2,d1+d2+d3))
# batch3 = d3// batch
random.shuffle(list1)
random.shuffle(list2)
random.shuffle(list3)
total_list = [list1,list2,list3]
target_list = total_list[balance_index]
for index,list_item in enumerate(total_list):
if index == balance_index:
continue
if len(list_item) > len(target_list):
list_item = list_item[:len(target_list)]
total_list[index] = list_item
list1 = total_list[0]
list2 = total_list[1]
list3 = total_list[2]
# list1 = list(range(d1))
d1 = len(list1)
batch1 = d1 // batch
# list2 = list(range(d1, d1 + d2))
d2 = len(list2)
batch2 = d2 // batch
# list3 = list(range(d1 + d2, d1 + d2 + d3))
d3 = len(list3)
batch3 = d3 // batch
random_list = list(range(batch1+batch2+batch3))
random.shuffle(random_list)
for random_batch_index in random_list:
if random_batch_index < batch1:
random_batch_index1 = random_batch_index
return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch]
elif random_batch_index < batch1 + batch2:
random_batch_index1 = random_batch_index - batch1
return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch]
else:
random_batch_index1 = random_batch_index - batch1 - batch2
return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch]
return return_list
class Sampler_uni(torch.utils.data.sampler.Sampler):
def __init__(self, num1, num2, num3, batchsize,balance_id=None):
self.num1 = num1
self.num2 = num2
self.num3 = num3
self.batchsize = batchsize
self.balance_id = balance_id
def __iter__(self):
if self.balance_id is not None:
rlist = shuffle_cus_balance(self.num1, self.num2, self.num3, self.batchsize, balance_index=self.balance_id)
else:
rlist = shuffle_cus(self.num1, self.num2, self.num3, self.batchsize)
return iter(rlist)
def __len__(self):
if self.balance_id is not None:
return self.num1*3
return self.num1+self.num2+self.num3
|