WangA commited on
Commit
14b2abc
1 Parent(s): 67a51e8

Delete dataset.py

Browse files
Files changed (1) hide show
  1. dataset.py +0 -80
dataset.py DELETED
@@ -1,80 +0,0 @@
1
- from torchvision.datasets import MNIST
2
- import os
3
- import numpy as np
4
- import random
5
-
6
- train_dataset = MNIST(os.path.join('./', "MNIST"), train=True, download=True)
7
- test_dataset = MNIST(os.path.join('./', "MNIST"), train=False, download=True)
8
-
9
-
10
- class MNIST_DS(object):
11
-
12
- def __init__(self, train_dataset, test_dataset):
13
- self.__train_labels_idx_map = {}
14
- self.__test_labels_idx_map = {}
15
-
16
- self.__train_data = train_dataset.data
17
- self.__test_data = test_dataset.data
18
- self.__train_labels = train_dataset.targets
19
- self.__test_labels = test_dataset.targets
20
-
21
- self.__train_labels_np = self.__train_labels.numpy()
22
- self.__train_unique_labels = np.unique(self.__train_labels_np)
23
-
24
- self.__test_labels_np = self.__test_labels.numpy()
25
- self.__test_unique_labels = np.unique(self.__test_labels_np)
26
-
27
- def load(self):
28
- self.__train_labels_idx_map = {}
29
- for label in self.__train_unique_labels:
30
- self.__train_labels_idx_map[label] = np.where(self.__train_labels_np == label)[0]
31
-
32
- self.__test_labels_idx_map = {}
33
- for label in self.__test_unique_labels:
34
- self.__test_labels_idx_map[label] = np.where(self.__test_labels_np == label)[0]
35
-
36
- def getTriplet(self, split="train"):
37
- pos_label = 0
38
- neg_label = 0
39
- label_idx_map = None
40
- data = None
41
-
42
- if split == 'train':
43
- pos_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)]
44
- neg_label = pos_label
45
- while neg_label is pos_label:
46
- neg_label = self.__train_unique_labels[random.randint(0, len(self.__train_unique_labels) - 1)]
47
- label_idx_map = self.__train_labels_idx_map
48
- data = self.__train_data
49
- else:
50
- pos_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)]
51
- neg_label = pos_label
52
- while neg_label is pos_label:
53
- neg_label = self.__test_unique_labels[random.randint(0, len(self.__test_unique_labels) - 1)]
54
- label_idx_map = self.__test_labels_idx_map
55
- data = self.__test_data
56
-
57
- pos_label_idx_map = label_idx_map[pos_label]
58
- pos_img_anchor_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)]
59
- pos_img_idx = pos_img_anchor_idx
60
- while pos_img_idx is pos_img_anchor_idx:
61
- pos_img_idx = pos_label_idx_map[random.randint(0, len(pos_label_idx_map) - 1)]
62
-
63
- neg_label_idx_map = label_idx_map[neg_label]
64
- neg_img_idx = neg_label_idx_map[random.randint(0, len(neg_label_idx_map) - 1)]
65
-
66
- pos_anchor_img = data[pos_img_anchor_idx].numpy()
67
- pos_img = data[pos_img_idx].numpy()
68
- neg_img = data[neg_img_idx].numpy()
69
-
70
- return pos_anchor_img, pos_img, neg_img
71
-
72
-
73
- dset_obj = MNIST_DS(train_dataset, test_dataset)
74
- dset_obj.load()
75
- train_triplets = []
76
- pos_anchor_img, pos_img, neg_img = dset_obj.getTriplet()
77
- train_triplets.append([pos_anchor_img, pos_img, neg_img])
78
-
79
- print(train_triplets[0][0])
80
-