wizzseen's picture
Upload 948 files
8a6df40 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
import random
import math
from torchvision import transforms
from PIL import Image
__all__ = ['cusSampler','Sampler_uni']
'''common N-pairs sampler'''
def index_dataset(dataset):
'''
get the index according to the dataset type(e.g. pascal or atr or cihp)
:param dataset:
:return:
'''
return_dict = {}
for i in range(len(dataset)):
tmp_lbl = dataset.datasets_lbl[i]
if tmp_lbl in return_dict:
return_dict[tmp_lbl].append(i)
else :
return_dict[tmp_lbl] = [i]
return return_dict
def sample_from_class(dataset,class_id):
return dataset[class_id][random.randrange(len(dataset[class_id]))]
def sampler_npair_K(batch_size,dataset,K=2,label_random_list = [0,0,1,1,2,2,2]):
images_by_class = index_dataset(dataset)
for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))):
example_indices = [sample_from_class(images_by_class, class_label_ind) for _ in range(batch_size)
for class_label_ind in [label_random_list[random.randrange(len(label_random_list))]]
]
yield example_indices[:batch_size]
def sampler_(images_by_class,batch_size,dataset,K=2,label_random_list = [0,0,1,1,]):
# images_by_class = index_dataset(dataset)
a = label_random_list[random.randrange(len(label_random_list))]
# print(a)
example_indices = [sample_from_class(images_by_class, a) for _ in range(batch_size)
for class_label_ind in [a]
]
return example_indices[:batch_size]
class cusSampler(torch.utils.data.sampler.Sampler):
r"""Samples elements randomly from a given list of indices, without replacement.
Arguments:
indices (sequence): a sequence of indices
"""
def __init__(self, dataset, batchsize, label_random_list=[0,1,1,1,2,2,2]):
self.images_by_class = index_dataset(dataset)
self.batch_size = batchsize
self.dataset = dataset
self.label_random_list = label_random_list
self.len = int(math.ceil(len(dataset) * 1.0 / batchsize))
def __iter__(self):
# return [sample_from_class(self.images_by_class, class_label_ind) for _ in range(self.batchsize)
# for class_label_ind in [self.label_random_list[random.randrange(len(self.label_random_list))]]
# ]
# print(sampler_(self.images_by_class,self.batch_size,self.dataset))
return iter(sampler_(self.images_by_class,self.batch_size,self.dataset,self.label_random_list))
def __len__(self):
return self.len
def shuffle_cus(d1=20,d2=10,d3=5,batch=2):
return_list = []
total_num = d1 + d2 + d3
list1 = list(range(d1))
batch1 = d1//batch
list2 = list(range(d1,d1+d2))
batch2 = d2//batch
list3 = list(range(d1+d2,d1+d2+d3))
batch3 = d3// batch
random.shuffle(list1)
random.shuffle(list2)
random.shuffle(list3)
random_list = list(range(batch1+batch2+batch3))
random.shuffle(random_list)
for random_batch_index in random_list:
if random_batch_index < batch1:
random_batch_index1 = random_batch_index
return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch]
elif random_batch_index < batch1 + batch2:
random_batch_index1 = random_batch_index - batch1
return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch]
else:
random_batch_index1 = random_batch_index - batch1 - batch2
return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch]
return return_list
def shuffle_cus_balance(d1=20,d2=10,d3=5,batch=2,balance_index=1):
return_list = []
total_num = d1 + d2 + d3
list1 = list(range(d1))
# batch1 = d1//batch
list2 = list(range(d1,d1+d2))
# batch2 = d2//batch
list3 = list(range(d1+d2,d1+d2+d3))
# batch3 = d3// batch
random.shuffle(list1)
random.shuffle(list2)
random.shuffle(list3)
total_list = [list1,list2,list3]
target_list = total_list[balance_index]
for index,list_item in enumerate(total_list):
if index == balance_index:
continue
if len(list_item) > len(target_list):
list_item = list_item[:len(target_list)]
total_list[index] = list_item
list1 = total_list[0]
list2 = total_list[1]
list3 = total_list[2]
# list1 = list(range(d1))
d1 = len(list1)
batch1 = d1 // batch
# list2 = list(range(d1, d1 + d2))
d2 = len(list2)
batch2 = d2 // batch
# list3 = list(range(d1 + d2, d1 + d2 + d3))
d3 = len(list3)
batch3 = d3 // batch
random_list = list(range(batch1+batch2+batch3))
random.shuffle(random_list)
for random_batch_index in random_list:
if random_batch_index < batch1:
random_batch_index1 = random_batch_index
return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch]
elif random_batch_index < batch1 + batch2:
random_batch_index1 = random_batch_index - batch1
return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch]
else:
random_batch_index1 = random_batch_index - batch1 - batch2
return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch]
return return_list
class Sampler_uni(torch.utils.data.sampler.Sampler):
def __init__(self, num1, num2, num3, batchsize,balance_id=None):
self.num1 = num1
self.num2 = num2
self.num3 = num3
self.batchsize = batchsize
self.balance_id = balance_id
def __iter__(self):
if self.balance_id is not None:
rlist = shuffle_cus_balance(self.num1, self.num2, self.num3, self.batchsize, balance_index=self.balance_id)
else:
rlist = shuffle_cus(self.num1, self.num2, self.num3, self.batchsize)
return iter(rlist)
def __len__(self):
if self.balance_id is not None:
return self.num1*3
return self.num1+self.num2+self.num3