File size: 3,088 Bytes
0da05b1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from torchvision.datasets import MNIST
import os
import numpy as np
import random

train_dataset = MNIST(os.path.join('./', "MNIST"), train=True, download=True)
test_dataset = MNIST(os.path.join('./', "MNIST"), train=False, download=True)


class MNIST_DS(object):

    def __init__(self, train_dataset, test_dataset):
        self.__train_labels_idx_map = {}
        self.__test_labels_idx_map = {}

        self.__train_data = train_dataset.data
        self.__test_data = test_dataset.data
        self.__train_labels = train_dataset.targets
        self.__test_labels = test_dataset.targets

        self.__train_labels_np = self.__train_labels.numpy()
        self.__train_unique_labels = np.unique(self.__train_labels_np)

        self.__test_labels_np = self.__test_labels.numpy()
        self.__test_unique_labels = np.unique(self.__test_labels_np)

    def load(self):
        self.__train_labels_idx_map = {}
        for label in self.__train_unique_labels:
            self.__train_labels_idx_map[label] = np.where(self.__train_labels_np == label)[0]

        self.__test_labels_idx_map = {}
        for label in self.__test_unique_labels:
            self.__test_labels_idx_map[label] = np.where(self.__test_labels_np == label)[0]

    def getTriplet(self, split="train"):
        pos_label = 0
        neg_label = 0
        label_idx_map = None
        data = None

        if split == 'train':
            pos_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)]
            neg_label = pos_label
            while neg_label is pos_label:
                neg_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)]
            label_idx_map = self.__train_labels_idx_map
            data = self.__train_data
        else:
            pos_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)]
            neg_label = pos_label
            while neg_label is pos_label:
                neg_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)]
            label_idx_map = self.__test_labels_idx_map
            data = self.__test_data

        pos_label_idx_map = label_idx_map[pos_label]
        pos_img_anchor_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)]
        pos_img_idx = pos_img_anchor_idx
        while pos_img_idx is pos_img_anchor_idx:
            pos_img_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)]

        neg_label_idx_map = label_idx_map[neg_label]
        neg_img_idx = neg_label_idx_map[random.randint(0, len(neg_label_idx_map) - 1)]

        pos_anchor_img = data[pos_img_anchor_idx].numpy()
        pos_img = data[pos_img_idx].numpy()
        neg_img = data[neg_img_idx].numpy()

        return pos_anchor_img, pos_img, neg_img


dset_obj = MNIST_DS(train_dataset, test_dataset)
dset_obj.load()
train_triplets = []
pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet()
train_triplets.append([pos_anchor_img, pos_img, neg_img])

print(train_triplets[0][0])