File size: 3,088 Bytes
0da05b1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from torchvision.datasets import MNIST
import os
import numpy as np
import random
train_dataset = MNIST(os.path.join('./', "MNIST"), train=True, download=True)
test_dataset = MNIST(os.path.join('./', "MNIST"), train=False, download=True)
class MNIST_DS(object):
def __init__(self, train_dataset, test_dataset):
self.__train_labels_idx_map = {}
self.__test_labels_idx_map = {}
self.__train_data = train_dataset.data
self.__test_data = test_dataset.data
self.__train_labels = train_dataset.targets
self.__test_labels = test_dataset.targets
self.__train_labels_np = self.__train_labels.numpy()
self.__train_unique_labels = np.unique(self.__train_labels_np)
self.__test_labels_np = self.__test_labels.numpy()
self.__test_unique_labels = np.unique(self.__test_labels_np)
def load(self):
self.__train_labels_idx_map = {}
for label in self.__train_unique_labels:
self.__train_labels_idx_map[label] = np.where(self.__train_labels_np == label)[0]
self.__test_labels_idx_map = {}
for label in self.__test_unique_labels:
self.__test_labels_idx_map[label] = np.where(self.__test_labels_np == label)[0]
def getTriplet(self, split="train"):
pos_label = 0
neg_label = 0
label_idx_map = None
data = None
if split == 'train':
pos_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)]
neg_label = pos_label
while neg_label is pos_label:
neg_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)]
label_idx_map = self.__train_labels_idx_map
data = self.__train_data
else:
pos_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)]
neg_label = pos_label
while neg_label is pos_label:
neg_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)]
label_idx_map = self.__test_labels_idx_map
data = self.__test_data
pos_label_idx_map = label_idx_map[pos_label]
pos_img_anchor_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)]
pos_img_idx = pos_img_anchor_idx
while pos_img_idx is pos_img_anchor_idx:
pos_img_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)]
neg_label_idx_map = label_idx_map[neg_label]
neg_img_idx = neg_label_idx_map[random.randint(0, len(neg_label_idx_map) - 1)]
pos_anchor_img = data[pos_img_anchor_idx].numpy()
pos_img = data[pos_img_idx].numpy()
neg_img = data[neg_img_idx].numpy()
return pos_anchor_img, pos_img, neg_img
dset_obj = MNIST_DS(train_dataset, test_dataset)
dset_obj.load()
train_triplets = []
pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet()
train_triplets.append([pos_anchor_img, pos_img, neg_img])
print(train_triplets[0][0])
|