darklord25 commited on
Commit
96ed3dd
·
verified ·
1 Parent(s): 1e236f1

Initial Commit

Browse files
README.md CHANGED
@@ -1,3 +1,7 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
1
+ # FSL_Subspace
2
+
3
+ Codes for work on few-shot learning on chest x-ray images ([paper](https://openreview.net/pdf?id=AF97JZpgPe)).
4
+
5
+ Check our [website](https://few-shot-learning-on-chest-x-ray.github.io/Project-Page/) for a brief summary of the paper.
6
+
7
+ tl;dr : We propose a computationally efficient few-shot learning method for diagnosing chest X-rays, which uses an ensemble of random subspaces and a novel loss function to create well-separated clusters of training data in discriminative subspaces. Our method is almost 1.8 times faster than the popular t-SVD method for subspace decomposition and yields promising results on large-scale CXR datasets.
dataloader/CIFAR_FS copy.py ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as np
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+
21
+ import h5py
22
+
23
+ import cv2
24
+ from PIL import Image
25
+ from PIL import ImageEnhance
26
+
27
+ from pdb import set_trace as breakpoint
28
+
29
+ from torchvision.transforms.transforms import ToPILImage
30
+
31
+
32
+ # Set the appropriate paths of the datasets here.
33
+ _CIFAR_FS_DATASET_DIR = './cifar/CIFAR-FS/'
34
+
35
+
36
+ def buildLabelIndex(labels):
37
+ label2inds = {}
38
+ for idx, label in enumerate(labels):
39
+ if label not in label2inds:
40
+ label2inds[label] = []
41
+ label2inds[label].append(idx)
42
+
43
+ return label2inds
44
+
45
+
46
+ def load_data(file):
47
+ try:
48
+ with open(file, 'rb') as fo:
49
+ data = pickle.load(fo)
50
+ return data
51
+ except:
52
+ with open(file, 'rb') as f:
53
+ u = pickle._Unpickler(f)
54
+ u.encoding = 'latin1'
55
+ data = u.load()
56
+ return data
57
+
58
+
59
+ class CIFAR_FS(data.Dataset):
60
+ def __init__(self, phase='train', do_not_use_random_transf=False):
61
+
62
+ assert(phase == 'train' or phase == 'val' or phase ==
63
+ 'test' or phase == 'trainval')
64
+ self.phase = phase
65
+ self.name = 'CIFAR_FS_' + phase
66
+
67
+ print('Loading CIFAR-FS dataset - phase {0}'.format(phase))
68
+ file_train_categories_train_phase = os.path.join(
69
+ _CIFAR_FS_DATASET_DIR,
70
+ 'CIFAR_FS_train.pickle')
71
+ file_train_categories_val_phase = os.path.join(
72
+ _CIFAR_FS_DATASET_DIR,
73
+ 'CIFAR_FS_train.pickle')
74
+ file_train_categories_test_phase = os.path.join(
75
+ _CIFAR_FS_DATASET_DIR,
76
+ 'CIFAR_FS_train.pickle')
77
+ file_val_categories_val_phase = os.path.join(
78
+ _CIFAR_FS_DATASET_DIR,
79
+ 'CIFAR_FS_val.pickle')
80
+ file_test_categories_test_phase = os.path.join(
81
+ _CIFAR_FS_DATASET_DIR,
82
+ 'CIFAR_FS_test.pickle')
83
+
84
+ if self.phase == 'train':
85
+ # During training phase we only load the training phase images
86
+ # of the training categories (aka base categories).
87
+ data_train = load_data(file_train_categories_train_phase)
88
+ self.data = data_train['data']
89
+ self.labels = data_train['labels']
90
+
91
+ self.label2ind = buildLabelIndex(self.labels)
92
+ self.labelIds = sorted(self.label2ind.keys())
93
+ self.num_cats = len(self.labelIds)
94
+ self.labelIds_base = self.labelIds
95
+ self.num_cats_base = len(self.labelIds_base)
96
+ elif self.phase == 'trainval':
97
+ # During training phase we only load the training phase images
98
+ # of the training categories (aka base categories).
99
+ data_train = load_data(file_train_categories_train_phase)
100
+ self.data = data_train['data']
101
+ self.labels = data_train['labels']
102
+ data_base = load_data(file_train_categories_val_phase)
103
+ data_novel = load_data(file_val_categories_val_phase)
104
+ self.data = np.concatenate(
105
+ [self.data, data_novel['data']], axis=0)
106
+ self.data = np.concatenate(
107
+ [self.data, data_base['data']], axis=0)
108
+
109
+ self.labels = np.concatenate(
110
+ [self.labels, data_novel['labels']], axis=0)
111
+ self.labels = np.concatenate(
112
+ [self.labels, data_base['labels']], axis=0)
113
+
114
+ self.label2ind = buildLabelIndex(self.labels)
115
+ self.labelIds = sorted(self.label2ind.keys())
116
+ self.num_cats = len(self.labelIds)
117
+ self.labelIds_base = self.labelIds
118
+ self.num_cats_base = len(self.labelIds_base)
119
+ elif self.phase == 'val' or self.phase == 'test':
120
+ if self.phase == 'test':
121
+ # load data that will be used for evaluating the recognition
122
+ # accuracy of the base categories.
123
+ data_base = load_data(file_train_categories_test_phase)
124
+ # load data that will be use for evaluating the few-shot recogniton
125
+ # accuracy on the novel categories.
126
+ data_novel = load_data(file_test_categories_test_phase)
127
+ else: # phase=='val'
128
+ # load data that will be used for evaluating the recognition
129
+ # accuracy of the base categories.
130
+ data_base = load_data(file_train_categories_val_phase)
131
+ # load data that will be use for evaluating the few-shot recogniton
132
+ # accuracy on the novel categories.
133
+ data_novel = load_data(file_val_categories_val_phase)
134
+
135
+ self.data = np.concatenate(
136
+ [data_base['data'], data_novel['data']], axis=0)
137
+ self.labels = data_base['labels'] + data_novel['labels']
138
+
139
+ self.label2ind = buildLabelIndex(self.labels)
140
+ self.labelIds = sorted(self.label2ind.keys())
141
+ self.num_cats = len(self.labelIds)
142
+
143
+ self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
144
+ self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
145
+ self.num_cats_base = len(self.labelIds_base)
146
+ self.num_cats_novel = len(self.labelIds_novel)
147
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
148
+ assert(len(intersection) == 0)
149
+ else:
150
+ raise ValueError('Not valid phase {0}'.format(self.phase))
151
+
152
+ mean_pix = [x/255.0 for x in [129.37731888,
153
+ 124.10583864, 112.47758569]]
154
+
155
+ std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
156
+
157
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
158
+
159
+ if (self.phase == 'test' or self.phase == 'val') or (do_not_use_random_transf == True):
160
+
161
+ self.transform = transforms.Compose([
162
+ transforms.ToPILImage(),
163
+ # lambda x: np.asarray(x),
164
+ transforms.ToTensor(),
165
+ normalize
166
+ ])
167
+ else:
168
+ self.transform = transforms.Compose([
169
+ transforms.ToPILImage(),
170
+ transforms.RandomCrop(32, padding=4),
171
+ transforms.ColorJitter(
172
+ brightness=0.4, contrast=0.4, saturation=0.4),
173
+ transforms.RandomHorizontalFlip(),
174
+ transforms.ToTensor(),
175
+ # lambda x: np.asarray(x),
176
+ normalize
177
+ ])
178
+
179
+ def __getitem__(self, index):
180
+ img, label = self.data[index], self.labels[index]
181
+ # doing this so that it is consistent with all other datasets
182
+ # to return a PIL Image
183
+
184
+ # img = Image.fromarray(img)
185
+ if self.transform is not None:
186
+ img = self.transform(img)
187
+ return img, label
188
+
189
+ def __len__(self):
190
+ return len(self.data)
191
+
192
+
193
+ class FewShotDataloader():
194
+ def __init__(self,
195
+ dataset,
196
+ nKnovel=5, # number of novel categories.
197
+ nKbase=-1, # number of base categories.
198
+ # number of training examples per novel category.
199
+ nExemplars=1,
200
+ # number of test examples for all the novel categories.
201
+ nTestNovel=15*5,
202
+ # number of test examples for all the base categories.
203
+ nTestBase=15*5,
204
+ batch_size=1, # number of training episodes per batch.
205
+ num_workers=4,
206
+ epoch_size=2000, # number of batches per epoch.
207
+ ):
208
+
209
+ self.dataset = dataset
210
+ self.phase = self.dataset.phase
211
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase == 'train' or self.phase == 'trainval'
212
+ else self.dataset.num_cats_novel)
213
+ assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
214
+ self.nKnovel = nKnovel
215
+
216
+ max_possible_nKbase = self.dataset.num_cats_base
217
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
218
+ if (self.phase == 'train' or self.phase == 'trainval') and nKbase > 0:
219
+ nKbase -= self.nKnovel
220
+ max_possible_nKbase -= self.nKnovel
221
+
222
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
223
+ self.nKbase = nKbase
224
+
225
+ self.nExemplars = nExemplars
226
+ self.nTestNovel = nTestNovel
227
+ self.nTestBase = nTestBase
228
+ self.batch_size = batch_size
229
+ self.epoch_size = epoch_size
230
+ self.num_workers = num_workers
231
+ self.is_eval_mode = (self.phase == 'test') or (self.phase == 'val')
232
+
233
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
234
+ """
235
+ Samples `sample_size` number of unique image ids picked from the
236
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
237
+
238
+ Args:
239
+ cat_id: a scalar with the id of the category from which images will
240
+ be sampled.
241
+ sample_size: number of images that will be sampled.
242
+
243
+ Returns:
244
+ image_ids: a list of length `sample_size` with unique image ids.
245
+ """
246
+ assert(cat_id in self.dataset.label2ind)
247
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
248
+ # Note: random.sample samples elements without replacement.
249
+ # seed = random.randint(1,10000000)
250
+ # random.seed(seed)
251
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
252
+
253
+ def sampleCategories(self, cat_set, sample_size=1):
254
+ """
255
+ Samples `sample_size` number of unique categories picked from the
256
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
257
+
258
+ Args:
259
+ cat_set: string that specifies the set of categories from which
260
+ categories will be sampled.
261
+ sample_size: number of categories that will be sampled.
262
+
263
+ Returns:
264
+ cat_ids: a list of length `sample_size` with unique category ids.
265
+ """
266
+ if cat_set == 'base':
267
+ labelIds = self.dataset.labelIds_base
268
+ elif cat_set == 'novel':
269
+ labelIds = self.dataset.labelIds_novel
270
+ else:
271
+ raise ValueError('Not recognized category set {}'.format(cat_set))
272
+
273
+ assert(len(labelIds) >= sample_size)
274
+ # return sample_size unique categories chosen from labelIds set of
275
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
276
+ # Note: random.sample samples elements without replacement.
277
+ return random.sample(labelIds, sample_size)
278
+
279
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
280
+ """
281
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
282
+ categories.
283
+
284
+ Args:
285
+ nKbase: number of base categories
286
+ nKnovel: number of novel categories
287
+
288
+ Returns:
289
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
290
+ categories.
291
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
292
+ categories.
293
+ """
294
+ if self.is_eval_mode:
295
+ assert(nKnovel <= self.dataset.num_cats_novel)
296
+ # sample from the set of base categories 'nKbase' number of base
297
+ # categories.
298
+ Kbase = sorted(self.sampleCategories('base', nKbase))
299
+ # sample from the set of novel categories 'nKnovel' number of novel
300
+ # categories.
301
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
302
+ else:
303
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
304
+ # of categories.
305
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
306
+ assert(len(cats_ids) == (nKnovel+nKbase))
307
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
308
+ # the rest as base categories.
309
+ random.shuffle(cats_ids)
310
+ Knovel = sorted(cats_ids[:nKnovel])
311
+ Kbase = sorted(cats_ids[nKnovel:])
312
+
313
+ return Kbase, Knovel
314
+
315
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
316
+ """
317
+ Sample `nTestBase` number of images from the `Kbase` categories.
318
+
319
+ Args:
320
+ Kbase: a list of length `nKbase` with the ids of the categories from
321
+ where the images will be sampled.
322
+ nTestBase: the total number of images that will be sampled.
323
+
324
+ Returns:
325
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
326
+ element of each tuple is the image id that was sampled and the
327
+ 2nd elemend is its category label (which is in the range
328
+ [0, len(Kbase)-1]).
329
+ """
330
+ Tbase = []
331
+ if len(Kbase) > 0:
332
+ # Sample for each base category a number images such that the total
333
+ # number sampled images of all categories to be equal to `nTestBase`.
334
+ KbaseIndices = np.random.choice(
335
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
336
+ KbaseIndices, NumImagesPerCategory = np.unique(
337
+ KbaseIndices, return_counts=True)
338
+
339
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
340
+ imd_ids = self.sampleImageIdsFrom(
341
+ Kbase[Kbase_idx], sample_size=NumImages)
342
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
343
+
344
+ assert(len(Tbase) == nTestBase)
345
+
346
+ return Tbase
347
+
348
+ def sample_train_and_test_examples_for_novel_categories(
349
+ self, Knovel, nTestNovel, nExemplars, nKbase):
350
+ """Samples train and test examples of the novel categories.
351
+
352
+ Args:
353
+ Knovel: a list with the ids of the novel categories.
354
+ nTestNovel: the total number of test images that will be sampled
355
+ from all the novel categories.
356
+ nExemplars: the number of training examples per novel category that
357
+ will be sampled.
358
+ nKbase: the number of base categories. It is used as offset of the
359
+ category index of each sampled image.
360
+
361
+ Returns:
362
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
363
+ 1st element of each tuple is the image id that was sampled and
364
+ the 2nd element is its category label (which is in the range
365
+ [nKbase, nKbase + len(Knovel) - 1]).
366
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
367
+ tuples. The 1st element of each tuple is the image id that was
368
+ sampled and the 2nd element is its category label (which is in
369
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
370
+ """
371
+
372
+ if len(Knovel) == 0:
373
+ return [], []
374
+
375
+ nKnovel = len(Knovel)
376
+ Tnovel = []
377
+ Exemplars = []
378
+ assert((nTestNovel % nKnovel) == 0)
379
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
380
+
381
+ for Knovel_idx in range(len(Knovel)):
382
+ imd_ids = self.sampleImageIdsFrom(
383
+ Knovel[Knovel_idx],
384
+ sample_size=(nEvalExamplesPerClass + nExemplars))
385
+
386
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
387
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
388
+
389
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
390
+ Exemplars += [(img_id, nKbase+Knovel_idx)
391
+ for img_id in imds_ememplars]
392
+ assert(len(Tnovel) == nTestNovel)
393
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
394
+ random.shuffle(Exemplars)
395
+
396
+ return Tnovel, Exemplars
397
+
398
+ def sample_episode(self):
399
+ """Samples a training episode."""
400
+ nKnovel = self.nKnovel
401
+ nKbase = self.nKbase
402
+ nTestNovel = self.nTestNovel
403
+ nTestBase = self.nTestBase
404
+ nExemplars = self.nExemplars
405
+
406
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
407
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
408
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
409
+ Knovel, nTestNovel, nExemplars, nKbase)
410
+
411
+ # concatenate the base and novel category examples.
412
+ Test = Tbase + Tnovel
413
+ random.shuffle(Test)
414
+ Kall = Kbase + Knovel
415
+
416
+ return Exemplars, Test, Kall, nKbase
417
+
418
+ def createExamplesTensorData(self, examples):
419
+ """
420
+ Creates the examples image and label tensor data.
421
+
422
+ Args:
423
+ examples: a list of 2-element tuples, each representing a
424
+ train or test example. The 1st element of each tuple
425
+ is the image id of the example and 2nd element is the
426
+ category label of the example, which is in the range
427
+ [0, nK - 1], where nK is the total number of categories
428
+ (both novel and base).
429
+
430
+ Returns:
431
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
432
+ example images, where nExamples is the number of examples
433
+ (i.e., nExamples = len(examples)).
434
+ labels: a tensor of shape [nExamples] with the category label
435
+ of each example.
436
+ """
437
+ images = torch.stack(
438
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
439
+ labels = torch.LongTensor([label for _, label in examples])
440
+ return images, labels
441
+
442
+ def get_iterator(self, epoch=0):
443
+ rand_seed = epoch
444
+ random.seed(rand_seed)
445
+ np.random.seed(rand_seed)
446
+
447
+ def load_function(iter_idx):
448
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
449
+ Xt, Yt = self.createExamplesTensorData(Test)
450
+ Kall = torch.LongTensor(Kall)
451
+ if len(Exemplars) > 0:
452
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
453
+ return Xe, Ye, Xt, Yt, Kall, nKbase
454
+ else:
455
+ return Xt, Yt, Kall, nKbase
456
+
457
+ tnt_dataset = tnt.dataset.ListDataset(
458
+ elem_list=range(self.epoch_size), load=load_function)
459
+ data_loader = tnt_dataset.parallel(
460
+ batch_size=self.batch_size,
461
+ num_workers=(0 if self.is_eval_mode else self.num_workers),
462
+ shuffle=(False if self.is_eval_mode else True))
463
+
464
+ return data_loader
465
+
466
+ def __call__(self, epoch=0):
467
+ return self.get_iterator(epoch)
468
+
469
+ def __len__(self):
470
+ return int(self.epoch_size / self.batch_size)
dataloader/CIFAR_FS.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as np
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+
21
+ import h5py
22
+
23
+ import cv2
24
+ from PIL import Image
25
+ from PIL import ImageEnhance
26
+ import matplotlib.pyplot as plt
27
+
28
+ from pdb import set_trace as breakpoint
29
+
30
+ from torchvision.transforms.transforms import ToPILImage
31
+
32
+
33
+ # Set the appropriate paths of the datasets here.
34
+ _CIFAR_FS_DATASET_DIR = './cifar/CIFAR-FS/'
35
+
36
+
37
+ def buildLabelIndex(labels):
38
+ label2inds = {}
39
+ for idx, label in enumerate(labels):
40
+ if label not in label2inds:
41
+ label2inds[label] = []
42
+ label2inds[label].append(idx)
43
+
44
+ return label2inds
45
+
46
+
47
+ def load_data(file):
48
+ try:
49
+ with open(file, 'rb') as fo:
50
+ data = pickle.load(fo)
51
+ return data
52
+ except:
53
+ with open(file, 'rb') as f:
54
+ u = pickle._Unpickler(f)
55
+ u.encoding = 'latin1'
56
+ data = u.load()
57
+ return data
58
+
59
+
60
+ class CIFAR_FS(data.Dataset):
61
+ def __init__(self, phase='train', do_not_use_random_transf=False):
62
+
63
+ assert(phase == 'train' or phase == 'val' or phase ==
64
+ 'test' or phase == 'trainval')
65
+ self.phase = phase
66
+ self.name = 'CIFAR_FS_' + phase
67
+
68
+ print('Loading CIFAR-FS dataset - phase {0}'.format(phase))
69
+ file_train_categories_train_phase = os.path.join(
70
+ _CIFAR_FS_DATASET_DIR,
71
+ 'CIFAR_FS_train.pickle')
72
+ file_train_categories_val_phase = os.path.join(
73
+ _CIFAR_FS_DATASET_DIR,
74
+ 'CIFAR_FS_train.pickle')
75
+ file_train_categories_test_phase = os.path.join(
76
+ _CIFAR_FS_DATASET_DIR,
77
+ 'CIFAR_FS_train.pickle')
78
+ file_val_categories_val_phase = os.path.join(
79
+ _CIFAR_FS_DATASET_DIR,
80
+ 'CIFAR_FS_val.pickle')
81
+ file_test_categories_test_phase = os.path.join(
82
+ _CIFAR_FS_DATASET_DIR,
83
+ 'CIFAR_FS_test.pickle')
84
+
85
+ if self.phase == 'train':
86
+ # During training phase we only load the training phase images
87
+ # of the training categories (aka base categories).
88
+ data_train = load_data(file_train_categories_train_phase)
89
+ self.data = data_train['data']
90
+ self.labels = data_train['labels']
91
+
92
+ self.label2ind = buildLabelIndex(self.labels)
93
+ self.labelIds = sorted(self.label2ind.keys())
94
+
95
+ self.num_cats = len(self.labelIds)
96
+ self.labelIds_base = self.labelIds
97
+ self.num_cats_base = len(self.labelIds_base)
98
+ elif self.phase == 'trainval':
99
+ # During training phase we only load the training phase images
100
+ # of the training categories (aka base categories).
101
+ data_train = load_data(file_train_categories_train_phase)
102
+ self.data = data_train['data']
103
+ self.labels = data_train['labels']
104
+ data_base = load_data(file_train_categories_val_phase)
105
+ data_novel = load_data(file_val_categories_val_phase)
106
+ self.data = np.concatenate(
107
+ [self.data, data_novel['data']], axis=0)
108
+ self.data = np.concatenate(
109
+ [self.data, data_base['data']], axis=0)
110
+
111
+ self.labels = np.concatenate(
112
+ [self.labels, data_novel['labels']], axis=0)
113
+ self.labels = np.concatenate(
114
+ [self.labels, data_base['labels']], axis=0)
115
+
116
+ self.label2ind = buildLabelIndex(self.labels)
117
+ self.labelIds = sorted(self.label2ind.keys())
118
+ self.num_cats = len(self.labelIds)
119
+ self.labelIds_base = self.labelIds
120
+ self.num_cats_base = len(self.labelIds_base)
121
+ elif self.phase == 'val' or self.phase == 'test':
122
+ if self.phase == 'test':
123
+ # load data that will be used for evaluating the recognition
124
+ # accuracy of the base categories.
125
+ data_base = load_data(file_train_categories_test_phase)
126
+ # load data that will be use for evaluating the few-shot recogniton
127
+ # accuracy on the novel categories.
128
+ data_novel = load_data(file_test_categories_test_phase)
129
+ else: # phase=='val'
130
+ # load data that will be used for evaluating the recognition
131
+ # accuracy of the base categories.
132
+ data_base = load_data(file_train_categories_val_phase)
133
+ # load data that will be use for evaluating the few-shot recogniton
134
+ # accuracy on the novel categories.
135
+ data_novel = load_data(file_val_categories_val_phase)
136
+
137
+ self.data = np.concatenate(
138
+ [data_base['data'], data_novel['data']], axis=0)
139
+ self.labels = data_base['labels'] + data_novel['labels']
140
+
141
+ self.label2ind = buildLabelIndex(self.labels)
142
+ self.labelIds = sorted(self.label2ind.keys())
143
+ self.num_cats = len(self.labelIds)
144
+
145
+ self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
146
+ self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
147
+ self.num_cats_base = len(self.labelIds_base)
148
+ self.num_cats_novel = len(self.labelIds_novel)
149
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
150
+ assert(len(intersection) == 0)
151
+ else:
152
+ raise ValueError('Not valid phase {0}'.format(self.phase))
153
+
154
+ mean_pix = [x/255.0 for x in [129.37731888,
155
+ 124.10583864, 112.47758569]]
156
+
157
+ std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
158
+
159
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
160
+
161
+ if (self.phase == 'test' or self.phase == 'val') or (do_not_use_random_transf == True):
162
+
163
+ self.transform = transforms.Compose([
164
+ transforms.ToPILImage(),
165
+ # lambda x: np.asarray(x),
166
+ transforms.ToTensor(),
167
+ normalize
168
+ ])
169
+ else:
170
+ self.transform = transforms.Compose([
171
+ transforms.ToPILImage(),
172
+ transforms.RandomCrop(32, padding=4),
173
+ transforms.ColorJitter(
174
+ brightness=0.4, contrast=0.4, saturation=0.4),
175
+ transforms.RandomHorizontalFlip(),
176
+ transforms.ToTensor(),
177
+ # lambda x: np.asarray(x),
178
+ normalize
179
+ ])
180
+
181
+ def __getitem__(self, index):
182
+ img, label = self.data[index], self.labels[index]
183
+ # doing this so that it is consistent with all other datasets
184
+ # to return a PIL Image
185
+
186
+ # img = Image.fromarray(img)
187
+ if self.transform is not None:
188
+ img = self.transform(img)
189
+ return img, label
190
+
191
+ def __len__(self):
192
+ return len(self.data)
193
+
194
+
195
+ class FewShotDataloader():
196
+ def __init__(self,
197
+ dataset,
198
+ nKnovel=5, # number of novel categories.
199
+ nKbase=-1, # number of base categories.
200
+ # number of training examples per novel category.
201
+ nExemplars=1,
202
+ # number of test examples for all the novel categories.
203
+ nTestNovel=15*5,
204
+ # number of test examples for all the base categories.
205
+ nTestBase=15*5,
206
+ batch_size=1, # number of training episodes per batch.
207
+ num_workers=4,
208
+ epoch_size=2000, # number of batches per epoch.
209
+ ):
210
+
211
+ self.dataset = dataset
212
+ self.phase = self.dataset.phase
213
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase == 'train' or self.phase == 'trainval'
214
+ else self.dataset.num_cats_novel)
215
+ assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
216
+ self.nKnovel = nKnovel
217
+
218
+ max_possible_nKbase = self.dataset.num_cats_base
219
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
220
+ if (self.phase == 'train' or self.phase == 'trainval') and nKbase > 0:
221
+ nKbase -= self.nKnovel
222
+ max_possible_nKbase -= self.nKnovel
223
+
224
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
225
+ self.nKbase = nKbase
226
+
227
+ self.nExemplars = nExemplars
228
+ self.nTestNovel = nTestNovel
229
+ self.nTestBase = nTestBase
230
+ self.batch_size = batch_size
231
+ self.epoch_size = epoch_size
232
+ self.num_workers = num_workers
233
+ self.is_eval_mode = (self.phase == 'test') or (self.phase == 'val')
234
+
235
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
236
+ """
237
+ Samples `sample_size` number of unique image ids picked from the
238
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
239
+
240
+ Args:
241
+ cat_id: a scalar with the id of the category from which images will
242
+ be sampled.
243
+ sample_size: number of images that will be sampled.
244
+
245
+ Returns:
246
+ image_ids: a list of length `sample_size` with unique image ids.
247
+ """
248
+ assert(cat_id in self.dataset.label2ind)
249
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
250
+ # Note: random.sample samples elements without replacement.
251
+ # seed = random.randint(1,10000000)
252
+ # random.seed(seed)
253
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
254
+
255
+ def sampleCategories(self, cat_set, sample_size=1):
256
+ """
257
+ Samples `sample_size` number of unique categories picked from the
258
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
259
+
260
+ Args:
261
+ cat_set: string that specifies the set of categories from which
262
+ categories will be sampled.
263
+ sample_size: number of categories that will be sampled.
264
+
265
+ Returns:
266
+ cat_ids: a list of length `sample_size` with unique category ids.
267
+ """
268
+ if cat_set == 'base':
269
+ labelIds = self.dataset.labelIds_base
270
+ elif cat_set == 'novel':
271
+ labelIds = self.dataset.labelIds_novel
272
+ else:
273
+ raise ValueError('Not recognized category set {}'.format(cat_set))
274
+
275
+ assert(len(labelIds) >= sample_size)
276
+ # return sample_size unique categories chosen from labelIds set of
277
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
278
+ # Note: random.sample samples elements without replacement.
279
+ return random.sample(labelIds, sample_size)
280
+
281
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
282
+ """
283
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
284
+ categories.
285
+
286
+ Args:
287
+ nKbase: number of base categories
288
+ nKnovel: number of novel categories
289
+
290
+ Returns:
291
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
292
+ categories.
293
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
294
+ categories.
295
+ """
296
+ if self.is_eval_mode:
297
+ assert(nKnovel <= self.dataset.num_cats_novel)
298
+ # sample from the set of base categories 'nKbase' number of base
299
+ # categories.
300
+ Kbase = sorted(self.sampleCategories('base', nKbase))
301
+ # sample from the set of novel categories 'nKnovel' number of novel
302
+ # categories.
303
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
304
+ else:
305
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
306
+ # of categories.
307
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
308
+ assert(len(cats_ids) == (nKnovel+nKbase))
309
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
310
+ # the rest as base categories.
311
+ random.shuffle(cats_ids)
312
+ Knovel = sorted(cats_ids[:nKnovel])
313
+ Kbase = sorted(cats_ids[nKnovel:])
314
+
315
+ return Kbase, Knovel
316
+
317
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
318
+ """
319
+ Sample `nTestBase` number of images from the `Kbase` categories.
320
+
321
+ Args:
322
+ Kbase: a list of length `nKbase` with the ids of the categories from
323
+ where the images will be sampled.
324
+ nTestBase: the total number of images that will be sampled.
325
+
326
+ Returns:
327
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
328
+ element of each tuple is the image id that was sampled and the
329
+ 2nd elemend is its category label (which is in the range
330
+ [0, len(Kbase)-1]).
331
+ """
332
+ Tbase = []
333
+ if len(Kbase) > 0:
334
+ # Sample for each base category a number images such that the total
335
+ # number sampled images of all categories to be equal to `nTestBase`.
336
+ KbaseIndices = np.random.choice(
337
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
338
+ KbaseIndices, NumImagesPerCategory = np.unique(
339
+ KbaseIndices, return_counts=True)
340
+
341
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
342
+ imd_ids = self.sampleImageIdsFrom(
343
+ Kbase[Kbase_idx], sample_size=NumImages)
344
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
345
+
346
+ assert(len(Tbase) == nTestBase)
347
+
348
+ return Tbase
349
+
350
+ def sample_train_and_test_examples_for_novel_categories(
351
+ self, Knovel, nTestNovel, nExemplars, nKbase):
352
+ """Samples train and test examples of the novel categories.
353
+
354
+ Args:
355
+ Knovel: a list with the ids of the novel categories.
356
+ nTestNovel: the total number of test images that will be sampled
357
+ from all the novel categories.
358
+ nExemplars: the number of training examples per novel category that
359
+ will be sampled.
360
+ nKbase: the number of base categories. It is used as offset of the
361
+ category index of each sampled image.
362
+
363
+ Returns:
364
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
365
+ 1st element of each tuple is the image id that was sampled and
366
+ the 2nd element is its category label (which is in the range
367
+ [nKbase, nKbase + len(Knovel) - 1]).
368
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
369
+ tuples. The 1st element of each tuple is the image id that was
370
+ sampled and the 2nd element is its category label (which is in
371
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
372
+ """
373
+
374
+ if len(Knovel) == 0:
375
+ return [], []
376
+
377
+ nKnovel = len(Knovel)
378
+ Tnovel = []
379
+ Exemplars = []
380
+ assert((nTestNovel % nKnovel) == 0)
381
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
382
+
383
+ for Knovel_idx in range(nKnovel):
384
+ imd_ids = self.sampleImageIdsFrom(
385
+ Knovel[Knovel_idx],
386
+ sample_size=(nEvalExamplesPerClass + nExemplars))
387
+
388
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
389
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
390
+
391
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
392
+
393
+ Exemplars += [(img_id, nKbase+Knovel_idx)
394
+ for img_id in imds_ememplars]
395
+
396
+ # print('='*60)
397
+ # print(Tnovel)
398
+ # print(Exemplars)
399
+ # print('='*60)
400
+ assert(len(Tnovel) == nTestNovel)
401
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
402
+
403
+ # random.shuffle(Exemplars) # shuffle commented by me
404
+
405
+ # print(Exemplars)
406
+
407
+ return Tnovel, Exemplars
408
+
409
+ def sample_episode(self):
410
+ """Samples a training episode."""
411
+ nKnovel = self.nKnovel
412
+ nKbase = self.nKbase
413
+ nTestNovel = self.nTestNovel
414
+ nTestBase = self.nTestBase
415
+ nExemplars = self.nExemplars
416
+
417
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
418
+
419
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
420
+
421
+ # print(Tbase,Knovel)
422
+
423
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
424
+ Knovel, nTestNovel, nExemplars, nKbase)
425
+
426
+ # concatenate the base and novel category examples.
427
+ Test = Tbase + Tnovel
428
+ # random.shuffle(Test)
429
+
430
+ # print(Test)
431
+
432
+ Kall = Kbase + Knovel
433
+
434
+ return Exemplars, Test, Kall, nKbase
435
+
436
+ def createExamplesTensorData(self, examples):
437
+ """
438
+ Creates the examples image and label tensor data.
439
+
440
+ Args:
441
+ examples: a list of 2-element tuples, each representing a
442
+ train or test example. The 1st element of each tuple
443
+ is the image id of the example and 2nd element is the
444
+ category label of the example, which is in the range
445
+ [0, nK - 1], where nK is the total number of categories
446
+ (both novel and base).
447
+
448
+ Returns:
449
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
450
+ example images, where nExamples is the number of examples
451
+ (i.e., nExamples = len(examples)).
452
+ labels: a tensor of shape [nExamples] with the category label
453
+ of each example.
454
+ """
455
+ images = torch.stack(
456
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
457
+ labels = torch.LongTensor([label for _, label in examples])
458
+ return images, labels
459
+
460
+ def get_iterator(self, epoch=0):
461
+ rand_seed = epoch
462
+ random.seed(rand_seed)
463
+ np.random.seed(rand_seed)
464
+
465
+ def load_function(iter_idx):
466
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
467
+ Xt, Yt = self.createExamplesTensorData(Test)
468
+ Kall = torch.LongTensor(Kall)
469
+ if len(Exemplars) > 0:
470
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
471
+ return Xe, Ye, Xt, Yt, Kall, nKbase
472
+ else:
473
+ return Xt, Yt, Kall, nKbase
474
+
475
+ tnt_dataset = tnt.dataset.ListDataset(
476
+ elem_list=range(self.epoch_size), load=load_function)
477
+ data_loader = tnt_dataset.parallel(
478
+ batch_size=self.batch_size,
479
+ num_workers=(1 if self.is_eval_mode else self.num_workers),
480
+ shuffle=(False if self.is_eval_mode else True))
481
+
482
+ return data_loader
483
+
484
+ def __call__(self, epoch=0):
485
+ return self.get_iterator(epoch)
486
+
487
+ def __len__(self):
488
+ return int(self.epoch_size / self.batch_size)
dataloader/FC100.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as np
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+
21
+ import h5py
22
+
23
+ from PIL import Image
24
+ from PIL import ImageEnhance
25
+
26
+ from pdb import set_trace as breakpoint
27
+
28
+
29
+ # Set the appropriate paths of the datasets here.
30
+ _FC100_DATASET_DIR = './cifar/FC100/'
31
+
32
+ def buildLabelIndex(labels):
33
+ label2inds = {}
34
+ for idx, label in enumerate(labels):
35
+ if label not in label2inds:
36
+ label2inds[label] = []
37
+ label2inds[label].append(idx)
38
+
39
+ return label2inds
40
+
41
+ def load_data(file):
42
+ try:
43
+ with open(file, 'rb') as fo:
44
+ data = pickle.load(fo)
45
+ return data
46
+ except:
47
+ with open(file, 'rb') as f:
48
+ u = pickle._Unpickler(f)
49
+ u.encoding = 'latin1'
50
+ data = u.load()
51
+ return data
52
+
53
+ class FC100(data.Dataset):
54
+ def __init__(self, phase='train', do_not_use_random_transf=False):
55
+
56
+ assert(phase=='train' or phase=='val' or phase=='test'or phase=='trainval')
57
+ self.phase = phase
58
+ self.name = 'FC100_' + phase
59
+
60
+ print('Loading FC100 dataset - phase {0}'.format(phase))
61
+ file_train_categories_train_phase = os.path.join(
62
+ _FC100_DATASET_DIR,
63
+ 'FC100_train.pickle')
64
+ file_train_categories_val_phase = os.path.join(
65
+ _FC100_DATASET_DIR,
66
+ 'FC100_train.pickle')
67
+ file_train_categories_test_phase = os.path.join(
68
+ _FC100_DATASET_DIR,
69
+ 'FC100_train.pickle')
70
+ file_val_categories_val_phase = os.path.join(
71
+ _FC100_DATASET_DIR,
72
+ 'FC100_val.pickle')
73
+ file_test_categories_test_phase = os.path.join(
74
+ _FC100_DATASET_DIR,
75
+ 'FC100_test.pickle')
76
+
77
+ if self.phase=='train':
78
+ # During training phase we only load the training phase images
79
+ # of the training categories (aka base categories).
80
+ data_train = load_data(file_train_categories_train_phase)
81
+ self.data = data_train['data']
82
+ self.labels = data_train['labels']
83
+
84
+ #print (self.labels)
85
+ self.label2ind = buildLabelIndex(self.labels)
86
+ self.labelIds = sorted(self.label2ind.keys())
87
+ self.num_cats = len(self.labelIds)
88
+ self.labelIds_base = self.labelIds
89
+ self.num_cats_base = len(self.labelIds_base)
90
+ #print (self.data.shape)
91
+ elif self.phase == 'trainval':
92
+ # During training phase we only load the training phase images
93
+ # of the training categories (aka base categories).
94
+ data_train = load_data(file_train_categories_train_phase)
95
+ self.data = data_train['data']
96
+ self.labels = data_train['labels']
97
+ data_base = load_data(file_train_categories_val_phase)
98
+ data_novel = load_data(file_val_categories_val_phase)
99
+ self.data = np.concatenate(
100
+ [self.data, data_novel['data']], axis=0)
101
+ self.data = np.concatenate(
102
+ [self.data, data_base['data']], axis=0)
103
+
104
+ self.labels = np.concatenate(
105
+ [self.labels, data_novel['labels']], axis=0)
106
+ self.labels = np.concatenate(
107
+ [self.labels, data_base['labels']], axis=0)
108
+
109
+ # print (self.labels)
110
+ self.label2ind = buildLabelIndex(self.labels)
111
+ self.labelIds = sorted(self.label2ind.keys())
112
+ self.num_cats = len(self.labelIds)
113
+ self.labelIds_base = self.labelIds
114
+ self.num_cats_base = len(self.labelIds_base)
115
+ elif self.phase=='val' or self.phase=='test':
116
+ if self.phase=='test':
117
+ # load data that will be used for evaluating the recognition
118
+ # accuracy of the base categories.
119
+ data_base = load_data(file_train_categories_test_phase)
120
+ # load data that will be use for evaluating the few-shot recogniton
121
+ # accuracy on the novel categories.
122
+ data_novel = load_data(file_test_categories_test_phase)
123
+ else: # phase=='val'
124
+ # load data that will be used for evaluating the recognition
125
+ # accuracy of the base categories.
126
+ data_base = load_data(file_train_categories_val_phase)
127
+ # load data that will be use for evaluating the few-shot recogniton
128
+ # accuracy on the novel categories.
129
+ data_novel = load_data(file_val_categories_val_phase)
130
+
131
+ self.data = np.concatenate(
132
+ [data_base['data'], data_novel['data']], axis=0)
133
+ self.labels = data_base['labels'] + data_novel['labels']
134
+
135
+ self.label2ind = buildLabelIndex(self.labels)
136
+ self.labelIds = sorted(self.label2ind.keys())
137
+ self.num_cats = len(self.labelIds)
138
+
139
+ self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
140
+ self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
141
+ self.num_cats_base = len(self.labelIds_base)
142
+ self.num_cats_novel = len(self.labelIds_novel)
143
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
144
+ assert(len(intersection) == 0)
145
+ else:
146
+ raise ValueError('Not valid phase {0}'.format(self.phase))
147
+
148
+ mean_pix = [x/255.0 for x in [129.37731888, 124.10583864, 112.47758569]]
149
+
150
+ std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
151
+
152
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
153
+
154
+ if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
155
+ self.transform = transforms.Compose([
156
+ lambda x: np.asarray(x),
157
+ transforms.ToTensor(),
158
+ normalize
159
+ ])
160
+ else:
161
+ self.transform = transforms.Compose([
162
+ transforms.RandomCrop(32, padding=4),
163
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
164
+ transforms.RandomHorizontalFlip(),
165
+ lambda x: np.asarray(x),
166
+ transforms.ToTensor(),
167
+ normalize
168
+ ])
169
+
170
+ def __getitem__(self, index):
171
+ img, label = self.data[index], self.labels[index]
172
+ # doing this so that it is consistent with all other datasets
173
+ # to return a PIL Image
174
+ img = Image.fromarray(img)
175
+ if self.transform is not None:
176
+ img = self.transform(img)
177
+ return img, label
178
+
179
+ def __len__(self):
180
+ return len(self.data)
181
+
182
+
183
+ class FewShotDataloader():
184
+ def __init__(self,
185
+ dataset,
186
+ nKnovel=5, # number of novel categories.
187
+ nKbase=-1, # number of base categories.
188
+ nExemplars=1, # number of training examples per novel category.
189
+ nTestNovel=15*5, # number of test examples for all the novel categories.
190
+ nTestBase=15*5, # number of test examples for all the base categories.
191
+ batch_size=1, # number of training episodes per batch.
192
+ num_workers=4,
193
+ epoch_size=2000, # number of batches per epoch.
194
+ ):
195
+
196
+ self.dataset = dataset
197
+ self.phase = self.dataset.phase
198
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' or self.phase=='trainval'
199
+ else self.dataset.num_cats_novel)
200
+ assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
201
+ self.nKnovel = nKnovel
202
+
203
+ max_possible_nKbase = self.dataset.num_cats_base
204
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
205
+ if (self.phase=='train' or self.phase=='trainval') and nKbase > 0:
206
+ nKbase -= self.nKnovel
207
+ max_possible_nKbase -= self.nKnovel
208
+
209
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
210
+ self.nKbase = nKbase
211
+
212
+ self.nExemplars = nExemplars
213
+ self.nTestNovel = nTestNovel
214
+ self.nTestBase = nTestBase
215
+ self.batch_size = batch_size
216
+ self.epoch_size = epoch_size
217
+ self.num_workers = num_workers
218
+ self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
219
+
220
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
221
+ """
222
+ Samples `sample_size` number of unique image ids picked from the
223
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
224
+
225
+ Args:
226
+ cat_id: a scalar with the id of the category from which images will
227
+ be sampled.
228
+ sample_size: number of images that will be sampled.
229
+
230
+ Returns:
231
+ image_ids: a list of length `sample_size` with unique image ids.
232
+ """
233
+ assert(cat_id in self.dataset.label2ind)
234
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
235
+ # Note: random.sample samples elements without replacement.
236
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
237
+
238
+ def sampleCategories(self, cat_set, sample_size=1):
239
+ """
240
+ Samples `sample_size` number of unique categories picked from the
241
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
242
+
243
+ Args:
244
+ cat_set: string that specifies the set of categories from which
245
+ categories will be sampled.
246
+ sample_size: number of categories that will be sampled.
247
+
248
+ Returns:
249
+ cat_ids: a list of length `sample_size` with unique category ids.
250
+ """
251
+ if cat_set=='base':
252
+ labelIds = self.dataset.labelIds_base
253
+ elif cat_set=='novel':
254
+ labelIds = self.dataset.labelIds_novel
255
+ else:
256
+ raise ValueError('Not recognized category set {}'.format(cat_set))
257
+
258
+ assert(len(labelIds) >= sample_size)
259
+ # return sample_size unique categories chosen from labelIds set of
260
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
261
+ # Note: random.sample samples elements without replacement.
262
+ return random.sample(labelIds, sample_size)
263
+
264
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
265
+ """
266
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
267
+ categories.
268
+
269
+ Args:
270
+ nKbase: number of base categories
271
+ nKnovel: number of novel categories
272
+
273
+ Returns:
274
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
275
+ categories.
276
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
277
+ categories.
278
+ """
279
+ if self.is_eval_mode:
280
+ assert(nKnovel <= self.dataset.num_cats_novel)
281
+ # sample from the set of base categories 'nKbase' number of base
282
+ # categories.
283
+ Kbase = sorted(self.sampleCategories('base', nKbase))
284
+ # sample from the set of novel categories 'nKnovel' number of novel
285
+ # categories.
286
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
287
+ else:
288
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
289
+ # of categories.
290
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
291
+ assert(len(cats_ids) == (nKnovel+nKbase))
292
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
293
+ # the rest as base categories.
294
+ random.shuffle(cats_ids)
295
+ Knovel = sorted(cats_ids[:nKnovel])
296
+ Kbase = sorted(cats_ids[nKnovel:])
297
+
298
+ return Kbase, Knovel
299
+
300
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
301
+ """
302
+ Sample `nTestBase` number of images from the `Kbase` categories.
303
+
304
+ Args:
305
+ Kbase: a list of length `nKbase` with the ids of the categories from
306
+ where the images will be sampled.
307
+ nTestBase: the total number of images that will be sampled.
308
+
309
+ Returns:
310
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
311
+ element of each tuple is the image id that was sampled and the
312
+ 2nd elemend is its category label (which is in the range
313
+ [0, len(Kbase)-1]).
314
+ """
315
+ Tbase = []
316
+ if len(Kbase) > 0:
317
+ # Sample for each base category a number images such that the total
318
+ # number sampled images of all categories to be equal to `nTestBase`.
319
+ KbaseIndices = np.random.choice(
320
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
321
+ KbaseIndices, NumImagesPerCategory = np.unique(
322
+ KbaseIndices, return_counts=True)
323
+
324
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
325
+ imd_ids = self.sampleImageIdsFrom(
326
+ Kbase[Kbase_idx], sample_size=NumImages)
327
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
328
+
329
+ assert(len(Tbase) == nTestBase)
330
+
331
+ return Tbase
332
+
333
+ def sample_train_and_test_examples_for_novel_categories(
334
+ self, Knovel, nTestNovel, nExemplars, nKbase):
335
+ """Samples train and test examples of the novel categories.
336
+
337
+ Args:
338
+ Knovel: a list with the ids of the novel categories.
339
+ nTestNovel: the total number of test images that will be sampled
340
+ from all the novel categories.
341
+ nExemplars: the number of training examples per novel category that
342
+ will be sampled.
343
+ nKbase: the number of base categories. It is used as offset of the
344
+ category index of each sampled image.
345
+
346
+ Returns:
347
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
348
+ 1st element of each tuple is the image id that was sampled and
349
+ the 2nd element is its category label (which is in the range
350
+ [nKbase, nKbase + len(Knovel) - 1]).
351
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
352
+ tuples. The 1st element of each tuple is the image id that was
353
+ sampled and the 2nd element is its category label (which is in
354
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
355
+ """
356
+
357
+ if len(Knovel) == 0:
358
+ return [], []
359
+
360
+ nKnovel = len(Knovel)
361
+ Tnovel = []
362
+ Exemplars = []
363
+ assert((nTestNovel % nKnovel) == 0)
364
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
365
+
366
+ for Knovel_idx in range(len(Knovel)):
367
+ imd_ids = self.sampleImageIdsFrom(
368
+ Knovel[Knovel_idx],
369
+ sample_size=(nEvalExamplesPerClass + nExemplars))
370
+
371
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
372
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
373
+
374
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
375
+ Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
376
+ assert(len(Tnovel) == nTestNovel)
377
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
378
+ random.shuffle(Exemplars)
379
+
380
+ return Tnovel, Exemplars
381
+
382
+ def sample_episode(self):
383
+ """Samples a training episode."""
384
+ nKnovel = self.nKnovel
385
+ nKbase = self.nKbase
386
+ nTestNovel = self.nTestNovel
387
+ nTestBase = self.nTestBase
388
+ nExemplars = self.nExemplars
389
+
390
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
391
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
392
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
393
+ Knovel, nTestNovel, nExemplars, nKbase)
394
+
395
+ # concatenate the base and novel category examples.
396
+ Test = Tbase + Tnovel
397
+ random.shuffle(Test)
398
+ Kall = Kbase + Knovel
399
+
400
+ return Exemplars, Test, Kall, nKbase
401
+
402
+ def createExamplesTensorData(self, examples):
403
+ """
404
+ Creates the examples image and label tensor data.
405
+
406
+ Args:
407
+ examples: a list of 2-element tuples, each representing a
408
+ train or test example. The 1st element of each tuple
409
+ is the image id of the example and 2nd element is the
410
+ category label of the example, which is in the range
411
+ [0, nK - 1], where nK is the total number of categories
412
+ (both novel and base).
413
+
414
+ Returns:
415
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
416
+ example images, where nExamples is the number of examples
417
+ (i.e., nExamples = len(examples)).
418
+ labels: a tensor of shape [nExamples] with the category label
419
+ of each example.
420
+ """
421
+ images = torch.stack(
422
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
423
+ labels = torch.LongTensor([label for _, label in examples])
424
+ return images, labels
425
+
426
+ def get_iterator(self, epoch=0):
427
+ rand_seed = epoch
428
+ random.seed(rand_seed)
429
+ np.random.seed(rand_seed)
430
+ def load_function(iter_idx):
431
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
432
+ Xt, Yt = self.createExamplesTensorData(Test)
433
+ Kall = torch.LongTensor(Kall)
434
+ if len(Exemplars) > 0:
435
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
436
+ return Xe, Ye, Xt, Yt, Kall, nKbase
437
+ else:
438
+ return Xt, Yt, Kall, nKbase
439
+
440
+ tnt_dataset = tnt.dataset.ListDataset(
441
+ elem_list=range(self.epoch_size), load=load_function)
442
+ data_loader = tnt_dataset.parallel(
443
+ batch_size=self.batch_size,
444
+ num_workers=(0 if self.is_eval_mode else self.num_workers),
445
+ shuffle=(False if self.is_eval_mode else True))
446
+
447
+ return data_loader
448
+
449
+ def __call__(self, epoch=0):
450
+ return self.get_iterator(epoch)
451
+
452
+ def __len__(self):
453
+ return int(self.epoch_size / self.batch_size)
dataloader/__pycache__/chest.cpython-36.pyc ADDED
Binary file (13.2 kB). View file
 
dataloader/__pycache__/chest.cpython-37.pyc ADDED
Binary file (13.3 kB). View file
 
dataloader/__pycache__/chest.cpython-38.pyc ADDED
Binary file (13.4 kB). View file
 
dataloader/chest.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as npw
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+
24
+ import h5py
25
+
26
+ import cv2
27
+ from PIL import Image
28
+ from PIL import ImageEnhance
29
+ import matplotlib.pyplot as plt
30
+
31
+
32
+ from torchvision.transforms.transforms import ToPILImage
33
+
34
+
35
+ # Set the appropriate paths of the datasets here.
36
+ # _CIFAR_FS_DATASET_DIR = './cifar/CIFAR-FS/'
37
+ _CHEST_DATASET_DIR = './NIH'
38
+ image_path = './NIH/images'
39
+
40
+
41
+ label_dict = {'Cardiomegaly': 0, 'Edema': 1, 'Effusion': 2, 'Emphysema': 3, 'Infiltration': 4, 'Mass': 5, 'Atelectasis': 6, 'Consolidation': 7,
42
+ 'Pleural_Thickening': 8, 'Fibrosis': 9, 'Hernia': 10, 'Pneumonia': 11, 'Nodule': 12, 'Pneumothorax': 13, 'No Finding': 14}
43
+
44
+
45
+ def buildLabelIndex(labels):
46
+ label2inds = {}
47
+ for idx, label in enumerate(labels):
48
+ label = label_dict[label]
49
+ if label not in label2inds:
50
+ label2inds[label] = []
51
+ label2inds[label].append(idx)
52
+
53
+ return label2inds
54
+
55
+
56
+ def load_data(file):
57
+ try:
58
+ with open(file, 'rb') as fo:
59
+ data = pickle.load(fo)
60
+ return data
61
+ except:
62
+ with open(file, 'rb') as f:
63
+ u = pickle._Unpickler(f)
64
+ u.encoding = 'latin1'
65
+ data = u.load()
66
+ return data
67
+
68
+
69
+ class Chest(data.Dataset):
70
+ def __init__(self, phase='train', idx = 1, do_not_use_random_transf=False):
71
+
72
+ assert(phase == 'train' or phase == 'val' or phase ==
73
+ 'test' or phase == 'trainval')
74
+ self.phase = phase
75
+ # self.name = phase + '.csv'
76
+
77
+ # idx = 3 # represents group for experimentation
78
+
79
+ print('Loading Chest-XRay dataset - phase {0}'.format(phase))
80
+
81
+ train_path = os.path.join(_CHEST_DATASET_DIR, f'train{idx}.csv')
82
+ val_path = os.path.join(_CHEST_DATASET_DIR, f'val{idx}.csv')
83
+ test_path = os.path.join(_CHEST_DATASET_DIR, f'test{idx}.csv')
84
+
85
+ if self.phase == 'train':
86
+ # # During training phase we only load the training phase images
87
+ # # of the training categories (aka base categories).
88
+ # data_train = load_data(file_train_categories_train_phase)
89
+ # # self.data = data_train['data']
90
+ # self.labels = data_train['labels']
91
+
92
+ file = pd.read_csv(train_path)
93
+
94
+ self.data = file['image_id'].values
95
+
96
+ self.labels = file['class_name'].values
97
+
98
+ self.label2ind = buildLabelIndex(self.labels)
99
+
100
+ self.labelIds = sorted(self.label2ind.keys())
101
+ self.num_cats = len(self.labelIds)
102
+ self.labelIds_base = self.labelIds
103
+ self.num_cats_base = len(self.labelIds_base)
104
+
105
+ # elif self.phase == 'trainval':
106
+ # # During training phase we only load the training phase images
107
+ # # of the training categories (aka base categories).
108
+ # data_train = load_data(file_train_categories_train_phase)
109
+ # self.data = data_train['data']
110
+ # self.labels = data_train['labels']
111
+ # data_base = load_data(file_train_categories_val_phase)
112
+ # data_novel = load_data(file_val_categories_val_phase)
113
+ # self.data = np.concatenate(
114
+ # [self.data, data_novel['data']], axis=0)
115
+ # self.data = np.concatenate(
116
+ # [self.data, data_base['data']], axis=0)
117
+
118
+ # self.labels = np.concatenate(
119
+ # [self.labels, data_novel['labels']], axis=0)
120
+ # self.labels = np.concatenate(
121
+ # [self.labels, data_base['labels']], axis=0)
122
+
123
+ # self.label2ind = buildLabelIndex(self.labels)
124
+ # self.labelIds = sorted(self.label2ind.keys())
125
+ # self.num_cats = len(self.labelIds)
126
+ # self.labelIds_base = self.labelIds
127
+ # self.num_cats_base = len(self.labelIds_base)
128
+
129
+ elif self.phase == 'val' or self.phase == 'test':
130
+ if self.phase == 'test':
131
+ # # load data that will be used for evaluating the recognition
132
+ # # accuracy of the base categories.
133
+ # data_base = load_data(file_train_categories_test_phase)
134
+ # # load data that will be use for evaluating the few-shot recogniton
135
+ # # accuracy on the novel categories.
136
+ # data_novel = load_data(file_test_categories_test_phase)
137
+
138
+ train_file = pd.read_csv(train_path)
139
+ file = pd.read_csv(test_path)
140
+ else: # phase=='val'
141
+ # # load data that will be used for evaluating the recognition
142
+ # # accuracy of the base categories.
143
+ # data_base = load_data(file_train_categories_val_phase)
144
+ # # load data that will be use for evaluating the few-shot recogniton
145
+ # # accuracy on the novel categories.
146
+ # data_novel = load_data(file_val_categories_val_phase)
147
+
148
+ train_file = pd.read_csv(train_path)
149
+ file = pd.read_csv(val_path)
150
+
151
+ # self.data = np.concatenate(
152
+ # [data_base['data'], data_novel['data']], axis=0)
153
+ # self.labels = data_base['labels'] + data_novel['labels']
154
+
155
+ train_labels = train_file['class_name'].values
156
+ novel_labels = file['class_name'].values
157
+
158
+ self.data = np.concatenate(
159
+ [train_file['image_id'].values, file['image_id'].values], axis=0)
160
+ self.labels = np.concatenate(
161
+ [train_file['class_name'].values, file['class_name'].values], axis=0)
162
+
163
+ self.label2ind = buildLabelIndex(self.labels)
164
+ self.labelIds = sorted(self.label2ind.keys())
165
+ self.num_cats = len(self.labelIds)
166
+
167
+ # self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
168
+ # self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
169
+
170
+ self.labelIds_base = buildLabelIndex(train_labels).keys()
171
+ self.labelIds_novel = buildLabelIndex(novel_labels).keys()
172
+ print('='*60)
173
+ print(self.labelIds_novel)
174
+ print('='*60)
175
+
176
+ self.num_cats_base = len(self.labelIds_base)
177
+ self.num_cats_novel = len(self.labelIds_novel)
178
+ # print(self.labelIds_novel)
179
+ # print(self.num_cats_novel)
180
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
181
+ assert(len(intersection) == 0)
182
+ else:
183
+ raise ValueError('Not valid phase {0}'.format(self.phase))
184
+
185
+ # mean_pix = [x/255.0 for x in [129.37731888,
186
+ # 124.10583864, 112.47758569]]
187
+
188
+ # std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
189
+
190
+ mean_pix = [0.52024849, 0.52024849, 0.52024849]
191
+ std_pix = [0.22699496, 0.22699496, 0.22699496]
192
+
193
+
194
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
195
+
196
+ if (self.phase == 'test' or self.phase == 'val') or (do_not_use_random_transf == True):
197
+
198
+ self.transform = transforms.Compose([
199
+ transforms.ToPILImage(),
200
+ # lambda x: np.asarray(x),
201
+ transforms.ToTensor(),
202
+ # lambda x: x/255.0,
203
+ normalize
204
+ ])
205
+ else:
206
+ self.transform = transforms.Compose([
207
+ transforms.ToPILImage(),
208
+ # transforms.RandomCrop(32, padding=4),
209
+ # transforms.ColorJitter(
210
+ # brightness=0.4, contrast=0.4, saturation=0.4),
211
+ transforms.RandomHorizontalFlip(),
212
+ transforms.ToTensor(),
213
+ # lambda x: np.asarray(x),
214
+ # lambda x: x/255.0,
215
+ normalize
216
+ ])
217
+
218
+ def __getitem__(self, index):
219
+ img, label = cv2.imread(os.path.join(
220
+ image_path, self.data[index]))[:,:,::-1], self.labels[index]
221
+ img = cv2.resize(img,(128,128)) # resize by Garvit
222
+ # img = cv2.resize(img,(84, 84)) # resize by kshitiz
223
+
224
+ # img = Image.fromarray(img)
225
+ if self.transform is not None:
226
+ img = self.transform(img)
227
+ return img, label
228
+
229
+ def __len__(self):
230
+ return len(self.data)
231
+
232
+
233
+ class FewShotDataloader():
234
+ def __init__(self,
235
+ dataset,
236
+ nKnovel=5, # number of novel categories.
237
+ nKbase=-1, # number of base categories.
238
+ # number of training examples per novel category.
239
+ nExemplars=1,
240
+ # number of test examples for all the novel categories.
241
+ nTestNovel=15*5,
242
+ # number of test examples for all the base categories.
243
+ nTestBase=15*5,
244
+ batch_size=1, # number of training episodes per batch.
245
+ num_workers=4,
246
+ epoch_size=2000, # number of batches per epoch.
247
+ ):
248
+
249
+ self.dataset = dataset
250
+ self.phase = self.dataset.phase
251
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase == 'train' or self.phase == 'trainval'
252
+ else self.dataset.num_cats_novel)
253
+
254
+ assert(nKnovel >= 0 and nKnovel <= max_possible_nKnovel)
255
+ self.nKnovel = nKnovel
256
+
257
+ max_possible_nKbase = self.dataset.num_cats_base
258
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
259
+ if (self.phase == 'train' or self.phase == 'trainval') and nKbase > 0:
260
+ nKbase -= self.nKnovel
261
+ max_possible_nKbase -= self.nKnovel
262
+
263
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
264
+ self.nKbase = nKbase
265
+
266
+ self.nExemplars = nExemplars
267
+ self.nTestNovel = nTestNovel
268
+ self.nTestBase = nTestBase
269
+ self.batch_size = batch_size
270
+ self.epoch_size = epoch_size
271
+ self.num_workers = num_workers
272
+ self.is_eval_mode = (self.phase == 'test') or (self.phase == 'val')
273
+
274
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
275
+ """
276
+ Samples `sample_size` number of unique image ids picked from the
277
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
278
+
279
+ Args:
280
+ cat_id: a scalar with the id of the category from which images will
281
+ be sampled.
282
+ sample_size: number of images that will be sampled.
283
+
284
+ Returns:
285
+ image_ids: a list of length `sample_size` with unique image ids.
286
+ """
287
+ assert(cat_id in self.dataset.label2ind)
288
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
289
+ # Note: random.sample samples elements without replacement.
290
+ # seed = random.randint(1,10000000)
291
+ # random.seed(seed)
292
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
293
+
294
+ def sampleCategories(self, cat_set, sample_size=1):
295
+ """
296
+ Samples `sample_size` number of unique categories picked from the
297
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
298
+
299
+ Args:
300
+ cat_set: string that specifies the set of categories from which
301
+ categories will be sampled.
302
+ sample_size: number of categories that will be sampled.
303
+
304
+ Returns:
305
+ cat_ids: a list of length `sample_size` with unique category ids.
306
+ """
307
+ if cat_set == 'base':
308
+ labelIds = self.dataset.labelIds_base
309
+ elif cat_set == 'novel':
310
+ labelIds = self.dataset.labelIds_novel
311
+ else:
312
+ raise ValueError('Not recognized category set {}'.format(cat_set))
313
+
314
+ assert(len(labelIds) >= sample_size)
315
+ # return sample_size unique categories chosen from labelIds set of
316
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
317
+ # Note: random.sample samples elements without replacement.
318
+ return random.sample(labelIds, sample_size)
319
+
320
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
321
+ """
322
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
323
+ categories.
324
+
325
+ Args:
326
+ nKbase: number of base categories
327
+ nKnovel: number of novel categories
328
+
329
+ Returns:
330
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
331
+ categories.
332
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
333
+ categories.
334
+ """
335
+ if self.is_eval_mode:
336
+ assert(nKnovel <= self.dataset.num_cats_novel)
337
+ # sample from the set of base categories 'nKbase' number of base
338
+ # categories.
339
+ Kbase = sorted(self.sampleCategories('base', nKbase))
340
+ # sample from the set of novel categories 'nKnovel' number of novel
341
+ # categories.
342
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
343
+ else:
344
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
345
+ # of categories.
346
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
347
+ assert(len(cats_ids) == (nKnovel+nKbase))
348
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
349
+ # the rest as base categories.
350
+ random.shuffle(cats_ids)
351
+ Knovel = sorted(cats_ids[:nKnovel])
352
+ Kbase = sorted(cats_ids[nKnovel:])
353
+
354
+
355
+ return Kbase, Knovel
356
+
357
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
358
+ """
359
+ Sample `nTestBase` number of images from the `Kbase` categories.
360
+
361
+ Args:
362
+ Kbase: a list of length `nKbase` with the ids of the categories from
363
+ where the images will be sampled.
364
+ nTestBase: the total number of images that will be sampled.
365
+
366
+ Returns:
367
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
368
+ element of each tuple is the image id that was sampled and the
369
+ 2nd elemend is its category label (which is in the range
370
+ [0, len(Kbase)-1]).
371
+ """
372
+ Tbase = []
373
+ if len(Kbase) > 0:
374
+ # Sample for each base category a number images such that the total
375
+ # number sampled images of all categories to be equal to `nTestBase`.
376
+ KbaseIndices = np.random.choice(
377
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
378
+ KbaseIndices, NumImagesPerCategory = np.unique(
379
+ KbaseIndices, return_counts=True)
380
+
381
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
382
+ imd_ids = self.sampleImageIdsFrom(
383
+ Kbase[Kbase_idx], sample_size=NumImages)
384
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
385
+
386
+ assert(len(Tbase) == nTestBase)
387
+
388
+ return Tbase
389
+
390
+ def sample_train_and_test_examples_for_novel_categories(
391
+ self, Knovel, nTestNovel, nExemplars, nKbase):
392
+ """Samples train and test examples of the novel categories.
393
+
394
+ Args:
395
+ Knovel: a list with the ids of the novel categories.
396
+ nTestNovel: the total number of test images that will be sampled
397
+ from all the novel categories.
398
+ nExemplars: the number of training examples per novel category that
399
+ will be sampled.
400
+ nKbase: the number of base categories. It is used as offset of the
401
+ category index of each sampled image.
402
+
403
+ Returns:
404
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
405
+ 1st element of each tuple is the image id that was sampled and
406
+ the 2nd element is its category label (which is in the range
407
+ [nKbase, nKbase + len(Knovel) - 1]).
408
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
409
+ tuples. The 1st element of each tuple is the image id that was
410
+ sampled and the 2nd element is its category label (which is in
411
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
412
+ """
413
+
414
+ if len(Knovel) == 0:
415
+ return [], []
416
+
417
+ nKnovel = len(Knovel)
418
+ Tnovel = []
419
+ Exemplars = []
420
+ assert((nTestNovel % nKnovel) == 0)
421
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
422
+
423
+ for Knovel_idx in range(len(Knovel)):
424
+ imd_ids = self.sampleImageIdsFrom(
425
+ Knovel[Knovel_idx],
426
+ sample_size=(nEvalExamplesPerClass + nExemplars))
427
+
428
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
429
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
430
+
431
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
432
+ Exemplars += [(img_id, nKbase+Knovel_idx)
433
+ for img_id in imds_ememplars]
434
+ assert(len(Tnovel) == nTestNovel)
435
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
436
+ # random.shuffle(Exemplars)
437
+
438
+ return Tnovel, Exemplars
439
+
440
+ def sample_episode(self):
441
+ """Samples a training episode."""
442
+ nKnovel = self.nKnovel
443
+ nKbase = self.nKbase
444
+ nTestNovel = self.nTestNovel
445
+ nTestBase = self.nTestBase
446
+ nExemplars = self.nExemplars
447
+
448
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
449
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
450
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
451
+ Knovel, nTestNovel, nExemplars, nKbase)
452
+
453
+ # concatenate the base and novel category examples.
454
+ Test = Tbase + Tnovel
455
+ # random.shuffle(Test)
456
+ Kall = Kbase + Knovel
457
+
458
+ return Exemplars, Test, Kall, nKbase
459
+
460
+ def createExamplesTensorData(self, examples):
461
+ """
462
+ Creates the examples image and label tensor data.
463
+
464
+ Args:
465
+ examples: a list of 2-element tuples, each representing a
466
+ train or test example. The 1st element of each tuple
467
+ is the image id of the example and 2nd element is the
468
+ category label of the example, which is in the range
469
+ [0, nK - 1], where nK is the total number of categories
470
+ (both novel and base).
471
+
472
+ Returns:
473
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
474
+ example images, where nExamples is the number of examples
475
+ (i.e., nExamples = len(examples)).
476
+ labels: a tensor of shape [nExamples] with the category label
477
+ of each example.
478
+ """
479
+ images = torch.stack(
480
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
481
+ labels = torch.LongTensor([label for _, label in examples])
482
+ return images, labels
483
+
484
+ def get_iterator(self, epoch=0):
485
+ rand_seed = epoch
486
+ random.seed(rand_seed)
487
+ np.random.seed(rand_seed)
488
+
489
+ def load_function(iter_idx):
490
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
491
+ Xt, Yt = self.createExamplesTensorData(Test)
492
+ Kall = torch.LongTensor(Kall)
493
+ if len(Exemplars) > 0:
494
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
495
+ return Xe, Ye, Xt, Yt, Kall, nKbase
496
+ else:
497
+ return Xt, Yt, Kall, nKbase
498
+
499
+ tnt_dataset = tnt.dataset.ListDataset(
500
+ elem_list=range(self.epoch_size), load=load_function)
501
+ data_loader = tnt_dataset.parallel(
502
+ batch_size=self.batch_size,
503
+ num_workers=(0 if self.is_eval_mode else self.num_workers),
504
+ shuffle=(False if self.is_eval_mode else True),)
505
+
506
+ return data_loader
507
+
508
+ def __call__(self, epoch=0):
509
+ return self.get_iterator(epoch)
510
+
511
+ def __len__(self):
512
+ return int(self.epoch_size / self.batch_size)
dataloader/chest1.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as npw
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+ import numpy as np
21
+ import pandas as pd
22
+
23
+
24
+ import h5py
25
+
26
+ import cv2
27
+ from PIL import Image
28
+ from PIL import ImageEnhance
29
+ import matplotlib.pyplot as plt
30
+
31
+
32
+ from torchvision.transforms.transforms import ToPILImage
33
+
34
+
35
+ # Set the appropriate paths of the datasets here.
36
+ # _CIFAR_FS_DATASET_DIR = './cifar/CIFAR-FS/'
37
+ _CHEST_DATASET_DIR = './NIH'
38
+ image_path = './NIH/images'
39
+
40
+
41
+ label_dict = {'Cardiomegaly': 0, 'Edema': 1, 'Effusion': 2, 'Emphysema': 3, 'Infiltration': 4, 'Mass': 5, 'Atelectasis': 6, 'Consolidation': 7,
42
+ 'Pleural_Thickening': 8, 'Fibrosis': 9, 'Hernia': 10, 'Pneumonia': 11, 'Nodule': 12, 'Pneumothorax': 13, 'No Finding': 14}
43
+
44
+
45
+ def buildLabelIndex(labels):
46
+ label2inds = {}
47
+ for idx, label in enumerate(labels):
48
+ label = label_dict[label]
49
+ if label not in label2inds:
50
+ label2inds[label] = []
51
+ label2inds[label].append(idx)
52
+
53
+ return label2inds
54
+
55
+
56
+ def load_data(file):
57
+ try:
58
+ with open(file, 'rb') as fo:
59
+ data = pickle.load(fo)
60
+ return data
61
+ except:
62
+ with open(file, 'rb') as f:
63
+ u = pickle._Unpickler(f)
64
+ u.encoding = 'latin1'
65
+ data = u.load()
66
+ return data
67
+
68
+
69
+ class Chest(data.Dataset):
70
+ def __init__(self, phase='train', do_not_use_random_transf=False):
71
+
72
+ assert(phase == 'train' or phase == 'val' or phase ==
73
+ 'test' or phase == 'trainval')
74
+ self.phase = phase
75
+ # self.name = phase + '.csv'
76
+
77
+ idx = 1 # represents group for experimentation
78
+
79
+ print('Loading Chest-XRay dataset - phase {0}'.format(phase))
80
+
81
+ train_path = os.path.join(_CHEST_DATASET_DIR, f'train{idx}.csv')
82
+ val_path = os.path.join(_CHEST_DATASET_DIR, f'val{idx}.csv')
83
+ test_path = os.path.join(_CHEST_DATASET_DIR, f'test{idx}.csv')
84
+
85
+ if self.phase == 'train':
86
+ # # During training phase we only load the training phase images
87
+ # # of the training categories (aka base categories).
88
+ # data_train = load_data(file_train_categories_train_phase)
89
+ # # self.data = data_train['data']
90
+ # self.labels = data_train['labels']
91
+
92
+ file = pd.read_csv(train_path)
93
+
94
+ self.data = file['image_id'].values
95
+
96
+ self.labels = file['class_name'].values
97
+
98
+ self.label2ind = buildLabelIndex(self.labels)
99
+
100
+ self.labelIds = sorted(self.label2ind.keys())
101
+ self.num_cats = len(self.labelIds)
102
+ self.labelIds_base = self.labelIds
103
+ self.num_cats_base = len(self.labelIds_base)
104
+
105
+ # elif self.phase == 'trainval':
106
+ # # During training phase we only load the training phase images
107
+ # # of the training categories (aka base categories).
108
+ # data_train = load_data(file_train_categories_train_phase)
109
+ # self.data = data_train['data']
110
+ # self.labels = data_train['labels']
111
+ # data_base = load_data(file_train_categories_val_phase)
112
+ # data_novel = load_data(file_val_categories_val_phase)
113
+ # self.data = np.concatenate(
114
+ # [self.data, data_novel['data']], axis=0)
115
+ # self.data = np.concatenate(
116
+ # [self.data, data_base['data']], axis=0)
117
+
118
+ # self.labels = np.concatenate(
119
+ # [self.labels, data_novel['labels']], axis=0)
120
+ # self.labels = np.concatenate(
121
+ # [self.labels, data_base['labels']], axis=0)
122
+
123
+ # self.label2ind = buildLabelIndex(self.labels)
124
+ # self.labelIds = sorted(self.label2ind.keys())
125
+ # self.num_cats = len(self.labelIds)
126
+ # self.labelIds_base = self.labelIds
127
+ # self.num_cats_base = len(self.labelIds_base)
128
+
129
+ elif self.phase == 'val' or self.phase == 'test':
130
+ if self.phase == 'test':
131
+ # # load data that will be used for evaluating the recognition
132
+ # # accuracy of the base categories.
133
+ # data_base = load_data(file_train_categories_test_phase)
134
+ # # load data that will be use for evaluating the few-shot recogniton
135
+ # # accuracy on the novel categories.
136
+ # data_novel = load_data(file_test_categories_test_phase)
137
+
138
+ train_file = pd.read_csv(train_path)
139
+ file = pd.read_csv(test_path)
140
+ else: # phase=='val'
141
+ # # load data that will be used for evaluating the recognition
142
+ # # accuracy of the base categories.
143
+ # data_base = load_data(file_train_categories_val_phase)
144
+ # # load data that will be use for evaluating the few-shot recogniton
145
+ # # accuracy on the novel categories.
146
+ # data_novel = load_data(file_val_categories_val_phase)
147
+
148
+ train_file = pd.read_csv(train_path)
149
+ file = pd.read_csv(val_path)
150
+
151
+ # self.data = np.concatenate(
152
+ # [data_base['data'], data_novel['data']], axis=0)
153
+ # self.labels = data_base['labels'] + data_novel['labels']
154
+
155
+ train_labels = train_file['class_name'].values
156
+ novel_labels = file['class_name'].values
157
+
158
+ self.data = np.concatenate(
159
+ [train_file['image_id'].values, file['image_id'].values], axis=0)
160
+ self.labels = np.concatenate(
161
+ [train_file['class_name'].values, file['class_name'].values], axis=0)
162
+
163
+
164
+ self.label2ind = buildLabelIndex(self.labels)
165
+ self.labelIds = sorted(self.label2ind.keys())
166
+ self.num_cats = len(self.labelIds)
167
+
168
+ # self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
169
+ # self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
170
+
171
+ self.labelIds_base = buildLabelIndex(train_labels).keys()
172
+ self.labelIds_novel = buildLabelIndex(novel_labels).keys()
173
+ print('='*60)
174
+ print(self.labelIds_novel)
175
+ print('='*60)
176
+
177
+ self.num_cats_base = len(self.labelIds_base)
178
+ self.num_cats_novel = len(self.labelIds_novel)
179
+ # print(self.labelIds_novel)
180
+ # print(self.num_cats_novel)
181
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
182
+ assert(len(intersection) == 0)
183
+ else:
184
+ raise ValueError('Not valid phase {0}'.format(self.phase))
185
+
186
+ # mean_pix = [x/255.0 for x in [129.37731888,
187
+ # 124.10583864, 112.47758569]]
188
+
189
+ # std_pix = [x/255.0 for x in [68.20947949, 65.43124043, 70.45866994]]
190
+
191
+ mean_pix = [0.52024849, 0.52024849, 0.52024849]
192
+ std_pix = [0.22699496, 0.22699496, 0.22699496]
193
+
194
+
195
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
196
+
197
+ if (self.phase == 'test' or self.phase == 'val') or (do_not_use_random_transf == True):
198
+
199
+ self.transform = transforms.Compose([
200
+ transforms.ToPILImage(),
201
+ # lambda x: np.asarray(x),
202
+ transforms.ToTensor(),
203
+ # lambda x: x/255.0,
204
+ normalize
205
+ ])
206
+ else:
207
+ self.transform = transforms.Compose([
208
+ transforms.ToPILImage(),
209
+ # transforms.RandomCrop(32, padding=4),
210
+ # transforms.ColorJitter(
211
+ # brightness=0.4, contrast=0.4, saturation=0.4),
212
+ transforms.RandomHorizontalFlip(),
213
+ transforms.ToTensor(),
214
+ # lambda x: np.asarray(x),
215
+ # lambda x: x/255.0,
216
+ normalize
217
+ ])
218
+
219
+ def __getitem__(self, index):
220
+ img, label = cv2.imread(os.path.join(
221
+ image_path, self.data[index]))[:,:,::-1], self.labels[index]
222
+ img = cv2.resize(img,(128,128)) # resize by Garvit
223
+ # img = cv2.resize(img,(84, 84)) # resize by kshitiz
224
+
225
+ # img = Image.fromarray(img)
226
+ if self.transform is not None:
227
+ img = self.transform(img)
228
+ return img, label, self.data[index]
229
+ # return img, label
230
+
231
+ def __len__(self):
232
+ return len(self.data)
233
+
234
+
235
+ class FewShotDataloader():
236
+ def __init__(self,
237
+ dataset,
238
+ nKnovel=5, # number of novel categories.
239
+ nKbase=-1, # number of base categories.
240
+ # number of training examples per novel category.
241
+ nExemplars=1,
242
+ # number of test examples for all the novel categories.
243
+ nTestNovel=15*5,
244
+ # number of test examples for all the base categories.
245
+ nTestBase=15*5,
246
+ batch_size=1, # number of training episodes per batch.
247
+ num_workers=4,
248
+ epoch_size=2000, # number of batches per epoch.
249
+ ):
250
+
251
+ self.dataset = dataset
252
+ self.phase = self.dataset.phase
253
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase == 'train' or self.phase == 'trainval'
254
+ else self.dataset.num_cats_novel)
255
+
256
+ assert(nKnovel >= 0 and nKnovel <= max_possible_nKnovel)
257
+ self.nKnovel = nKnovel
258
+
259
+ max_possible_nKbase = self.dataset.num_cats_base
260
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
261
+ if (self.phase == 'train' or self.phase == 'trainval') and nKbase > 0:
262
+ nKbase -= self.nKnovel
263
+ max_possible_nKbase -= self.nKnovel
264
+
265
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
266
+ self.nKbase = nKbase
267
+
268
+ self.nExemplars = nExemplars
269
+ self.nTestNovel = nTestNovel
270
+ self.nTestBase = nTestBase
271
+ self.batch_size = batch_size
272
+ self.epoch_size = epoch_size
273
+ self.num_workers = num_workers
274
+ self.is_eval_mode = (self.phase == 'test') or (self.phase == 'val')
275
+
276
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
277
+ """
278
+ Samples `sample_size` number of unique image ids picked from the
279
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
280
+
281
+ Args:
282
+ cat_id: a scalar with the id of the category from which images will
283
+ be sampled.
284
+ sample_size: number of images that will be sampled.
285
+
286
+ Returns:
287
+ image_ids: a list of length `sample_size` with unique image ids.
288
+ """
289
+ assert(cat_id in self.dataset.label2ind)
290
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
291
+ # Note: random.sample samples elements without replacement.
292
+ # seed = random.randint(1,10000000)
293
+ # random.seed(seed)
294
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
295
+
296
+ def sampleCategories(self, cat_set, sample_size=1):
297
+ """
298
+ Samples `sample_size` number of unique categories picked from the
299
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
300
+
301
+ Args:
302
+ cat_set: string that specifies the set of categories from which
303
+ categories will be sampled.
304
+ sample_size: number of categories that will be sampled.
305
+
306
+ Returns:
307
+ cat_ids: a list of length `sample_size` with unique category ids.
308
+ """
309
+ if cat_set == 'base':
310
+ labelIds = self.dataset.labelIds_base
311
+ elif cat_set == 'novel':
312
+ labelIds = self.dataset.labelIds_novel
313
+ else:
314
+ raise ValueError('Not recognized category set {}'.format(cat_set))
315
+
316
+ assert(len(labelIds) >= sample_size)
317
+ # return sample_size unique categories chosen from labelIds set of
318
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
319
+ # Note: random.sample samples elements without replacement.
320
+ return random.sample(labelIds, sample_size)
321
+
322
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
323
+ """
324
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
325
+ categories.
326
+
327
+ Args:
328
+ nKbase: number of base categories
329
+ nKnovel: number of novel categories
330
+
331
+ Returns:
332
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
333
+ categories.
334
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
335
+ categories.
336
+ """
337
+ if self.is_eval_mode:
338
+ assert(nKnovel <= self.dataset.num_cats_novel)
339
+ # sample from the set of base categories 'nKbase' number of base
340
+ # categories.
341
+ Kbase = sorted(self.sampleCategories('base', nKbase))
342
+ # sample from the set of novel categories 'nKnovel' number of novel
343
+ # categories.
344
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
345
+ else:
346
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
347
+ # of categories.
348
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
349
+ assert(len(cats_ids) == (nKnovel+nKbase))
350
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
351
+ # the rest as base categories.
352
+ random.shuffle(cats_ids)
353
+ Knovel = sorted(cats_ids[:nKnovel])
354
+ Kbase = sorted(cats_ids[nKnovel:])
355
+
356
+
357
+ return Kbase, Knovel
358
+
359
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
360
+ """
361
+ Sample `nTestBase` number of images from the `Kbase` categories.
362
+
363
+ Args:
364
+ Kbase: a list of length `nKbase` with the ids of the categories from
365
+ where the images will be sampled.
366
+ nTestBase: the total number of images that will be sampled.
367
+
368
+ Returns:
369
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
370
+ element of each tuple is the image id that was sampled and the
371
+ 2nd elemend is its category label (which is in the range
372
+ [0, len(Kbase)-1]).
373
+ """
374
+ Tbase = []
375
+ if len(Kbase) > 0:
376
+ # Sample for each base category a number images such that the total
377
+ # number sampled images of all categories to be equal to `nTestBase`.
378
+ KbaseIndices = np.random.choice(
379
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
380
+ KbaseIndices, NumImagesPerCategory = np.unique(
381
+ KbaseIndices, return_counts=True)
382
+
383
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
384
+ imd_ids = self.sampleImageIdsFrom(
385
+ Kbase[Kbase_idx], sample_size=NumImages)
386
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
387
+
388
+ assert(len(Tbase) == nTestBase)
389
+
390
+ return Tbase
391
+
392
+ def sample_train_and_test_examples_for_novel_categories(
393
+ self, Knovel, nTestNovel, nExemplars, nKbase):
394
+ """Samples train and test examples of the novel categories.
395
+
396
+ Args:
397
+ Knovel: a list with the ids of the novel categories.
398
+ nTestNovel: the total number of test images that will be sampled
399
+ from all the novel categories.
400
+ nExemplars: the number of training examples per novel category that
401
+ will be sampled.
402
+ nKbase: the number of base categories. It is used as offset of the
403
+ category index of each sampled image.
404
+
405
+ Returns:
406
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
407
+ 1st element of each tuple is the image id that was sampled and
408
+ the 2nd element is its category label (which is in the range
409
+ [nKbase, nKbase + len(Knovel) - 1]).
410
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
411
+ tuples. The 1st element of each tuple is the image id that was
412
+ sampled and the 2nd element is its category label (which is in
413
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
414
+ """
415
+
416
+ if len(Knovel) == 0:
417
+ return [], []
418
+
419
+ nKnovel = len(Knovel)
420
+ Tnovel = []
421
+ Exemplars = []
422
+ assert((nTestNovel % nKnovel) == 0)
423
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
424
+
425
+ for Knovel_idx in range(len(Knovel)):
426
+ imd_ids = self.sampleImageIdsFrom(
427
+ Knovel[Knovel_idx],
428
+ sample_size=(nEvalExamplesPerClass + nExemplars))
429
+
430
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
431
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
432
+
433
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
434
+ Exemplars += [(img_id, nKbase+Knovel_idx)
435
+ for img_id in imds_ememplars]
436
+ assert(len(Tnovel) == nTestNovel)
437
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
438
+ # random.shuffle(Exemplars)
439
+
440
+ return Tnovel, Exemplars
441
+
442
+ def sample_episode(self):
443
+ """Samples a training episode."""
444
+ nKnovel = self.nKnovel
445
+ nKbase = self.nKbase
446
+ nTestNovel = self.nTestNovel
447
+ nTestBase = self.nTestBase
448
+ nExemplars = self.nExemplars
449
+
450
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
451
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
452
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
453
+ Knovel, nTestNovel, nExemplars, nKbase)
454
+
455
+ # concatenate the base and novel category examples.
456
+ Test = Tbase + Tnovel
457
+ # random.shuffle(Test)
458
+ Kall = Kbase + Knovel
459
+
460
+ return Exemplars, Test, Kall, nKbase
461
+
462
+ def createExamplesTensorData(self, examples):
463
+ """
464
+ Creates the examples image and label tensor data.
465
+
466
+ Args:
467
+ examples: a list of 2-element tuples, each representing a
468
+ train or test example. The 1st element of each tuple
469
+ is the image id of the example and 2nd element is the
470
+ category label of the example, which is in the range
471
+ [0, nK - 1], where nK is the total number of categories
472
+ (both novel and base).
473
+
474
+ Returns:
475
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
476
+ example images, where nExamples is the number of examples
477
+ (i.e., nExamples = len(examples)).
478
+ labels: a tensor of shape [nExamples] with the category label
479
+ of each example.
480
+ """
481
+ images = torch.stack(
482
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
483
+ names = np.stack(
484
+ [self.dataset[img_idx][-1] for img_idx, _ in examples], axis=0)
485
+ print(names)
486
+ labels = torch.LongTensor([label for _, label in examples])
487
+ return images, labels
488
+
489
+ def get_iterator(self, epoch=0):
490
+ rand_seed = epoch
491
+ random.seed(rand_seed)
492
+ np.random.seed(rand_seed)
493
+
494
+ def load_function(iter_idx):
495
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
496
+ Xt, Yt = self.createExamplesTensorData(Test)
497
+ Kall = torch.LongTensor(Kall)
498
+ if len(Exemplars) > 0:
499
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
500
+ return Xe, Ye, Xt, Yt, Kall, nKbase
501
+ else:
502
+ return Xt, Yt, Kall, nKbase
503
+
504
+ tnt_dataset = tnt.dataset.ListDataset(
505
+ elem_list=range(self.epoch_size), load=load_function)
506
+ data_loader = tnt_dataset.parallel(
507
+ batch_size=self.batch_size,
508
+ num_workers=(0 if self.is_eval_mode else self.num_workers),
509
+ shuffle=(False if self.is_eval_mode else True),)
510
+
511
+ return data_loader
512
+
513
+ def __call__(self, epoch=0):
514
+ return self.get_iterator(epoch)
515
+
516
+ def __len__(self):
517
+ return int(self.epoch_size / self.batch_size)
dataloader/mini_imagenet.py ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as np
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+
21
+ import h5py
22
+
23
+ from PIL import Image
24
+ from PIL import ImageEnhance
25
+
26
+ from pdb import set_trace as breakpoint
27
+
28
+ from torchvision.transforms.transforms import ToPILImage
29
+
30
+ # Set the appropriate paths of the datasets here.
31
+ _MINI_IMAGENET_DATASET_DIR = './miniimagenet/' ## your miniimagenet folder
32
+
33
+
34
+ def buildLabelIndex(labels):
35
+ label2inds = {}
36
+ for idx, label in enumerate(labels):
37
+ if label not in label2inds:
38
+ label2inds[label] = []
39
+ label2inds[label].append(idx)
40
+
41
+ return label2inds
42
+
43
+
44
+ def load_data(file):
45
+ try:
46
+ with open(file, 'rb') as fo:
47
+ data = pickle.load(fo)
48
+ return data
49
+ except:
50
+ with open(file, 'rb') as f:
51
+ u = pickle._Unpickler(f)
52
+ u.encoding = 'latin1'
53
+ data = u.load()
54
+ return data
55
+
56
+ class MiniImageNet(data.Dataset):
57
+ def __init__(self, phase='train', do_not_use_random_transf=False):
58
+
59
+ self.base_folder = 'miniImagenet'
60
+ #assert(phase=='train' or phase=='val' or phase=='test' or ph)
61
+ self.phase = phase
62
+ self.name = 'MiniImageNet_' + phase
63
+
64
+ print('Loading mini ImageNet dataset - phase {0}'.format(phase))
65
+ file_train_categories_train_phase = os.path.join(
66
+ _MINI_IMAGENET_DATASET_DIR,
67
+ 'miniImageNet_category_split_train_phase_train.pickle')
68
+ file_train_categories_val_phase = os.path.join(
69
+ _MINI_IMAGENET_DATASET_DIR,
70
+ 'miniImageNet_category_split_train_phase_val.pickle')
71
+ file_train_categories_test_phase = os.path.join(
72
+ _MINI_IMAGENET_DATASET_DIR,
73
+ 'miniImageNet_category_split_train_phase_test.pickle')
74
+ file_val_categories_val_phase = os.path.join(
75
+ _MINI_IMAGENET_DATASET_DIR,
76
+ 'miniImageNet_category_split_val.pickle')
77
+ file_test_categories_test_phase = os.path.join(
78
+ _MINI_IMAGENET_DATASET_DIR,
79
+ 'miniImageNet_category_split_test.pickle')
80
+
81
+ if self.phase=='train':
82
+ # During training phase we only load the training phase images
83
+ # of the training categories (aka base categories).
84
+ data_train = load_data(file_train_categories_train_phase)
85
+ self.data = data_train['data']
86
+ self.labels = data_train['labels']
87
+
88
+ self.label2ind = buildLabelIndex(self.labels)
89
+ self.labelIds = sorted(self.label2ind.keys())
90
+ self.num_cats = len(self.labelIds)
91
+ self.labelIds_base = self.labelIds
92
+ self.num_cats_base = len(self.labelIds_base)
93
+ elif self.phase == 'trainval':
94
+ # During training phase we only load the training phase images
95
+ # of the training categories (aka base categories).
96
+ data_train = load_data(file_train_categories_train_phase)
97
+ self.data = data_train['data']
98
+ self.labels = data_train['labels']
99
+ data_base = load_data(file_train_categories_val_phase)
100
+ data_novel = load_data(file_val_categories_val_phase)
101
+ self.data = np.concatenate(
102
+ [self.data, data_novel['data']], axis=0)
103
+ self.data = np.concatenate(
104
+ [self.data, data_base['data']], axis=0)
105
+ self.labels = np.concatenate(
106
+ [self.labels, data_novel['labels']], axis=0)
107
+ self.labels = np.concatenate(
108
+ [self.labels, data_base['labels']], axis=0)
109
+
110
+ self.label2ind = buildLabelIndex(self.labels)
111
+ self.labelIds = sorted(self.label2ind.keys())
112
+ self.num_cats = len(self.labelIds)
113
+ self.labelIds_base = self.labelIds
114
+ self.num_cats_base = len(self.labelIds_base)
115
+ elif self.phase=='val' or self.phase=='test':
116
+ if self.phase=='test':
117
+ # load data that will be used for evaluating the recognition
118
+ # accuracy of the base categories.
119
+ data_base = load_data(file_train_categories_test_phase)
120
+ # load data that will be use for evaluating the few-shot recogniton
121
+ # accuracy on the novel categories.
122
+ data_novel = load_data(file_test_categories_test_phase)
123
+ else: # phase=='val'
124
+ # load data that will be used for evaluating the recognition
125
+ # accuracy of the base categories.
126
+ data_base = load_data(file_train_categories_val_phase)
127
+ # load data that will be use for evaluating the few-shot recogniton
128
+ # accuracy on the novel categories.
129
+ data_novel = load_data(file_val_categories_val_phase)
130
+
131
+ self.data = np.concatenate(
132
+ [data_base['data'], data_novel['data']], axis=0)
133
+ self.labels = data_base['labels'] + data_novel['labels']
134
+
135
+ self.label2ind = buildLabelIndex(self.labels)
136
+ self.labelIds = sorted(self.label2ind.keys())
137
+ self.num_cats = len(self.labelIds)
138
+
139
+ self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
140
+ self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
141
+ self.num_cats_base = len(self.labelIds_base)
142
+ self.num_cats_novel = len(self.labelIds_novel)
143
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
144
+ assert(len(intersection) == 0)
145
+ else:
146
+ raise ValueError('Not valid phase {0}'.format(self.phase))
147
+
148
+ mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
149
+ std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
150
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
151
+
152
+ if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
153
+ self.transform = transforms.Compose([
154
+ # transforms.ToPILImage(),
155
+ # lambda x: np.asarray(x),
156
+ transforms.ToTensor(),
157
+ normalize
158
+ ])
159
+ else:
160
+ self.transform = transforms.Compose([
161
+ # transforms.ToPILImage(),
162
+ transforms.RandomCrop(84, padding=8),
163
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
164
+ transforms.RandomHorizontalFlip(),
165
+ # lambda x: np.asarray(x),
166
+ transforms.ToTensor(),
167
+ normalize
168
+ ])
169
+
170
+ def __getitem__(self, index):
171
+ img, label = self.data[index], self.labels[index]
172
+ # doing this so that it is consistent with all other datasets
173
+ # to return a PIL Image
174
+ img = Image.fromarray(img)
175
+ if self.transform is not None:
176
+ img = self.transform(img)
177
+ return img, label
178
+
179
+ def __len__(self):
180
+ return len(self.data)
181
+
182
+
183
+ class FewShotDataloader():
184
+ def __init__(self,
185
+ dataset,
186
+ nKnovel=5, # number of novel categories.
187
+ nKbase=-1, # number of base categories.
188
+ nExemplars=1, # number of training examples per novel category.
189
+ nTestNovel=15*5, # number of test examples for all the novel categories.
190
+ nTestBase=15*5, # number of test examples for all the base categories.
191
+ batch_size=1, # number of training episodes per batch.
192
+ num_workers=0,
193
+ epoch_size=2000, # number of batches per epoch.
194
+ ):
195
+
196
+ self.dataset = dataset
197
+ self.phase = self.dataset.phase
198
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' or self.phase=='trainval'
199
+ else self.dataset.num_cats_novel)
200
+ assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
201
+ self.nKnovel = nKnovel
202
+
203
+ max_possible_nKbase = self.dataset.num_cats_base
204
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
205
+ if (self.phase=='train'or self.phase=='trainval') and nKbase > 0:
206
+ nKbase -= self.nKnovel
207
+ max_possible_nKbase -= self.nKnovel
208
+
209
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
210
+ self.nKbase = nKbase
211
+
212
+ self.nExemplars = nExemplars
213
+ self.nTestNovel = nTestNovel
214
+ self.nTestBase = nTestBase
215
+ self.batch_size = batch_size
216
+ self.epoch_size = epoch_size
217
+ self.num_workers = num_workers
218
+ self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
219
+
220
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
221
+ """
222
+ Samples `sample_size` number of unique image ids picked from the
223
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
224
+
225
+ Args:
226
+ cat_id: a scalar with the id of the category from which images will
227
+ be sampled.
228
+ sample_size: number of images that will be sampled.
229
+
230
+ Returns:
231
+ image_ids: a list of length `sample_size` with unique image ids.
232
+ """
233
+ assert(cat_id in self.dataset.label2ind)
234
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
235
+ # Note: random.sample samples elements without replacement.
236
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
237
+
238
+ def sampleCategories(self, cat_set, sample_size=1):
239
+ """
240
+ Samples `sample_size` number of unique categories picked from the
241
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
242
+
243
+ Args:
244
+ cat_set: string that specifies the set of categories from which
245
+ categories will be sampled.
246
+ sample_size: number of categories that will be sampled.
247
+
248
+ Returns:
249
+ cat_ids: a list of length `sample_size` with unique category ids.
250
+ """
251
+ if cat_set=='base':
252
+ labelIds = self.dataset.labelIds_base
253
+ elif cat_set=='novel':
254
+ labelIds = self.dataset.labelIds_novel
255
+ else:
256
+ raise ValueError('Not recognized category set {}'.format(cat_set))
257
+
258
+ assert(len(labelIds) >= sample_size)
259
+ # return sample_size unique categories chosen from labelIds set of
260
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
261
+ # Note: random.sample samples elements without replacement.
262
+ return random.sample(labelIds, sample_size)
263
+
264
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
265
+ """
266
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
267
+ categories.
268
+
269
+ Args:
270
+ nKbase: number of base categories
271
+ nKnovel: number of novel categories
272
+
273
+ Returns:
274
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
275
+ categories.
276
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
277
+ categories.
278
+ """
279
+ if self.is_eval_mode:
280
+ assert(nKnovel <= self.dataset.num_cats_novel)
281
+ # sample from the set of base categories 'nKbase' number of base
282
+ # categories.
283
+ Kbase = sorted(self.sampleCategories('base', nKbase))
284
+ # sample from the set of novel categories 'nKnovel' number of novel
285
+ # categories.
286
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
287
+ else:
288
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
289
+ # of categories.
290
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
291
+ assert(len(cats_ids) == (nKnovel+nKbase))
292
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
293
+ # the rest as base categories.
294
+ random.shuffle(cats_ids)
295
+ Knovel = sorted(cats_ids[:nKnovel])
296
+ Kbase = sorted(cats_ids[nKnovel:])
297
+
298
+ return Kbase, Knovel
299
+
300
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
301
+ """
302
+ Sample `nTestBase` number of images from the `Kbase` categories.
303
+
304
+ Args:
305
+ Kbase: a list of length `nKbase` with the ids of the categories from
306
+ where the images will be sampled.
307
+ nTestBase: the total number of images that will be sampled.
308
+
309
+ Returns:
310
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
311
+ element of each tuple is the image id that was sampled and the
312
+ 2nd elemend is its category label (which is in the range
313
+ [0, len(Kbase)-1]).
314
+ """
315
+ Tbase = []
316
+ if len(Kbase) > 0:
317
+ # Sample for each base category a number images such that the total
318
+ # number sampled images of all categories to be equal to `nTestBase`.
319
+ KbaseIndices = np.random.choice(
320
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
321
+ KbaseIndices, NumImagesPerCategory = np.unique(
322
+ KbaseIndices, return_counts=True)
323
+
324
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
325
+ imd_ids = self.sampleImageIdsFrom(
326
+ Kbase[Kbase_idx], sample_size=NumImages)
327
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
328
+
329
+ assert(len(Tbase) == nTestBase)
330
+
331
+ return Tbase
332
+
333
+ def sample_train_and_test_examples_for_novel_categories(
334
+ self, Knovel, nTestNovel, nExemplars, nKbase):
335
+ """Samples train and test examples of the novel categories.
336
+
337
+ Args:
338
+ Knovel: a list with the ids of the novel categories.
339
+ nTestNovel: the total number of test images that will be sampled
340
+ from all the novel categories.
341
+ nExemplars: the number of training examples per novel category that
342
+ will be sampled.
343
+ nKbase: the number of base categories. It is used as offset of the
344
+ category index of each sampled image.
345
+
346
+ Returns:
347
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
348
+ 1st element of each tuple is the image id that was sampled and
349
+ the 2nd element is its category label (which is in the range
350
+ [nKbase, nKbase + len(Knovel) - 1]).
351
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
352
+ tuples. The 1st element of each tuple is the image id that was
353
+ sampled and the 2nd element is its category label (which is in
354
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
355
+ """
356
+
357
+ if len(Knovel) == 0:
358
+ return [], []
359
+
360
+ nKnovel = len(Knovel)
361
+ Tnovel = []
362
+ Exemplars = []
363
+ assert((nTestNovel % nKnovel) == 0)
364
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
365
+
366
+ for Knovel_idx in range(nKnovel):
367
+ imd_ids = self.sampleImageIdsFrom(
368
+ Knovel[Knovel_idx],
369
+ sample_size=(nEvalExamplesPerClass + nExemplars))
370
+
371
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
372
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
373
+
374
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
375
+ Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
376
+ assert(len(Tnovel) == nTestNovel)
377
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
378
+
379
+ # random.shuffle(Exemplars)
380
+
381
+ return Tnovel, Exemplars
382
+
383
+ def sample_episode(self):
384
+ """Samples a training episode."""
385
+ nKnovel = self.nKnovel
386
+ nKbase = self.nKbase
387
+ nTestNovel = self.nTestNovel
388
+ nTestBase = self.nTestBase
389
+ nExemplars = self.nExemplars
390
+
391
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
392
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
393
+ # print(Kbase,Knovel,Tbase)
394
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
395
+ Knovel, nTestNovel, nExemplars, nKbase)
396
+ # concatenate the base and novel category examples.
397
+ Test = Tbase + Tnovel
398
+ # random.shuffle(Test)
399
+ Kall = Kbase + Knovel
400
+
401
+ return Exemplars, Test, Kall, nKbase
402
+
403
+ def createExamplesTensorData(self, examples):
404
+ """
405
+ Creates the examples image and label tensor data.
406
+
407
+ Args:
408
+ examples: a list of 2-element tuples, each representing a
409
+ train or test example. The 1st element of each tuple
410
+ is the image id of the example and 2nd element is the
411
+ category label of the example, which is in the range
412
+ [0, nK - 1], where nK is the total number of categories
413
+ (both novel and base).
414
+
415
+ Returns:
416
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
417
+ example images, where nExamples is the number of examples
418
+ (i.e., nExamples = len(examples)).
419
+ labels: a tensor of shape [nExamples] with the category label
420
+ of each example.
421
+ """
422
+ images = torch.stack(
423
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
424
+ labels = torch.LongTensor([label for _, label in examples])
425
+ return images, labels
426
+
427
+ def get_iterator(self, epoch=0):
428
+ rand_seed = epoch
429
+ random.seed(rand_seed)
430
+ np.random.seed(rand_seed)
431
+ def load_function(iter_idx):
432
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
433
+ Xt, Yt = self.createExamplesTensorData(Test)
434
+ Kall = torch.LongTensor(Kall)
435
+ if len(Exemplars) > 0:
436
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
437
+ return Xe, Ye, Xt, Yt, Kall, nKbase
438
+ else:
439
+ return Xt, Yt, Kall, nKbase
440
+
441
+ tnt_dataset = tnt.dataset.ListDataset(
442
+ elem_list=range(self.epoch_size), load=load_function)
443
+ data_loader = tnt_dataset.parallel(
444
+ batch_size=self.batch_size,
445
+ num_workers=(0 if self.is_eval_mode else self.num_workers),
446
+ shuffle=(False if self.is_eval_mode else True))
447
+
448
+ return data_loader
449
+
450
+ def __call__(self, epoch=0):
451
+ return self.get_iterator(epoch)
452
+
453
+ def __len__(self):
454
+ return int(self.epoch_size / self.batch_size)
dataloader/simple_datamanager.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from abc import abstractmethod
3
+ import os
4
+ from PIL import Image
5
+ import json
6
+
7
+ class DataManager:
8
+ @abstractmethod
9
+ def get_data_loader(self, data_file, aug):
10
+ pass
11
+
12
+
13
+ class SimpleDataset:
14
+ def __init__(self, data_file, transform):
15
+ with open(data_file, 'r') as f:
16
+ self.meta = json.load(f)
17
+ self.transform = transform
18
+ #self.target_transform = target_transform
19
+
20
+
21
+ def __getitem__(self,i):
22
+ image_path = os.path.join(self.meta['image_names'][i])
23
+ img = Image.open(image_path).convert('RGB')
24
+ img = self.transform(img)
25
+ target = self.target_transform(self.meta['image_labels'][i])
26
+ return img, target
27
+
28
+ def __len__(self):
29
+ return len(self.meta['image_names'])
30
+
31
+
32
+ class SimpleDataManager(DataManager):
33
+ def __init__(self, dataset, batch_size):
34
+ super(SimpleDataManager, self).__init__()
35
+ self.batch_size = batch_size
36
+ self.dataset = dataset
37
+
38
+ def get_data_loader(self): #parameters that would change on train/val set
39
+ dataset = self.dataset#SimpleDataset(data_file, transform)
40
+ data_loader_params = dict(batch_size = self.batch_size, shuffle = True, num_workers = 12, pin_memory = True)
41
+ data_loader = torch.utils.data.DataLoader(dataset, **data_loader_params)
42
+
43
+ return data_loader
dataloader/tiered_imagenet.py ADDED
@@ -0,0 +1,512 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Dataloader of Gidaris & Komodakis, CVPR 2018
2
+ # Adapted from:
3
+ # https://github.com/gidariss/FewShotWithoutForgetting/blob/master/dataloader.py
4
+ from __future__ import print_function
5
+
6
+ import os
7
+ import os.path
8
+ import numpy as np
9
+ import random
10
+ import pickle
11
+ import json
12
+ import math
13
+
14
+ import torch
15
+ import torch.utils.data as data
16
+ import torchvision
17
+ import torchvision.datasets as datasets
18
+ import torchvision.transforms as transforms
19
+ import torchnet as tnt
20
+
21
+ import h5py
22
+
23
+ from PIL import Image
24
+ from PIL import ImageEnhance
25
+
26
+ from pdb import set_trace as breakpoint
27
+
28
+ from torchvision.transforms.transforms import ToPILImage
29
+
30
+ # Set the appropriate paths of the datasets here.
31
+ _TIERED_IMAGENET_DATASET_DIR = './tieredimagenet/' # your tiered imagenet folder
32
+
33
+ def buildLabelIndex(labels):
34
+ label2inds = {}
35
+ for idx, label in enumerate(labels):
36
+ if label not in label2inds:
37
+ label2inds[label] = []
38
+ label2inds[label].append(idx)
39
+
40
+ return label2inds
41
+
42
+
43
+ def load_data(file):
44
+ try:
45
+ with open(file, 'rb') as fo:
46
+ data = pickle.load(fo)
47
+ return data
48
+ except:
49
+ with open(file, 'rb') as f:
50
+ u = pickle._Unpickler(f)
51
+ u.encoding = 'latin1'
52
+ data = u.load()
53
+ return data
54
+
55
+ class tieredImageNet(data.Dataset):
56
+ def __init__(self, phase='train', do_not_use_random_transf=False):
57
+
58
+ assert(phase=='train' or phase=='val' or phase=='test' or phase=='trainval')
59
+ self.phase = phase
60
+ self.name = 'tieredImageNet_' + phase
61
+
62
+ print('Loading tiered ImageNet dataset - phase {0}'.format(phase))
63
+ file_train_categories_train_phase = os.path.join(
64
+ _TIERED_IMAGENET_DATASET_DIR,
65
+ 'train_images.npz')
66
+ label_train_categories_train_phase = os.path.join(
67
+ _TIERED_IMAGENET_DATASET_DIR,
68
+ 'train_labels.pkl')
69
+ file_train_categories_val_phase = os.path.join(
70
+ _TIERED_IMAGENET_DATASET_DIR,
71
+ 'train_images.npz')
72
+ label_train_categories_val_phase = os.path.join(
73
+ _TIERED_IMAGENET_DATASET_DIR,
74
+ 'train_labels.pkl')
75
+ file_train_categories_test_phase = os.path.join(
76
+ _TIERED_IMAGENET_DATASET_DIR,
77
+ 'train_images.npz')
78
+ label_train_categories_test_phase = os.path.join(
79
+ _TIERED_IMAGENET_DATASET_DIR,
80
+ 'train_labels.pkl')
81
+
82
+ file_val_categories_val_phase = os.path.join(
83
+ _TIERED_IMAGENET_DATASET_DIR,
84
+ 'val_images.npz')
85
+ label_val_categories_val_phase = os.path.join(
86
+ _TIERED_IMAGENET_DATASET_DIR,
87
+ 'val_labels.pkl')
88
+ file_test_categories_test_phase = os.path.join(
89
+ _TIERED_IMAGENET_DATASET_DIR,
90
+ 'test_images.npz')
91
+ label_test_categories_test_phase = os.path.join(
92
+ _TIERED_IMAGENET_DATASET_DIR,
93
+ 'test_labels.pkl')
94
+
95
+ if self.phase == 'train':
96
+ # During training phase we only load the training phase images
97
+ # of the training categories (aka base categories).
98
+ data_train = load_data(label_train_categories_train_phase)
99
+ # self.data = data_train['data']
100
+ self.labels = data_train['labels']
101
+ self.data = np.load(file_train_categories_train_phase)[
102
+ 'images'] # np.array(load_data(file_train_categories_train_phase))
103
+ # self.labels = load_data(file_train_categories_train_phase)#data_train['labels']
104
+
105
+ self.label2ind = buildLabelIndex(self.labels)
106
+ self.labelIds = sorted(self.label2ind.keys())
107
+ self.num_cats = len(self.labelIds)
108
+ self.labelIds_base = self.labelIds
109
+ self.num_cats_base = len(self.labelIds_base)
110
+ # if self.phase=='train':
111
+ # # During training phase we only load the training phase images
112
+ # # of the training categories (aka base categories).
113
+ # data_train = load_data(label_train_categories_train_phase)
114
+ # #self.data = data_train['data']
115
+ # self.labels = data_train['labels']
116
+ # self.data = np.load(file_train_categories_train_phase)['images']#np.array(load_data(file_train_categories_train_phase))
117
+ # #self.labels = load_data(file_train_categories_train_phase)#data_train['labels']
118
+ #
119
+ #
120
+ # data_base = load_data(label_train_categories_val_phase)['labels']
121
+ # data_base_images = np.load(file_train_categories_val_phase)['images']
122
+ # data_novel = load_data(label_val_categories_val_phase)['labels']
123
+ # data_novel_images = np.load(file_val_categories_val_phase)['images']
124
+ #
125
+ # self.data = np.concatenate(
126
+ # [self.data, data_base_images], axis=0)
127
+ # self.data = np.concatenate(
128
+ # [self.data, data_novel_images], axis=0)
129
+ # self.labels = np.concatenate(
130
+ # [self.labels, data_base], axis=0)
131
+ # self.labels = np.concatenate(
132
+ # [self.labels, data_novel], axis=0)
133
+ #
134
+ #
135
+ # self.label2ind = buildLabelIndex(self.labels)
136
+ # self.labelIds = sorted(self.label2ind.keys())
137
+ # self.num_cats = len(self.labelIds)
138
+ # self.labelIds_base = self.labelIds
139
+ # self.num_cats_base = len(self.labelIds_base)
140
+ elif self.phase == 'trainval':
141
+ # During training phase we only load the training phase images
142
+ # of the training categories (aka base categories).
143
+ data_train = load_data(file_train_categories_train_phase)
144
+ #self.data = data_train['data']
145
+ self.data = np.load(file_train_categories_train_phase)['images']
146
+ self.labels = data_train['labels']
147
+
148
+ data_base = load_data(label_train_categories_val_phase)['labels']
149
+ data_base_images = np.load(file_train_categories_val_phase)['images']
150
+ data_novel = load_data(label_val_categories_val_phase)['labels']
151
+ data_novel_images = np.load(file_val_categories_val_phase)['images']
152
+
153
+ self.data = np.concatenate(
154
+ [self.data, data_base_images], axis=0)
155
+ self.data = np.concatenate(
156
+ [self.data, data_novel_images], axis=0)
157
+ self.labels = np.concatenate(
158
+ [self.labels, data_base], axis=0)
159
+ self.labels = np.concatenate(
160
+ [self.labels, data_novel], axis=0)
161
+
162
+ self.label2ind = buildLabelIndex(self.labels)
163
+ self.labelIds = sorted(self.label2ind.keys())
164
+ self.num_cats = len(self.labelIds)
165
+ self.labelIds_base = self.labelIds
166
+ self.num_cats_base = len(self.labelIds_base)
167
+ elif self.phase=='val' or self.phase=='test':
168
+ if self.phase=='test':
169
+ # load data that will be used for evaluating the recognition
170
+ # accuracy of the base categories.
171
+ data_base = load_data(label_train_categories_test_phase)
172
+ data_base_images = np.load(file_train_categories_test_phase)['images']
173
+
174
+ # load data that will be use for evaluating the few-shot recogniton
175
+ # accuracy on the novel categories.
176
+ data_novel = load_data(label_test_categories_test_phase)
177
+ data_novel_images = np.load(file_test_categories_test_phase)['images']
178
+ else: # phase=='val'
179
+ # load data that will be used for evaluating the recognition
180
+ # accuracy of the base categories.
181
+ data_base = load_data(label_train_categories_val_phase)
182
+ data_base_images = np.load(file_train_categories_val_phase)['images']
183
+ #print (data_base_images)
184
+ #print (data_base_images.shape)
185
+ # load data that will be use for evaluating the few-shot recogniton
186
+ # accuracy on the novel categories.
187
+ data_novel = load_data(label_val_categories_val_phase)
188
+ data_novel_images = np.load(file_val_categories_val_phase)['images']
189
+
190
+ self.data = np.concatenate(
191
+ [data_base_images, data_novel_images], axis=0)
192
+ self.labels = data_base['labels'] + data_novel['labels']
193
+
194
+ self.label2ind = buildLabelIndex(self.labels)
195
+ self.labelIds = sorted(self.label2ind.keys())
196
+ self.num_cats = len(self.labelIds)
197
+
198
+ self.labelIds_base = buildLabelIndex(data_base['labels']).keys()
199
+ self.labelIds_novel = buildLabelIndex(data_novel['labels']).keys()
200
+ self.num_cats_base = len(self.labelIds_base)
201
+ self.num_cats_novel = len(self.labelIds_novel)
202
+ intersection = set(self.labelIds_base) & set(self.labelIds_novel)
203
+ print (intersection)
204
+ assert(len(intersection) == 0)
205
+ else:
206
+ raise ValueError('Not valid phase {0}'.format(self.phase))
207
+
208
+ mean_pix = [x/255.0 for x in [120.39586422, 115.59361427, 104.54012653]]
209
+ std_pix = [x/255.0 for x in [70.68188272, 68.27635443, 72.54505529]]
210
+ normalize = transforms.Normalize(mean=mean_pix, std=std_pix)
211
+
212
+ if (self.phase=='test' or self.phase=='val') or (do_not_use_random_transf==True):
213
+ self.transform = transforms.Compose([
214
+ # lambda x: np.asarray(x),
215
+ transforms.ToTensor(),
216
+ normalize
217
+ ])
218
+ else:
219
+ self.transform = transforms.Compose([
220
+ transforms.RandomCrop(84, padding=8),
221
+ transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4),
222
+ transforms.RandomHorizontalFlip(),
223
+ # lambda x: np.asarray(x),
224
+ transforms.ToTensor(),
225
+ normalize
226
+ ])
227
+
228
+ def __getitem__(self, index):
229
+ img, label = self.data[index], self.labels[index]
230
+ # doing this so that it is consistent with all other datasets
231
+ # to return a PIL Image
232
+ img = Image.fromarray(img)
233
+ if self.transform is not None:
234
+ img = self.transform(img)
235
+ return img, label
236
+
237
+ def __len__(self):
238
+ return len(self.data)
239
+
240
+
241
+ class FewShotDataloader():
242
+ def __init__(self,
243
+ dataset,
244
+ nKnovel=5, # number of novel categories.
245
+ nKbase=-1, # number of base categories.
246
+ nExemplars=1, # number of training examples per novel category.
247
+ nTestNovel=15*5, # number of test examples for all the novel categories.
248
+ nTestBase=15*5, # number of test examples for all the base categories.
249
+ batch_size=1, # number of training episodes per batch.
250
+ num_workers=1,
251
+ epoch_size=2000, # number of batches per epoch.
252
+ ):
253
+
254
+ self.dataset = dataset
255
+ self.phase = self.dataset.phase
256
+ max_possible_nKnovel = (self.dataset.num_cats_base if self.phase=='train' or self.phase=='trainval'
257
+ else self.dataset.num_cats_novel)
258
+ assert(nKnovel >= 0 and nKnovel < max_possible_nKnovel)
259
+ self.nKnovel = nKnovel
260
+
261
+ max_possible_nKbase = self.dataset.num_cats_base
262
+ nKbase = nKbase if nKbase >= 0 else max_possible_nKbase
263
+ if (self.phase=='train'or self.phase=='trainval') and nKbase > 0:
264
+ nKbase -= self.nKnovel
265
+ max_possible_nKbase -= self.nKnovel
266
+
267
+ assert(nKbase >= 0 and nKbase <= max_possible_nKbase)
268
+ self.nKbase = nKbase
269
+
270
+ self.nExemplars = nExemplars
271
+ self.nTestNovel = nTestNovel
272
+ self.nTestBase = nTestBase
273
+ self.batch_size = batch_size
274
+ self.epoch_size = epoch_size
275
+ self.num_workers = num_workers
276
+ self.is_eval_mode = (self.phase=='test') or (self.phase=='val')
277
+
278
+ def sampleImageIdsFrom(self, cat_id, sample_size=1):
279
+ """
280
+ Samples `sample_size` number of unique image ids picked from the
281
+ category `cat_id` (i.e., self.dataset.label2ind[cat_id]).
282
+
283
+ Args:
284
+ cat_id: a scalar with the id of the category from which images will
285
+ be sampled.
286
+ sample_size: number of images that will be sampled.
287
+
288
+ Returns:
289
+ image_ids: a list of length `sample_size` with unique image ids.
290
+ """
291
+ assert(cat_id in self.dataset.label2ind)
292
+ assert(len(self.dataset.label2ind[cat_id]) >= sample_size)
293
+ # Note: random.sample samples elements without replacement.
294
+ return random.sample(self.dataset.label2ind[cat_id], sample_size)
295
+
296
+ def sampleCategories(self, cat_set, sample_size=1):
297
+ """
298
+ Samples `sample_size` number of unique categories picked from the
299
+ `cat_set` set of categories. `cat_set` can be either 'base' or 'novel'.
300
+
301
+ Args:
302
+ cat_set: string that specifies the set of categories from which
303
+ categories will be sampled.
304
+ sample_size: number of categories that will be sampled.
305
+
306
+ Returns:
307
+ cat_ids: a list of length `sample_size` with unique category ids.
308
+ """
309
+ if cat_set=='base':
310
+ labelIds = self.dataset.labelIds_base
311
+ elif cat_set=='novel':
312
+ labelIds = self.dataset.labelIds_novel
313
+ else:
314
+ raise ValueError('Not recognized category set {}'.format(cat_set))
315
+
316
+ assert(len(labelIds) >= sample_size)
317
+ # return sample_size unique categories chosen from labelIds set of
318
+ # categories (that can be either self.labelIds_base or self.labelIds_novel)
319
+ # Note: random.sample samples elements without replacement.
320
+ return random.sample(labelIds, sample_size)
321
+
322
+ def sample_base_and_novel_categories(self, nKbase, nKnovel):
323
+ """
324
+ Samples `nKbase` number of base categories and `nKnovel` number of novel
325
+ categories.
326
+
327
+ Args:
328
+ nKbase: number of base categories
329
+ nKnovel: number of novel categories
330
+
331
+ Returns:
332
+ Kbase: a list of length 'nKbase' with the ids of the sampled base
333
+ categories.
334
+ Knovel: a list of lenght 'nKnovel' with the ids of the sampled novel
335
+ categories.
336
+ """
337
+ if self.is_eval_mode:
338
+ assert(nKnovel <= self.dataset.num_cats_novel)
339
+ # sample from the set of base categories 'nKbase' number of base
340
+ # categories.
341
+ Kbase = sorted(self.sampleCategories('base', nKbase))
342
+ # sample from the set of novel categories 'nKnovel' number of novel
343
+ # categories.
344
+ Knovel = sorted(self.sampleCategories('novel', nKnovel))
345
+ else:
346
+ # sample from the set of base categories 'nKnovel' + 'nKbase' number
347
+ # of categories.
348
+ cats_ids = self.sampleCategories('base', nKnovel+nKbase)
349
+ assert(len(cats_ids) == (nKnovel+nKbase))
350
+ # Randomly pick 'nKnovel' number of fake novel categories and keep
351
+ # the rest as base categories.
352
+ random.shuffle(cats_ids)
353
+ Knovel = sorted(cats_ids[:nKnovel])
354
+ Kbase = sorted(cats_ids[nKnovel:])
355
+
356
+ return Kbase, Knovel
357
+
358
+ def sample_test_examples_for_base_categories(self, Kbase, nTestBase):
359
+ """
360
+ Sample `nTestBase` number of images from the `Kbase` categories.
361
+
362
+ Args:
363
+ Kbase: a list of length `nKbase` with the ids of the categories from
364
+ where the images will be sampled.
365
+ nTestBase: the total number of images that will be sampled.
366
+
367
+ Returns:
368
+ Tbase: a list of length `nTestBase` with 2-element tuples. The 1st
369
+ element of each tuple is the image id that was sampled and the
370
+ 2nd elemend is its category label (which is in the range
371
+ [0, len(Kbase)-1]).
372
+ """
373
+ Tbase = []
374
+ if len(Kbase) > 0:
375
+ # Sample for each base category a number images such that the total
376
+ # number sampled images of all categories to be equal to `nTestBase`.
377
+ KbaseIndices = np.random.choice(
378
+ np.arange(len(Kbase)), size=nTestBase, replace=True)
379
+ KbaseIndices, NumImagesPerCategory = np.unique(
380
+ KbaseIndices, return_counts=True)
381
+
382
+ for Kbase_idx, NumImages in zip(KbaseIndices, NumImagesPerCategory):
383
+ imd_ids = self.sampleImageIdsFrom(
384
+ Kbase[Kbase_idx], sample_size=NumImages)
385
+ Tbase += [(img_id, Kbase_idx) for img_id in imd_ids]
386
+
387
+ assert(len(Tbase) == nTestBase)
388
+
389
+ return Tbase
390
+
391
+ def sample_train_and_test_examples_for_novel_categories(
392
+ self, Knovel, nTestNovel, nExemplars, nKbase):
393
+ """Samples train and test examples of the novel categories.
394
+
395
+ Args:
396
+ Knovel: a list with the ids of the novel categories.
397
+ nTestNovel: the total number of test images that will be sampled
398
+ from all the novel categories.
399
+ nExemplars: the number of training examples per novel category that
400
+ will be sampled.
401
+ nKbase: the number of base categories. It is used as offset of the
402
+ category index of each sampled image.
403
+
404
+ Returns:
405
+ Tnovel: a list of length `nTestNovel` with 2-element tuples. The
406
+ 1st element of each tuple is the image id that was sampled and
407
+ the 2nd element is its category label (which is in the range
408
+ [nKbase, nKbase + len(Knovel) - 1]).
409
+ Exemplars: a list of length len(Knovel) * nExemplars of 2-element
410
+ tuples. The 1st element of each tuple is the image id that was
411
+ sampled and the 2nd element is its category label (which is in
412
+ the ragne [nKbase, nKbase + len(Knovel) - 1]).
413
+ """
414
+
415
+ if len(Knovel) == 0:
416
+ return [], []
417
+
418
+ nKnovel = len(Knovel)
419
+ Tnovel = []
420
+ Exemplars = []
421
+ assert((nTestNovel % nKnovel) == 0)
422
+ nEvalExamplesPerClass = int(nTestNovel / nKnovel)
423
+
424
+ for Knovel_idx in range(len(Knovel)):
425
+ imd_ids = self.sampleImageIdsFrom(
426
+ Knovel[Knovel_idx],
427
+ sample_size=(nEvalExamplesPerClass + nExemplars))
428
+
429
+ imds_tnovel = imd_ids[:nEvalExamplesPerClass]
430
+ imds_ememplars = imd_ids[nEvalExamplesPerClass:]
431
+
432
+ Tnovel += [(img_id, nKbase+Knovel_idx) for img_id in imds_tnovel]
433
+ Exemplars += [(img_id, nKbase+Knovel_idx) for img_id in imds_ememplars]
434
+ assert(len(Tnovel) == nTestNovel)
435
+ assert(len(Exemplars) == len(Knovel) * nExemplars)
436
+
437
+ # random.shuffle(Exemplars)
438
+
439
+ return Tnovel, Exemplars
440
+
441
+ def sample_episode(self):
442
+ """Samples a training episode."""
443
+ nKnovel = self.nKnovel
444
+ nKbase = self.nKbase
445
+ nTestNovel = self.nTestNovel
446
+ nTestBase = self.nTestBase
447
+ nExemplars = self.nExemplars
448
+
449
+ Kbase, Knovel = self.sample_base_and_novel_categories(nKbase, nKnovel)
450
+ Tbase = self.sample_test_examples_for_base_categories(Kbase, nTestBase)
451
+ Tnovel, Exemplars = self.sample_train_and_test_examples_for_novel_categories(
452
+ Knovel, nTestNovel, nExemplars, nKbase)
453
+
454
+ # concatenate the base and novel category examples.
455
+ Test = Tbase + Tnovel
456
+ # random.shuffle(Test)
457
+ Kall = Kbase + Knovel
458
+
459
+ return Exemplars, Test, Kall, nKbase
460
+
461
+ def createExamplesTensorData(self, examples):
462
+ """
463
+ Creates the examples image and label tensor data.
464
+
465
+ Args:
466
+ examples: a list of 2-element tuples, each representing a
467
+ train or test example. The 1st element of each tuple
468
+ is the image id of the example and 2nd element is the
469
+ category label of the example, which is in the range
470
+ [0, nK - 1], where nK is the total number of categories
471
+ (both novel and base).
472
+
473
+ Returns:
474
+ images: a tensor of shape [nExamples, Height, Width, 3] with the
475
+ example images, where nExamples is the number of examples
476
+ (i.e., nExamples = len(examples)).
477
+ labels: a tensor of shape [nExamples] with the category label
478
+ of each example.
479
+ """
480
+ images = torch.stack(
481
+ [self.dataset[img_idx][0] for img_idx, _ in examples], dim=0)
482
+ labels = torch.LongTensor([label for _, label in examples])
483
+ return images, labels
484
+
485
+ def get_iterator(self, epoch=0):
486
+ rand_seed = epoch
487
+ random.seed(rand_seed)
488
+ np.random.seed(rand_seed)
489
+ def load_function(iter_idx):
490
+ Exemplars, Test, Kall, nKbase = self.sample_episode()
491
+ Xt, Yt = self.createExamplesTensorData(Test)
492
+ Kall = torch.LongTensor(Kall)
493
+ if len(Exemplars) > 0:
494
+ Xe, Ye = self.createExamplesTensorData(Exemplars)
495
+ return Xe, Ye, Xt, Yt, Kall, nKbase
496
+ else:
497
+ return Xt, Yt, Kall, nKbase
498
+
499
+ tnt_dataset = tnt.dataset.ListDataset(
500
+ elem_list=range(self.epoch_size), load=load_function)
501
+ data_loader = tnt_dataset.parallel(
502
+ batch_size=self.batch_size,
503
+ num_workers=(0 if self.is_eval_mode else self.num_workers),
504
+ shuffle=(False if self.is_eval_mode else True))
505
+
506
+ return data_loader
507
+
508
+ def __call__(self, epoch=0):
509
+ return self.get_iterator(epoch)
510
+
511
+ def __len__(self):
512
+ return int(self.epoch_size / self.batch_size)
norm.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import cv2
4
+ import statistics
5
+ from tqdm import tqdm
6
+ from glob import glob
7
+
8
+ def calculate_normalization_parameters(path=None):
9
+ # data = pd.read_csv(path_to_train_csv)
10
+ data = glob('NIH/images/*.png')
11
+ mean = 0
12
+ std = 0
13
+ height = []
14
+ width = []
15
+ for i in tqdm(data):
16
+ image = cv2.imread(i)[:, :, ::-1]
17
+ h, w, _ = image.shape
18
+ image = image.reshape(-1, 3)
19
+ mean += np.mean(image, axis=0)
20
+ std += np.std(image, axis=0)
21
+ height.append(h)
22
+ width.append(w)
23
+ mean = mean / (255 * len(data))
24
+ std = std / (255 * len(data))
25
+ print("median height:", statistics.median(height))
26
+ print("median width:", statistics.median(width))
27
+ print("mean:", mean)
28
+ print("std:", std)
29
+ return mean, std
30
+
31
+
32
+ calculate_normalization_parameters()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ einops
2
+ timm
3
+ torchinfo
4
+ torchsummary
5
+ torchnet
6
+ wandb
7
+ adabelief_pytorch
8
+ scikit-plot
9
+ pandas
10
+ h5py
test.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import argparse
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch.utils.data import DataLoader
7
+
8
+ from torch.autograd import Variable
9
+
10
+ from tqdm import tqdm
11
+
12
+ from models.protonet_embedding import ProtoNetEmbedding
13
+ from models.R2D2_embedding import R2D2Embedding
14
+ from models.ResNet12_embedding import resnet12
15
+
16
+ from models.classification_heads import ClassificationHead
17
+
18
+ from utils import pprint, set_gpu, Timer, count_accuracy, log
19
+ from sklearn.metrics import confusion_matrix, f1_score, roc_curve, auc
20
+ import scikitplot as skplt
21
+ import matplotlib.pyplot as plt
22
+
23
+
24
+ import numpy as np
25
+ import os
26
+ import random
27
+
28
+ import pickle
29
+
30
+ from dataloader.chest import label_dict
31
+
32
+
33
+ import pandas as pd
34
+
35
+ def multiclass_roc(y_test, y_score,n_classes = 3):
36
+
37
+
38
+ # structures
39
+ fpr = dict()
40
+ tpr = dict()
41
+ roc_auc = dict()
42
+
43
+ # calculate dummies once
44
+ y_test_dummies = pd.get_dummies(y_test, drop_first=False).values
45
+ for i in range(n_classes):
46
+ fpr[i], tpr[i], _ = roc_curve(y_test_dummies[:, i], y_score[:, i])
47
+ roc_auc[i] = auc(fpr[i], tpr[i])
48
+
49
+ return fpr,tpr,roc_auc
50
+
51
+ # os.environ['CUDA_VISIBLE_DEVICES'] = "0"
52
+
53
+ def seed_everything(seed: int):
54
+ random.seed(seed)
55
+ os.environ["PYTHONHASHSEED"] = str(seed)
56
+ np.random.seed(seed)
57
+ torch.manual_seed(seed)
58
+ torch.cuda.manual_seed(seed)
59
+ torch.backends.cudnn.deterministic = True
60
+ torch.backends.cudnn.benchmark = True
61
+
62
+ def euclidean_dist(x, y):
63
+
64
+ # x: N x D
65
+ # y: M x D
66
+ n = x.size(0)
67
+ m = y.size(0)
68
+ d = x.size(1)
69
+
70
+ assert d == y.size(1)
71
+
72
+ x = x.unsqueeze(1).expand(n, m, d)
73
+ y = y.unsqueeze(0).expand(n, m, d)
74
+
75
+ return torch.pow(x - y, 2).sum(2)
76
+
77
+ def flip(x, dim):
78
+ xsize = x.size()
79
+ dim = x.dim() + dim if dim < 0 else dim
80
+ x = x.view(-1, *xsize[dim:])
81
+ x = x.view(x.size(0), x.size(1), -1)[:, getattr(torch.arange(x.size(1)-1,
82
+ -1, -1), ('cpu','cuda')[x.is_cuda])().long(), :]
83
+ return x.view(xsize)
84
+
85
+
86
+ def get_model(options):
87
+ # Choose the embedding network
88
+ if options.network == 'ProtoNet':
89
+ network = ProtoNetEmbedding().cuda()
90
+ elif options.network == 'R2D2':
91
+ network = R2D2Embedding().cuda()
92
+ elif options.network == 'ResNet':
93
+ if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
94
+ network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=5,num_layer=options.num_layer).cuda()
95
+ network = torch.nn.DataParallel(network)
96
+ else:
97
+ network = resnet12(avg_pool=False, drop_rate=0.1, dropblock_size=2,num_layer=options.num_layer).cuda()
98
+ else:
99
+ print ("Cannot recognize the network type")
100
+ assert(False)
101
+
102
+ # Choose the classification head
103
+ if opt.head == 'ProtoNet':
104
+ cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
105
+ elif options.head == 'SubspaceTrans':
106
+ cls_head = ClassificationHead(base_learner='SubspaceTrans').cuda()
107
+ elif options.head == 'Subspace':
108
+ cls_head = ClassificationHead(base_learner='Subspace').cuda()
109
+ elif options.head == 'SubspaceFast':
110
+ cls_head = ClassificationHead(base_learner='SubspaceFast').cuda()
111
+ elif opt.head == 'Ridge':
112
+ cls_head = ClassificationHead(base_learner='Ridge').cuda()
113
+ elif opt.head == 'R2D2':
114
+ cls_head = ClassificationHead(base_learner='R2D2').cuda()
115
+ elif opt.head == 'SVM':
116
+ cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
117
+ else:
118
+ print ("Cannot recognize the classification head type")
119
+ assert(False)
120
+
121
+ return (network, cls_head)
122
+
123
+ def get_dataset(options):
124
+ # Choose the embedding network
125
+ if options.dataset == 'miniImageNet':
126
+ from dataloader.mini_imagenet import MiniImageNet, FewShotDataloader
127
+ dataset_test = MiniImageNet(phase='test')
128
+ data_loader = FewShotDataloader
129
+ elif options.dataset == 'tieredImageNet':
130
+ from dataloader.tiered_imagenet import tieredImageNet, FewShotDataloader
131
+ dataset_test = tieredImageNet(phase='test')
132
+ data_loader = FewShotDataloader
133
+ elif options.dataset == 'CIFAR_FS':
134
+ from dataloader.CIFAR_FS import CIFAR_FS, FewShotDataloader
135
+ dataset_test = CIFAR_FS(phase='test')
136
+ data_loader = FewShotDataloader
137
+ elif options.dataset == 'FC100':
138
+ from dataloader.FC100 import FC100, FewShotDataloader
139
+ dataset_test = FC100(phase='test')
140
+ data_loader = FewShotDataloader
141
+ elif options.dataset == 'Chest':
142
+ from dataloader.chest import Chest, FewShotDataloader
143
+ dataset_test = Chest(phase='test')
144
+ data_loader = FewShotDataloader
145
+ else:
146
+ print ("Cannot recognize the dataset type")
147
+ assert(False)
148
+
149
+ return (dataset_test, data_loader)
150
+
151
+ #
152
+ if __name__ == '__main__':
153
+ parser = argparse.ArgumentParser()
154
+
155
+ #Changes
156
+ parser.add_argument('--gpu', default='3')
157
+ #Changes
158
+ parser.add_argument('--load',
159
+ default='experiments/group2_subspace30_CE_train/best_model.pth', ## your best model
160
+ help='path of the checkpoint file')
161
+ #Changes
162
+ parser.add_argument('--num_layer', type=int, default=30,
163
+ help='num of layer')
164
+
165
+ parser.add_argument('--episode', type=int, default=1000,
166
+ help='number of episodes to test')
167
+ parser.add_argument('--way', type=int, default=3,
168
+ help='number of classes in one test episode')
169
+ parser.add_argument('--shot', type=int, default=5,
170
+ help='number of support examples per training class')
171
+ parser.add_argument('--query', type=int, default=5,
172
+ help='number of query examples per training class')
173
+ parser.add_argument('--network', type=str, default='ResNet',
174
+ help='choose which embedding network to use. ProtoNet, R2D2, ResNet')
175
+ parser.add_argument('--head', type=str, default='Subspace',
176
+ help='choose which embedding network to use. ProtoNet, Ridge, R2D2, SVM')
177
+ parser.add_argument('--dataset', type=str, default='Chest',
178
+ help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
179
+
180
+
181
+ opt = parser.parse_args()
182
+
183
+ seed_everything(42)
184
+
185
+ (dataset_test, data_loader) = get_dataset(opt)
186
+
187
+ set_gpu(opt.gpu)
188
+
189
+ # Define the models
190
+ (embedding_net, cls_head) = get_model(opt)
191
+
192
+ # Load saved model checkpoints
193
+ saved_models = torch.load(opt.load)
194
+ embedding_net.load_state_dict(saved_models['embedding'])
195
+ embedding_net.eval()
196
+ cls_head.load_state_dict(saved_models['head'])
197
+ cls_head.eval()
198
+
199
+
200
+ aug=False
201
+
202
+ label_dict_inv = {v:k for k,v in label_dict.items()}
203
+
204
+ test_accuracies = []
205
+ per_class_accuracies = []
206
+ y_pred_list = []
207
+ y_list = []
208
+ dloader_test = data_loader(
209
+ dataset=dataset_test,
210
+ nKnovel=opt.way,
211
+ nKbase=0,
212
+ nExemplars=opt.shot, # num training examples per novel category
213
+ nTestNovel=opt.query * opt.way, # num test examples for all the novel categories
214
+ nTestBase=0, # num test examples for all the base categories
215
+ batch_size=1,
216
+ num_workers=1,
217
+ epoch_size=opt.episode, # num of batches per epoch
218
+ )
219
+
220
+ #print("epp: ", epp)
221
+
222
+ with torch.no_grad():
223
+ for i, batch in enumerate(tqdm(dloader_test()), 1):
224
+ data_support, labels_support, data_query, labels_query, _, _ = [x.cuda() for x in batch]
225
+
226
+ n_support = opt.way * opt.shot
227
+ n_query = opt.way * opt.query
228
+
229
+ if opt.shot == 1 and aug:
230
+ flipped_data_support = flip(data_support, 3)
231
+ data_support = torch.cat((data_support, flipped_data_support), dim=0)
232
+ labels_support = torch.cat((labels_support, labels_support), dim=0)
233
+
234
+ list_emb_support = embedding_net(data_support.reshape([-1] + list(data_support.shape[-3:])))
235
+ list_emb_query = embedding_net(data_query.reshape([-1] + list(data_query.shape[-3:])))
236
+
237
+ logits = torch.zeros(n_query, opt.way).cuda()
238
+
239
+ for emb_support, emb_query in zip(list_emb_support, list_emb_query):
240
+
241
+
242
+ emb_support = emb_support.view(1, opt.way, opt.shot, -1).mean(2)
243
+
244
+ emb_query = emb_query.reshape(1, n_query, -1)
245
+
246
+ dists = euclidean_dist(emb_query[0], emb_support[0])
247
+
248
+
249
+ logits += F.softmax(-dists, dim=1).view(1 * opt.way * opt.query, -1)
250
+
251
+
252
+
253
+ logits /= opt.num_layer
254
+
255
+ logits = logits.reshape(-1, opt.way)
256
+ labels_query = labels_query.reshape(-1)
257
+
258
+
259
+ acc,pca = count_accuracy(logits, labels_query)
260
+ test_accuracies.append(acc.item())
261
+ per_class_accuracies.append(pca)
262
+
263
+ y_pred_list.append(logits.detach().cpu().numpy())
264
+ y_list.append(labels_query.detach().cpu().numpy())
265
+
266
+ avg = np.mean(np.array(test_accuracies))
267
+ std = np.std(np.array(test_accuracies))
268
+ ci95 = 1.96 * std / np.sqrt(i + 1)
269
+
270
+ if i % 10 == 0:
271
+
272
+ # print(logits.detach().cpu().numpy())
273
+ # print(torch.argmax(logits, dim=1).view(-1))
274
+ # print(labels_query.detach().cpu().numpy())
275
+
276
+ pca = np.array(per_class_accuracies).mean(0)
277
+ pcs = np.array(per_class_accuracies).std(0)
278
+
279
+ print('Episode [{}/{}]:\t\t\tAccuracy: {:.2f} ± {:.2f} ({:.2f}) % ({:.2f} %)'\
280
+ .format(i, opt.episode, avg, ci95,std, acc))
281
+ print(f'{label_dict_inv[9]}: {pca[0]:.2f} ± {pcs[0]:.2f} % | {label_dict_inv[10]}: {pca[1]:.2f} ± {pcs[1]:.2f} % | {label_dict_inv[11]}: {pca[2]:.2f} ± {pcs[2]:.2f}%')
282
+
283
+
284
+
285
+ pca = np.array(per_class_accuracies).mean(0)
286
+ pcs = np.array(per_class_accuracies).std(0)
287
+
288
+ print("Mean")
289
+ print(pca)
290
+ print('Standard Deviation')
291
+ print(pcs)
292
+
293
+
294
+ y_pred_proba = np.array(
295
+ y_pred_list).reshape(-1, 3)
296
+
297
+ y_pred = np.argmax(y_pred_proba, axis=1)
298
+
299
+ y_true = np.array(y_list).reshape(-1)
300
+
301
+ f1 = f1_score(y_true, y_pred, average=None)
302
+
303
+ print('F1 Score')
304
+ print(f1)
305
+
306
+ fpr,tpr, auc = multiclass_roc(y_true,y_pred_proba)
307
+ save_tuple = (fpr,tpr,auc)
308
+
309
+ print(auc)
310
+
311
+ # Plots
312
+
313
+ #Changes
314
+ # with open('plot/group5_subspace25.pickle', 'wb') as f:
315
+ # pickle.dump(save_tuple, f)
316
+
317
+ #Changes
318
+ class_dict = {'Fibrosis': 0, 'Hernia': 1, 'Pneumonia': 2}
319
+ # class_dict = {'Mass': 0, 'Nodule': 1, 'Pleural_Thickening': 2}
320
+ # class_dict = {'Cardiomegaly': 0, 'Edema': 1, 'Emphysema': 2}
321
+ # class_dict = {'Consolidation': 0, 'Effusion': 1, 'Pneumothorax': 2}
322
+ # class_dict = {'Atelectasis': 0, 'Infiltration': 1, 'No Finding': 2}
323
+
324
+ class_dict_inv = {v: k for k, v in class_dict.items()}
325
+
326
+ y_true = np.array([class_dict_inv[i]
327
+ for i in np.array(y_list).reshape(-1)])
328
+
329
+ # print(np.array(y_pred_list).reshape(-1, 3).shape)
330
+ # print(np.array(y_list).reshape(-1).shape)
331
+ # print(y_list)
332
+ # print(np.array(y_pred_list).reshape(-1, 3))
333
+
334
+
335
+ # skplt.metrics.plot_roc(y_true, y_pred_proba,plot_micro=False, plot_macro=False)
336
+
337
+ #Changes
338
+ # plt.savefig('plot/group5_subspace25.png', dpi=1000)
339
+ # plt.show()
340
+
341
+
342
+
343
+
344
+ # python test_ortho_bcs.py --gpu 2 --load experiments/chest_exp1/best_model.pth --way 3 --dataset Chest
345
+ # python test_ortho_bcs.py --gpu 2 --load experiments/chest_exp1/best_model.pth --way 3 --dataset Chest
train.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import timm
3
+ import os
4
+ import sys
5
+ import argparse
6
+ import random
7
+ import numpy as np
8
+ from tqdm import tqdm
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from torch import linalg as LA
12
+ from models.classification_heads import ClassificationHead
13
+ from models.R2D2_embedding import R2D2Embedding
14
+ from models.protonet_embedding import ProtoNetEmbedding
15
+ from models.ResNet12_embedding import resnet12
16
+ import torch.nn as nn
17
+ from utils import set_gpu, Timer, count_accuracy, check_dir, log
18
+ import warnings
19
+ import wandb
20
+ from itertools import combinations
21
+
22
+ from torchsummary import summary
23
+ warnings.filterwarnings("ignore")
24
+
25
+
26
+ def one_hot(indices, depth):
27
+ """
28
+ Returns a one-hot tensor.
29
+ This is a PyTorch equivalent of Tensorflow's tf.one_hot.
30
+
31
+ Parameters:
32
+ indices: a (n_batch, m) Tensor or (m) Tensor.
33
+ depth: a scalar. Represents the depth of the one hot dimension.
34
+ Returns: a (n_batch, m, depth) Tensor or (m, depth) Tensor.
35
+ """
36
+
37
+ encoded_indicies = torch.zeros(indices.size() + torch.Size([depth])).cuda()
38
+ index = indices.view(indices.size()+torch.Size([1]))
39
+ encoded_indicies = encoded_indicies.scatter_(1, index, 1)
40
+
41
+ return encoded_indicies
42
+
43
+ def seed_everything(seed: int):
44
+ random.seed(seed)
45
+ os.environ["PYTHONHASHSEED"] = str(seed)
46
+ np.random.seed(seed)
47
+ torch.manual_seed(seed)
48
+ torch.cuda.manual_seed(seed)
49
+ torch.backends.cudnn.deterministic = True
50
+ torch.backends.cudnn.benchmark = True
51
+
52
+
53
+ def euclidean_dist(x, y):
54
+
55
+ # x: N x D
56
+ # y: M x D
57
+ n = x.size(0)
58
+ m = y.size(0)
59
+ d = x.size(1)
60
+
61
+ assert d == y.size(1)
62
+
63
+ x = x.unsqueeze(1).expand(n, m, d)
64
+ y = y.unsqueeze(0).expand(n, m, d)
65
+
66
+
67
+ return torch.pow(x - y, 2).sum(2)
68
+
69
+ def cosine_dist(x, y):
70
+
71
+ # x: N x D
72
+ # y: M x D
73
+ n = x.size(0)
74
+ m = y.size(0)
75
+ d = x.size(1)
76
+
77
+ assert d == y.size(1)
78
+
79
+ x = x.unsqueeze(1).expand(n, m, d)
80
+ y = y.unsqueeze(0).expand(n, m, d)
81
+
82
+
83
+
84
+ cos = nn.CosineSimilarity(dim=2, eps=1e-6)
85
+ out = 1 - cos(x,y)
86
+
87
+
88
+ return out
89
+
90
+
91
+ def get_model(options):
92
+ # Choose the embedding network
93
+ if options.network == 'ProtoNet':
94
+ network = ProtoNetEmbedding().cuda()
95
+ elif options.network == 'R2D2':
96
+ network = R2D2Embedding().cuda()
97
+ elif options.network == 'ResNet':
98
+ if options.dataset == 'miniImageNet' or options.dataset == 'tieredImageNet':
99
+ network = resnet12(avg_pool=False, drop_rate=0.1,
100
+ dropblock_size=5,num_layer=options.num_layer).cuda()
101
+ network = torch.nn.DataParallel(network) # , device_ids=[1, 2])
102
+ else:
103
+ network = resnet12(avg_pool=False, drop_rate=0.1,
104
+ dropblock_size=2,num_layer=options.num_layer).cuda()
105
+ else:
106
+ print("Cannot recognize the network type")
107
+ assert(False)
108
+
109
+ # Choose the classification head
110
+ if options.head == 'Subspace':
111
+ cls_head = ClassificationHead(base_learner='Subspace').cuda()
112
+ elif options.head == 'ProtoNet':
113
+ cls_head = ClassificationHead(base_learner='ProtoNet').cuda()
114
+ elif options.head == 'Ridge':
115
+ cls_head = ClassificationHead(base_learner='Ridge').cuda()
116
+ elif options.head == 'R2D2':
117
+ cls_head = ClassificationHead(base_learner='R2D2').cuda()
118
+ elif options.head == 'SVM':
119
+ cls_head = ClassificationHead(base_learner='SVM-CS').cuda()
120
+ else:
121
+ print("Cannot recognize the dataset type")
122
+ assert(False)
123
+
124
+ return (network, cls_head)
125
+
126
+ def get_dataset(options):
127
+ # Choose the embedding network
128
+ if options.dataset == 'miniImageNet':
129
+ from dataloader.mini_imagenet import MiniImageNet, FewShotDataloader
130
+ # change it to train only, this is including the validation set
131
+ dataset_train = MiniImageNet(phase='trainval')
132
+ dataset_val = MiniImageNet(phase='test')
133
+ data_loader = FewShotDataloader
134
+ elif options.dataset == 'tieredImageNet':
135
+ from dataloader.tiered_imagenet import tieredImageNet, FewShotDataloader
136
+ dataset_train = tieredImageNet(phase='train')
137
+ dataset_val = tieredImageNet(phase='test')
138
+ data_loader = FewShotDataloader
139
+ elif options.dataset == 'CIFAR_FS':
140
+ from dataloader.CIFAR_FS import CIFAR_FS, FewShotDataloader
141
+ dataset_train = CIFAR_FS(phase='train')
142
+ dataset_val = CIFAR_FS(phase='test')
143
+ data_loader = FewShotDataloader
144
+ elif options.dataset == 'Chest':
145
+ from dataloader.chest import Chest, FewShotDataloader
146
+ dataset_train = Chest(phase='train')
147
+ dataset_val = Chest(phase='val')
148
+ data_loader = FewShotDataloader
149
+ else:
150
+ print("Cannot recognize the dataset type")
151
+ assert(False)
152
+
153
+ return (dataset_train, dataset_val, data_loader)
154
+
155
+
156
+ if __name__ == '__main__':
157
+ parser = argparse.ArgumentParser()
158
+ parser.add_argument('--num-epoch', type=int, default=80,
159
+ help='number of training epochs')
160
+ parser.add_argument('--save-epoch', type=int, default=5,
161
+ help='frequency of model saving')
162
+ parser.add_argument('--train-shot', type=int, default=5,
163
+ help='number of support examples per training class')
164
+ parser.add_argument('--val-shot', type=int, default=5,
165
+ help='number of support examples per validation class')
166
+ parser.add_argument('--train-query', type=int, default=5,
167
+ help='number of query examples per training class')
168
+ parser.add_argument('--val-episode', type=int, default=600,
169
+ help='number of episodes per validation')
170
+ parser.add_argument('--val-query', type=int, default=5,
171
+ help='number of query examples per validation class')
172
+ parser.add_argument('--train-way', type=int, default=3,
173
+ help='number of classes in one training episode')
174
+ parser.add_argument('--test-way', type=int, default=3,
175
+ help='number of classes in one test (or validation) episode')
176
+ parser.add_argument('--save-path', default='experiments')
177
+
178
+ parser.add_argument('--wandbexperiment', default="group5_subspace30",type=str)
179
+ parser.add_argument('--gpu', default='0') # using 4 gpus
180
+ parser.add_argument('--num_layer', type=int, default=30,
181
+ help='number of linear layer')
182
+
183
+ # parser.add_argument('--gpu', default='0,1,2,3') # using 4 gpus
184
+ parser.add_argument('--network', type=str, default='ResNet',
185
+ help='choose which embedding network to use. ResNet')
186
+ parser.add_argument('--head', type=str, default='Subspace',
187
+ help='choose which classification head to use. Subspace, ProtoNet, R2D2, SVM')
188
+ parser.add_argument('--dataset', type=str, default='Chest',
189
+ help='choose which classification head to use. miniImageNet, tieredImageNet, CIFAR_FS, FC100')
190
+ parser.add_argument('--episodes-per-batch', type=int, default=1,
191
+ help='number of episodes per batch')
192
+ parser.add_argument('--eps', type=float, default=0.0,
193
+ help='epsilon of label smoothing')
194
+ parser.add_argument('--wandb', action="store_true")
195
+ parser.add_argument("--wandbkey", type=str,
196
+ default='db1158429a436f94565ac9eadecc6afe9e5a0b8f',
197
+ help='Wandb project key')
198
+
199
+
200
+ # python train_my.py --gpu 2 --dataset Chest --num_layer 5
201
+
202
+
203
+ opt = parser.parse_args()
204
+ seed_everything(42)
205
+ print(opt)
206
+ opt.save_path = os.path.join(opt.save_path,opt.wandbexperiment)
207
+
208
+
209
+ if opt.wandb:
210
+ os.system('wandb login {}'.format(opt.wandbkey))
211
+ wandb.init(name=opt.wandbexperiment,
212
+ project='chest-few-shot-final')
213
+ wandb.config.update(opt)
214
+
215
+ (dataset_train, dataset_val, data_loader) = get_dataset(opt)
216
+
217
+ # Dataloader of Gidaris & Komodakis (CVPR 2018)
218
+ dloader_train = data_loader(
219
+ dataset=dataset_train,
220
+ nKnovel=opt.train_way,
221
+ nKbase=0,
222
+ nExemplars=opt.train_shot, # num training examples per novel category
223
+ # num test examples for all the novel categories
224
+ nTestNovel=opt.train_way * opt.train_query,
225
+ nTestBase=0, # num test examples for all the base categories
226
+ batch_size=opt.episodes_per_batch,
227
+ num_workers=15,
228
+ epoch_size=opt.episodes_per_batch * 1000, # num of batches per epoch
229
+ )
230
+
231
+ dloader_val = data_loader(
232
+ dataset=dataset_val,
233
+ nKnovel=opt.test_way,
234
+ nKbase=0,
235
+ nExemplars=opt.val_shot, # num training examples per novel category
236
+ # num test examples for all the novel categories
237
+ nTestNovel=opt.val_query * opt.test_way,
238
+ nTestBase=0, # num test examples for all the base categories
239
+ batch_size=1,
240
+ num_workers=15,
241
+ epoch_size=1 * opt.val_episode, # num of batches per epoch
242
+ )
243
+
244
+ set_gpu(opt.gpu)
245
+ check_dir('./experiments/')
246
+ check_dir(opt.save_path)
247
+
248
+ log_file_path = os.path.join(opt.save_path, "train_log.txt")
249
+ log(log_file_path, str(vars(opt)))
250
+
251
+ (embedding_net, cls_head) = get_model(opt)
252
+
253
+ optimizer = torch.optim.SGD(embedding_net.parameters(),lr=3e-3)
254
+
255
+
256
+ def lambda_epoch(e): return 1.0 if e < 12 else (
257
+ 0.025 if e < 30 else 0.0032 if e < 45 else (0.0014 if e < 57 else (0.00052)))
258
+
259
+ ## tieredimagenet###
260
+ # lambda_epoch = lambda e: 1.0 if e < 20 else (
261
+ # 0.012 if e < 45 else 0.0052 if e < 59 else (0.00054 if e < 68 else (0.00012)))
262
+
263
+ lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
264
+ optimizer, lr_lambda=lambda_epoch, last_epoch=-1)
265
+
266
+ max_val_acc = 0.0
267
+
268
+ timer = Timer()
269
+ x_entropy = torch.nn.CrossEntropyLoss()
270
+
271
+
272
+ index = list(combinations([i for i in range(opt.num_layer)], 2))
273
+
274
+ for epoch in range(1, opt.num_epoch + 1):
275
+
276
+
277
+ for param_group in optimizer.param_groups:
278
+ epoch_learning_rate = param_group['lr']
279
+
280
+ log(log_file_path, 'Train Epoch: {}\tLearning Rate: {:.4f}'.format(
281
+ epoch, epoch_learning_rate))
282
+
283
+ _, _ = [x.train() for x in (embedding_net, cls_head)]
284
+
285
+ train_accuracies = []
286
+ train_losses = []
287
+
288
+ train_n_support = opt.train_way * opt.train_shot
289
+ train_n_query = opt.train_way * opt.train_query
290
+
291
+
292
+
293
+
294
+ for i, batch in enumerate(tqdm(dloader_train(epoch)), 1):
295
+
296
+ data_support, labels_support, data_query, labels_query, _, _ = [
297
+ x.cuda() for x in batch]
298
+
299
+ list_emb_query = embedding_net(data_query.view(
300
+ [-1] + list(data_query.shape[-3:]))) # [100, 2560]
301
+ list_emb_support = embedding_net(data_support.view(
302
+ [-1] + list(data_support.shape[-3:]))) # [100, 3, 32, 32] -> [100, 2560]
303
+
304
+
305
+ loss_weights = 0.
306
+ for ind in index:
307
+
308
+ loss_weights += torch.abs(F.cosine_similarity(getattr(embedding_net,f'linear{ind[0]}_1').weight.view(-1),getattr(embedding_net,f'linear{ind[1]}_1').weight.view(-1),dim=0))
309
+
310
+
311
+ log_p_y = torch.zeros(
312
+ opt.episodes_per_batch * opt.train_way * opt.train_query, opt.train_way).cuda()
313
+
314
+ for emb_support,emb_query in zip(list_emb_support, list_emb_query):
315
+ # emb_support = emb_support.view(
316
+ # opt.episodes_per_batch, train_n_support, -1) # [4, 25, 2560]
317
+ if opt.train_shot == 1:
318
+ emb_support = emb_support.view(
319
+ opt.episodes_per_batch, opt.train_way, -1) # [4,5,5,2560] --> [4, 5, 20]
320
+ else:
321
+ emb_support = emb_support.view(
322
+ opt.episodes_per_batch, opt.train_way, opt.train_shot, -1).mean(2) # [4,5,5,2560] --> [4, 5, 20]
323
+
324
+ emb_query = emb_query.view(
325
+ opt.episodes_per_batch, train_n_query, -1) # [4, 25, 2560]
326
+
327
+
328
+ dists = torch.stack(
329
+ [euclidean_dist(emb_query[i], emb_support[i]) for i in range(opt.episodes_per_batch)]) # [4,25,5]
330
+
331
+
332
+
333
+ log_p_y += F.softmax(-dists,
334
+ dim=2).view(opt.episodes_per_batch* opt.train_way* opt.train_query, -1) # [100,5]
335
+
336
+
337
+ log_p_y /= opt.num_layer
338
+
339
+
340
+ smoothed_one_hot = one_hot(
341
+ labels_query.view(-1), opt.train_way) # [100,5]
342
+
343
+ loss = x_entropy(
344
+ log_p_y.view(-1, opt.train_way), labels_query.view(-1))
345
+
346
+
347
+ acc, _ = count_accuracy(
348
+ log_p_y.view(-1, opt.train_way), labels_query.view(-1))
349
+
350
+ train_accuracies.append(acc.item())
351
+ train_losses.append(loss.item())
352
+
353
+ if (i % 100 == 0):
354
+ train_acc_avg = np.mean(np.array(train_accuracies))
355
+ log(log_file_path, 'Train Epoch: {}\tBatch: [{}/{}]\tLoss: {:.4f}\tAccuracy: {:.2f} % ({:.2f} %)'.format(
356
+ epoch, i, len(dloader_train), loss.item(), train_acc_avg, acc))
357
+ if opt.wandb:
358
+
359
+ wandb.log({'Epoch': epoch,
360
+ 'lr': optimizer.param_groups[0]['lr'],"Loss":loss.item(),"Avg Accuracy":train_acc_avg,'Accuracy':acc,
361
+ 'cosine loss':loss_weights})
362
+
363
+
364
+ optimizer.zero_grad()
365
+
366
+ loss += loss_weights
367
+ loss.backward()
368
+
369
+ optimizer.step()
370
+
371
+ # Evaluate on the validation split
372
+ _, _ = [x.eval() for x in (embedding_net, cls_head)]
373
+
374
+ val_accuracies = []
375
+ val_losses = []
376
+
377
+
378
+ with torch.no_grad():
379
+
380
+ for i, batch in enumerate(tqdm(dloader_val(epoch)), 1):
381
+ data_support, labels_support, data_query, labels_query, _, _ = [
382
+ x.cuda() for x in batch]
383
+
384
+ test_n_support = opt.test_way * opt.val_shot
385
+ test_n_query = opt.test_way * opt.val_query
386
+
387
+
388
+ list_emb_support = embedding_net(data_support.view(
389
+ [-1] + list(data_support.shape[-3:])))
390
+ list_emb_query = embedding_net(data_query.view(
391
+ [-1] + list(data_query.shape[-3:])))
392
+
393
+
394
+ logit_query = torch.zeros(test_n_query, opt.test_way).cuda()
395
+
396
+ for emb_support, emb_query in zip(list_emb_support, list_emb_query):
397
+
398
+ # print(emb_support.size())
399
+ emb_support = emb_support.view(1, test_n_support, -1)
400
+ # print(emb_support.size())
401
+
402
+ emb_support = emb_support.view(
403
+ 1, opt.train_way, opt.train_shot, -1).mean(2) # [4, 5, 20]
404
+
405
+ emb_query = emb_query.view(1, test_n_query, -1)
406
+
407
+ # print(emb_support.size(),emb_query.size())
408
+
409
+ dists = torch.stack(
410
+ [euclidean_dist(emb_query[i], emb_support[i]) for i in range(emb_query.size(0))])
411
+
412
+ logit_query += F.softmax(-dists, dim=2).view(1 *
413
+ opt.test_way * opt.val_query, -1) # []
414
+
415
+ logit_query /= opt.num_layer
416
+
417
+
418
+ loss = x_entropy(
419
+ logit_query.view(-1, opt.test_way), labels_query.view(-1))
420
+ acc, _ = count_accuracy(
421
+ logit_query.view(-1, opt.test_way), labels_query.view(-1))
422
+
423
+ val_accuracies.append(acc.item())
424
+ val_losses.append(loss.item())
425
+
426
+ val_acc_avg = np.mean(np.array(val_accuracies))
427
+ val_acc_ci95 = 1.96 * \
428
+ np.std(np.array(val_accuracies)) / np.sqrt(opt.val_episode)
429
+
430
+ val_loss_avg = np.mean(np.array(val_losses))
431
+
432
+ if val_acc_avg > max_val_acc:
433
+ max_val_acc = val_acc_avg
434
+ torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict()},
435
+ os.path.join(opt.save_path, 'best_model.pth'))
436
+
437
+
438
+
439
+ log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} % (Best)'
440
+ .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
441
+ else:
442
+ log(log_file_path, 'Validation Epoch: {}\t\t\tLoss: {:.4f}\tAccuracy: {:.2f} ± {:.2f} %'
443
+ .format(epoch, val_loss_avg, val_acc_avg, val_acc_ci95))
444
+
445
+ if opt.wandb:
446
+ wandb.log({"Validation Loss":val_loss_avg,"Val Avg Accuracy":val_acc_avg})
447
+
448
+ torch.save({'embedding': embedding_net.state_dict(
449
+ ), 'head': cls_head.state_dict()}, os.path.join(opt.save_path, 'last_epoch.pth'))
450
+
451
+ if epoch % opt.save_epoch == 0:
452
+ torch.save({'embedding': embedding_net.state_dict(), 'head': cls_head.state_dict(
453
+ )}, os.path.join(opt.save_path, 'epoch_{}.pth'.format(epoch)))
454
+
455
+ log(log_file_path, 'Elapsed Time: {}/{}\n'.format(timer.measure(),
456
+ timer.measure(epoch / float(opt.num_epoch))))
457
+
458
+ # lr_scheduler.step()
utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import pprint
4
+ import torch
5
+ from sklearn.metrics import confusion_matrix
6
+
7
+ def set_gpu(x):
8
+ os.environ['CUDA_VISIBLE_DEVICES'] = x
9
+ print('using gpu:', x)
10
+
11
+ def check_dir(path):
12
+ '''
13
+ Create directory if it does not exist.
14
+ path: Path of directory.
15
+ '''
16
+ if not os.path.exists(path):
17
+ os.mkdir(path)
18
+
19
+ def count_accuracy(logits, label):
20
+ pred = torch.argmax(logits, dim=1).view(-1)
21
+ label = label.view(-1)
22
+
23
+ acc = [0 for c in range(3)]
24
+ for c in range(3):
25
+ acc[c] = (pred.eq(label) * label.eq(c)).float() / max((label.eq(c)).sum(), 1)
26
+
27
+
28
+ matrix = confusion_matrix(label.cpu().detach().numpy(), pred.cpu().detach().numpy())
29
+ pca = matrix.diagonal()/matrix.sum(axis=1)
30
+
31
+ accuracy = 100 * pred.eq(label).float().mean()
32
+ return accuracy, pca * 100
33
+
34
+ class Timer():
35
+ def __init__(self):
36
+ self.o = time.time()
37
+
38
+ def measure(self, p=1):
39
+ x = (time.time() - self.o) / float(p)
40
+ x = int(x)
41
+ if x >= 3600:
42
+ return '{:.1f}h'.format(x / 3600)
43
+ if x >= 60:
44
+ return '{}m'.format(round(x / 60))
45
+ return '{}s'.format(x)
46
+
47
+ def log(log_file_path, string):
48
+ '''
49
+ Write one line of log into screen and file.
50
+ log_file_path: Path of log file.
51
+ string: String to write in log file.
52
+ '''
53
+ with open(log_file_path, 'a+') as f:
54
+ f.write(string + '\n')
55
+ f.flush()
56
+ print(string)