Spaces:
Running
Running
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.autograd import Variable | |
| import random | |
| import math | |
| from torchvision import transforms | |
| from PIL import Image | |
| __all__ = ['cusSampler','Sampler_uni'] | |
| '''common N-pairs sampler''' | |
| def index_dataset(dataset): | |
| ''' | |
| get the index according to the dataset type(e.g. pascal or atr or cihp) | |
| :param dataset: | |
| :return: | |
| ''' | |
| return_dict = {} | |
| for i in range(len(dataset)): | |
| tmp_lbl = dataset.datasets_lbl[i] | |
| if tmp_lbl in return_dict: | |
| return_dict[tmp_lbl].append(i) | |
| else : | |
| return_dict[tmp_lbl] = [i] | |
| return return_dict | |
| def sample_from_class(dataset,class_id): | |
| return dataset[class_id][random.randrange(len(dataset[class_id]))] | |
| def sampler_npair_K(batch_size,dataset,K=2,label_random_list = [0,0,1,1,2,2,2]): | |
| images_by_class = index_dataset(dataset) | |
| for batch_idx in range(int(math.ceil(len(dataset) * 1.0 / batch_size))): | |
| example_indices = [sample_from_class(images_by_class, class_label_ind) for _ in range(batch_size) | |
| for class_label_ind in [label_random_list[random.randrange(len(label_random_list))]] | |
| ] | |
| yield example_indices[:batch_size] | |
| def sampler_(images_by_class,batch_size,dataset,K=2,label_random_list = [0,0,1,1,]): | |
| # images_by_class = index_dataset(dataset) | |
| a = label_random_list[random.randrange(len(label_random_list))] | |
| # print(a) | |
| example_indices = [sample_from_class(images_by_class, a) for _ in range(batch_size) | |
| for class_label_ind in [a] | |
| ] | |
| return example_indices[:batch_size] | |
| class cusSampler(torch.utils.data.sampler.Sampler): | |
| r"""Samples elements randomly from a given list of indices, without replacement. | |
| Arguments: | |
| indices (sequence): a sequence of indices | |
| """ | |
| def __init__(self, dataset, batchsize, label_random_list=[0,1,1,1,2,2,2]): | |
| self.images_by_class = index_dataset(dataset) | |
| self.batch_size = batchsize | |
| self.dataset = dataset | |
| self.label_random_list = label_random_list | |
| self.len = int(math.ceil(len(dataset) * 1.0 / batchsize)) | |
| def __iter__(self): | |
| # return [sample_from_class(self.images_by_class, class_label_ind) for _ in range(self.batchsize) | |
| # for class_label_ind in [self.label_random_list[random.randrange(len(self.label_random_list))]] | |
| # ] | |
| # print(sampler_(self.images_by_class,self.batch_size,self.dataset)) | |
| return iter(sampler_(self.images_by_class,self.batch_size,self.dataset,self.label_random_list)) | |
| def __len__(self): | |
| return self.len | |
| def shuffle_cus(d1=20,d2=10,d3=5,batch=2): | |
| return_list = [] | |
| total_num = d1 + d2 + d3 | |
| list1 = list(range(d1)) | |
| batch1 = d1//batch | |
| list2 = list(range(d1,d1+d2)) | |
| batch2 = d2//batch | |
| list3 = list(range(d1+d2,d1+d2+d3)) | |
| batch3 = d3// batch | |
| random.shuffle(list1) | |
| random.shuffle(list2) | |
| random.shuffle(list3) | |
| random_list = list(range(batch1+batch2+batch3)) | |
| random.shuffle(random_list) | |
| for random_batch_index in random_list: | |
| if random_batch_index < batch1: | |
| random_batch_index1 = random_batch_index | |
| return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
| elif random_batch_index < batch1 + batch2: | |
| random_batch_index1 = random_batch_index - batch1 | |
| return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
| else: | |
| random_batch_index1 = random_batch_index - batch1 - batch2 | |
| return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
| return return_list | |
| def shuffle_cus_balance(d1=20,d2=10,d3=5,batch=2,balance_index=1): | |
| return_list = [] | |
| total_num = d1 + d2 + d3 | |
| list1 = list(range(d1)) | |
| # batch1 = d1//batch | |
| list2 = list(range(d1,d1+d2)) | |
| # batch2 = d2//batch | |
| list3 = list(range(d1+d2,d1+d2+d3)) | |
| # batch3 = d3// batch | |
| random.shuffle(list1) | |
| random.shuffle(list2) | |
| random.shuffle(list3) | |
| total_list = [list1,list2,list3] | |
| target_list = total_list[balance_index] | |
| for index,list_item in enumerate(total_list): | |
| if index == balance_index: | |
| continue | |
| if len(list_item) > len(target_list): | |
| list_item = list_item[:len(target_list)] | |
| total_list[index] = list_item | |
| list1 = total_list[0] | |
| list2 = total_list[1] | |
| list3 = total_list[2] | |
| # list1 = list(range(d1)) | |
| d1 = len(list1) | |
| batch1 = d1 // batch | |
| # list2 = list(range(d1, d1 + d2)) | |
| d2 = len(list2) | |
| batch2 = d2 // batch | |
| # list3 = list(range(d1 + d2, d1 + d2 + d3)) | |
| d3 = len(list3) | |
| batch3 = d3 // batch | |
| random_list = list(range(batch1+batch2+batch3)) | |
| random.shuffle(random_list) | |
| for random_batch_index in random_list: | |
| if random_batch_index < batch1: | |
| random_batch_index1 = random_batch_index | |
| return_list += list1[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
| elif random_batch_index < batch1 + batch2: | |
| random_batch_index1 = random_batch_index - batch1 | |
| return_list += list2[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
| else: | |
| random_batch_index1 = random_batch_index - batch1 - batch2 | |
| return_list += list3[random_batch_index1*batch : (random_batch_index1+1)*batch] | |
| return return_list | |
| class Sampler_uni(torch.utils.data.sampler.Sampler): | |
| def __init__(self, num1, num2, num3, batchsize,balance_id=None): | |
| self.num1 = num1 | |
| self.num2 = num2 | |
| self.num3 = num3 | |
| self.batchsize = batchsize | |
| self.balance_id = balance_id | |
| def __iter__(self): | |
| if self.balance_id is not None: | |
| rlist = shuffle_cus_balance(self.num1, self.num2, self.num3, self.batchsize, balance_index=self.balance_id) | |
| else: | |
| rlist = shuffle_cus(self.num1, self.num2, self.num3, self.batchsize) | |
| return iter(rlist) | |
| def __len__(self): | |
| if self.balance_id is not None: | |
| return self.num1*3 | |
| return self.num1+self.num2+self.num3 | |