Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
import random | |
import math | |
from torchvision import transforms | |
from PIL import Image | |
__all__ = ['cusSampler','Sampler_uni'] | |
'''common N-pairs sampler''' | |
def index_dataset(dataset): | |
''' | |
get the index according to the dataset type(e.g. pascal or atr or cihp) | |
:param dataset: | |
:return: | |
''' | |
return_dict = {} | |
for i in range(len(dataset)): | |
tmp_lbl = dataset.datasets_lbl[i] | |
if tmp_lbl in return_dict: | |
return_dict[tmp_lbl].append(i) | |
else : | |
return_dict[tmp_lbl] = [i] | |
return return_dict | |
def sample_from_class(dataset,class_id): | |
return dataset[class_id][random.randrange(len(dataset[class_id]))] | |
def sampler_npair_K(batch_size,dataset,K=2,label_random_list = [0,0,1,1,2,2,2]): | |
images_by_class = index_dataset(dataset) | |
for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))): | |
example_indices = [sample_from_class(images_by_class, class_label_ind) for _ in range(batch_size) | |
for class_label_ind in [label_random_list[random.randrange(len(label_random_list))]] | |
] | |
yield example_indices[:batch_size] | |
def sampler_(images_by_class,batch_size,dataset,K=2,label_random_list = [0,0,1,1,]): | |
# images_by_class = index_dataset(dataset) | |
a = label_random_list[random.randrange(len(label_random_list))] | |
# print(a) | |
example_indices = [sample_from_class(images_by_class, a) for _ in range(batch_size) | |
for class_label_ind in [a] | |
] | |
return example_indices[:batch_size] | |
class cusSampler(torch.utils.data.sampler.Sampler): | |
r"""Samples elements randomly from a given list of indices, without replacement. | |
Arguments: | |
indices (sequence): a sequence of indices | |
""" | |
def __init__(self, dataset, batchsize, label_random_list=[0,1,1,1,2,2,2]): | |
self.images_by_class = index_dataset(dataset) | |
self.batch_size = batchsize | |
self.dataset = dataset | |
self.label_random_list = label_random_list | |
self.len = int(math.ceil(len(dataset) * 1.0 / batchsize)) | |
def __iter__(self): | |
# return [sample_from_class(self.images_by_class, class_label_ind) for _ in range(self.batchsize) | |
# for class_label_ind in [self.label_random_list[random.randrange(len(self.label_random_list))]] | |
# ] | |
# print(sampler_(self.images_by_class,self.batch_size,self.dataset)) | |
return iter(sampler_(self.images_by_class,self.batch_size,self.dataset,self.label_random_list)) | |
def __len__(self): | |
return self.len | |
def shuffle_cus(d1=20,d2=10,d3=5,batch=2): | |
return_list = [] | |
total_num = d1 + d2 + d3 | |
list1 = list(range(d1)) | |
batch1 = d1//batch | |
list2 = list(range(d1,d1+d2)) | |
batch2 = d2//batch | |
list3 = list(range(d1+d2,d1+d2+d3)) | |
batch3 = d3// batch | |
random.shuffle(list1) | |
random.shuffle(list2) | |
random.shuffle(list3) | |
random_list = list(range(batch1+batch2+batch3)) | |
random.shuffle(random_list) | |
for random_batch_index in random_list: | |
if random_batch_index < batch1: | |
random_batch_index1 = random_batch_index | |
return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
elif random_batch_index < batch1 + batch2: | |
random_batch_index1 = random_batch_index - batch1 | |
return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
else: | |
random_batch_index1 = random_batch_index - batch1 - batch2 | |
return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
return return_list | |
def shuffle_cus_balance(d1=20,d2=10,d3=5,batch=2,balance_index=1): | |
return_list = [] | |
total_num = d1 + d2 + d3 | |
list1 = list(range(d1)) | |
# batch1 = d1//batch | |
list2 = list(range(d1,d1+d2)) | |
# batch2 = d2//batch | |
list3 = list(range(d1+d2,d1+d2+d3)) | |
# batch3 = d3// batch | |
random.shuffle(list1) | |
random.shuffle(list2) | |
random.shuffle(list3) | |
total_list = [list1,list2,list3] | |
target_list = total_list[balance_index] | |
for index,list_item in enumerate(total_list): | |
if index == balance_index: | |
continue | |
if len(list_item) > len(target_list): | |
list_item = list_item[:len(target_list)] | |
total_list[index] = list_item | |
list1 = total_list[0] | |
list2 = total_list[1] | |
list3 = total_list[2] | |
# list1 = list(range(d1)) | |
d1 = len(list1) | |
batch1 = d1 // batch | |
# list2 = list(range(d1, d1 + d2)) | |
d2 = len(list2) | |
batch2 = d2 // batch | |
# list3 = list(range(d1 + d2, d1 + d2 + d3)) | |
d3 = len(list3) | |
batch3 = d3 // batch | |
random_list = list(range(batch1+batch2+batch3)) | |
random.shuffle(random_list) | |
for random_batch_index in random_list: | |
if random_batch_index < batch1: | |
random_batch_index1 = random_batch_index | |
return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
elif random_batch_index < batch1 + batch2: | |
random_batch_index1 = random_batch_index - batch1 | |
return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
else: | |
random_batch_index1 = random_batch_index - batch1 - batch2 | |
return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
return return_list | |
class Sampler_uni(torch.utils.data.sampler.Sampler): | |
def __init__(self, num1, num2, num3, batchsize,balance_id=None): | |
self.num1 = num1 | |
self.num2 = num2 | |
self.num3 = num3 | |
self.batchsize = batchsize | |
self.balance_id = balance_id | |
def __iter__(self): | |
if self.balance_id is not None: | |
rlist = shuffle_cus_balance(self.num1, self.num2, self.num3, self.batchsize, balance_index=self.balance_id) | |
else: | |
rlist = shuffle_cus(self.num1, self.num2, self.num3, self.batchsize) | |
return iter(rlist) | |
def __len__(self): | |
if self.balance_id is not None: | |
return self.num1*3 | |
return self.num1+self.num2+self.num3 | |