peeyushsinghal commited on
Commit
2adcc85
1 Parent(s): 9f95236

files from da

Browse files
Files changed (2) hide show
  1. app.py +474 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """Gradio_C1_C2_v3.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1KBTZm5X8qNslEbM7sLFu2IO-d2kg1XZY
8
+ """
9
+
10
+
11
+ import gradio as gr
12
+ import os
13
+ from PIL import Image
14
+ from torchvision import datasets,transforms
15
+ import random
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+ from torch.autograd import Function
20
+ from collections import OrderedDict
21
+ import pandas as pd
22
+ import io
23
+ import base64
24
+
25
+ # # checking the mounted drive and mounting if not done
26
+ # if not os.path.exists('/content/gdrive'):
27
+ # from google.colab import drive
28
+ # drive.mount('/content/gdrive')
29
+ # else:
30
+ # print("Google Drive is already mounted.")
31
+
32
+ list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified.pt')
33
+
34
+ class CustomDataset(torch.utils.data.Dataset):
35
+ def __init__(self, data):
36
+ self.data = data
37
+
38
+ def __len__(self):
39
+ return len(self.data)
40
+
41
+ def __getitem__(self, idx):
42
+ imgs, labels, image_names = self.data[idx]
43
+ return imgs, labels, image_names
44
+
45
+ dataset_c1 = CustomDataset(list_c1)
46
+
47
+ # Create a dataloader with the filtered dataset
48
+ dataloader_c1 = torch.utils.data.DataLoader(dataset_c1, batch_size=10, shuffle=True)
49
+
50
+ transform_to_pil = transforms.ToPILImage()
51
+
52
+ def get_images():
53
+ images, labels,image_names = next(iter(dataloader_c1))
54
+ pil_images = [transform_to_pil(image) for image in images]
55
+ return pil_images, labels.tolist()
56
+
57
+ list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified.pt')
58
+ dataset_c2 = CustomDataset(list_c2)
59
+ dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
60
+ def get_images_2():
61
+ images, labels,image_names = next(iter(dataloader_c2))
62
+ pil_images = [transform_to_pil(image) for image in images]
63
+ return pil_images, labels.tolist()
64
+
65
+ # next(iter(dataloader_c1))
66
+
67
+ def get_device():
68
+ if torch.cuda.is_available():
69
+ device = "cuda"
70
+ elif torch.backends.mps.is_available():
71
+ device = "mps"
72
+ else:
73
+ device = "cpu"
74
+ print("Device Selected:", device)
75
+ return device
76
+
77
+ device = get_device()
78
+
79
+ class GradientReversalFn(Function):
80
+ @staticmethod
81
+ def forward(ctx, x, alpha):
82
+ ctx.alpha = alpha
83
+
84
+ return x.view_as(x)
85
+
86
+ @staticmethod
87
+ def backward(ctx, grad_output):
88
+ output = grad_output.neg() * ctx.alpha
89
+
90
+ return output, None
91
+
92
+ class Network(nn.Module):
93
+ def __init__(self, num_classes = 10):
94
+ super(Network, self).__init__() # Initialize the parent class
95
+
96
+ drop_out_value = 0.1
97
+
98
+ #---------------------Feature Extractor Network------------------------#
99
+ self.feature_extractor = nn.Sequential(
100
+ # Input Block
101
+ nn.Conv2d(3, 16, 3, bias=False), # In: 3x28x28, Out: 16x26x26, RF: 3x3, Stride: 1
102
+ nn.ReLU(),
103
+ nn.BatchNorm2d(16),
104
+ nn.Dropout(drop_out_value),
105
+
106
+ # Conv Block 2
107
+ nn.Conv2d(16, 16, 3, bias=False), # In: 16x26x26, Out: 16x24x24, RF: 5x5, Stride: 1
108
+ nn.ReLU(),
109
+ nn.BatchNorm2d(16),
110
+ nn.Dropout(drop_out_value),
111
+
112
+ # Conv Block 3
113
+ nn.Conv2d(16, 16, 3, bias=False), # In: 16x24x24, Out: 16x22x22, RF: 7x7, Stride: 1
114
+ nn.ReLU(),
115
+ nn.BatchNorm2d(16),
116
+ nn.Dropout(drop_out_value),
117
+
118
+ # Transition Block 1
119
+ nn.MaxPool2d(kernel_size=2, stride=2), # In: 16x22x22, Out: 16x11x11, RF: 8x8, Stride: 2
120
+
121
+ # Conv Block 4
122
+ nn.Conv2d(16, 16, 3, bias=False), # In: 16x11x11, Out: 16x9x9, RF: 12x12, Stride: 1
123
+ nn.ReLU(),
124
+ nn.BatchNorm2d(16),
125
+ nn.Dropout(drop_out_value),
126
+
127
+ # Conv Block 5
128
+ nn.Conv2d(16, 32, 3, bias=False), # In: 16x9x9, Out: 32x7x7, RF: 16x16, Stride: 1
129
+ nn.ReLU(),
130
+ nn.BatchNorm2d(32),
131
+ nn.Dropout(drop_out_value),
132
+
133
+ # Output Block
134
+ nn.Conv2d(32, 64, 1, bias=False), # In: 32x7x7, Out: 64x7x7, RF: 16x16, Stride: 1
135
+
136
+ # Global Average Pooling
137
+ nn.AvgPool2d(7) # In: 64x7x7, Out: 64x1x1, RF: 16x16, Stride: 7
138
+ )
139
+
140
+ #---------------------Class Classifier Network------------------------#
141
+ self.class_classifier = nn.Sequential(nn.ReLU(),
142
+ nn.Dropout(p=drop_out_value),
143
+ nn.Linear(64,50),
144
+ nn.BatchNorm1d(50), # added batch norm to improve accuracy
145
+ nn.ReLU(),
146
+ nn.Dropout(p=drop_out_value),
147
+ nn.Linear(50,num_classes))
148
+
149
+ #---------------------Label Classifier Network------------------------#
150
+ self.domain_classifier = nn.Sequential(nn.ReLU(),
151
+ nn.Dropout(p=drop_out_value),
152
+ nn.Linear(64,50),
153
+ nn.BatchNorm1d(50), # added batch norm to improve accuracy
154
+ nn.ReLU(),
155
+ nn.Dropout(p=drop_out_value),
156
+ nn.Linear(50,2))
157
+ def forward(self, input_data, alpha = 1.0):
158
+ if input_data.data.shape[1] == 1:
159
+ input_data = input_data.expand(input_data.data.shape[0], 3, img_size, img_size)
160
+
161
+ input_data = self.feature_extractor(input_data)
162
+
163
+ features = input_data.view(input_data.size(0), -1) # Flatten the output for fully connected layer
164
+
165
+ reverse_features = GradientReversalFn.apply(features, alpha)
166
+ class_output = self.class_classifier(features)
167
+ domain_output = self.domain_classifier(reverse_features)
168
+
169
+ return class_output, domain_output, features
170
+
171
+ ## NON DANN
172
+ # Instantiate the model (make sure it has the same architecture)
173
+ loaded_model_non_dann = Network()
174
+ loaded_model_non_dann = loaded_model_non_dann.to(device)
175
+ # Load the saved state dictionary
176
+ loaded_model_non_dann.load_state_dict(torch.load('non_dann_26_06.pt', map_location=device), strict=False)
177
+ loaded_model_non_dann.eval()
178
+
179
+ ## DANN
180
+ # Instantiate the model (make sure it has the same architecture)
181
+ loaded_model_dann = Network()
182
+ loaded_model_dann = loaded_model_dann.to(device)
183
+ # Load the saved state dictionary
184
+ loaded_model_dann.load_state_dict(torch.load('dann_26_06.pt', map_location=device), strict=False)
185
+ loaded_model_dann.eval()
186
+
187
+ img_size = 28 # for mnist
188
+ cpu_batch_size = 10
189
+ class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
190
+
191
+ def classify_image_both(image):
192
+ target_test_transforms = transforms.Compose([
193
+ transforms.Resize(img_size),
194
+ transforms.ToTensor(),# converts to tesnor
195
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
196
+ ])
197
+
198
+
199
+ target_transformed_image = target_test_transforms(image)
200
+ image_tensor = target_transformed_image.to(device).unsqueeze(0)
201
+
202
+ list_confidences = []
203
+ for model in [loaded_model_non_dann, loaded_model_dann]:
204
+ model.eval()
205
+ logits,_,_ = model(image_tensor)
206
+ output = F.softmax(logits.view(-1), dim = -1)
207
+
208
+ confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))]
209
+ confidences.sort(key=lambda x: x[1], reverse=True)
210
+ confidences = OrderedDict(confidences[:3])
211
+ label = torch.argmax(output).item()
212
+ list_confidences.append(confidences)
213
+
214
+
215
+ return list_confidences[0],list_confidences[1]
216
+
217
+ ### SOURCE DATA - MNIST
218
+
219
+ # Test Phase transformations
220
+ test_transforms = transforms.Compose([
221
+ # transforms.Resize(img_size),
222
+ transforms.ToTensor(),# converts to tesnor
223
+ # transforms.Normalize((0.1307,), (0.3081,))
224
+ ])
225
+ transform_to_pil = transforms.ToPILImage()
226
+ test = datasets.MNIST('./data',
227
+ train=False,
228
+ download=True,
229
+ transform=test_transforms)
230
+
231
+ dataloader_args = dict(shuffle=True, batch_size=cpu_batch_size)
232
+
233
+ mnist_loader = torch.utils.data.DataLoader(
234
+ dataset = test,
235
+ **dataloader_args
236
+ )
237
+
238
+ def get_mnist_images():
239
+ images, labels = next(iter(mnist_loader))
240
+ pil_images = [transform_to_pil(image) for image in images]
241
+ return pil_images, labels.tolist()
242
+
243
+ splits = {'train': 'data/train-00000-of-00001-571b6b1e2c195186.parquet', 'test': 'data/test-00000-of-00001-ba3ad971b105ff65.parquet'}
244
+ df = pd.read_parquet("hf://datasets/Mike0307/MNIST-M/" + splits["test"])
245
+
246
+ class MNIST_M(torch.utils.data.Dataset):
247
+ def __init__(self, dataframe, transform=None):
248
+ self.dataframe = dataframe
249
+ self.transform = transform
250
+
251
+ def __len__(self):
252
+ return len(self.dataframe)
253
+
254
+ def __getitem__(self, idx):
255
+ # Get image and label from dataframe
256
+ img_data = self.dataframe.iloc[idx]['image']['bytes']
257
+ label = self.dataframe.iloc[idx]['label']
258
+ img_path = self.dataframe.iloc[idx]['image']['path']
259
+
260
+ # Decode image data (assuming it's base64 encoded)
261
+ img = Image.open(io.BytesIO(img_data))
262
+
263
+
264
+ # Apply transformations if any
265
+ if self.transform:
266
+ img = self.transform(img)
267
+
268
+ return img, label,img_path
269
+
270
+
271
+ # Test Phase transformations
272
+ target_test_transforms = transforms.Compose([
273
+ transforms.Resize(img_size),
274
+ transforms.ToTensor(),# converts to tesnor
275
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
276
+ ])
277
+
278
+
279
+ transform_to_pil = transforms.ToPILImage()
280
+
281
+
282
+ # Create dataset
283
+ target_test_dataset = MNIST_M(dataframe=df, transform=target_test_transforms)
284
+ target_test_dataloader = torch.utils.data.DataLoader(target_test_dataset, batch_size=cpu_batch_size, shuffle=True)
285
+ def get_mnist_m_images():
286
+ images, labels,image_names = next(iter(target_test_dataloader))
287
+ pil_images = [transform_to_pil(image) for image in images]
288
+ return pil_images, labels.tolist()
289
+
290
+ mnist_images, mnist_labels = get_mnist_images()
291
+ mnist_m_images,mnist_m_labels = get_mnist_m_images()
292
+
293
+ def classify_image_inference(image):
294
+ # print(image.mode)
295
+ image_transforms = None
296
+ if image.mode == "L":
297
+ # image = image.convert("RGB")
298
+ source = 'MNIST'
299
+ image_transforms = transforms.Compose([
300
+ transforms.Resize(img_size),
301
+ transforms.ToTensor(),# converts to tesnor
302
+ transforms.Normalize((0.1307,), (0.3081,))
303
+ ])
304
+ else:
305
+ source = 'MNIST-M'
306
+ image_transforms = transforms.Compose([
307
+ transforms.Resize(img_size),
308
+ transforms.ToTensor(),# converts to tesnor
309
+ transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
310
+ ])
311
+
312
+ transformed_image = image_transforms(image)
313
+ image_tensor = transformed_image.to(device).unsqueeze(0)
314
+
315
+ list_confidences = []
316
+ for model in [loaded_model_non_dann, loaded_model_dann]:
317
+ model.eval()
318
+ logits,_,_ = model(image_tensor)
319
+ output = F.softmax(logits.view(-1), dim = -1)
320
+
321
+ confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))]
322
+ confidences.sort(key=lambda x: x[1], reverse=True)
323
+ confidences = OrderedDict(confidences[:3])
324
+ label = torch.argmax(output).item()
325
+ list_confidences.append(confidences)
326
+
327
+
328
+ return list_confidences[0],list_confidences[1]
329
+
330
+ def display_image():
331
+ # Load the image from a local file
332
+ image = Image.open("mnist-m.JPG")
333
+ return image
334
+
335
+ with gr.Blocks() as demo:
336
+ with gr.Tab("Introduction"):
337
+ gr.Markdown("## Domain Adaptation in Deep Networks - Demonstration")
338
+ with gr.Row():
339
+ with gr.Column():
340
+ image_output = gr.Image(value=display_image(), label = "source and target",height = 256, width = 256, show_label = True)
341
+ gr.Markdown(
342
+ '''
343
+ Source - MNIST
344
+ ------
345
+ - The MNIST database (Modified National Institute of Standards and Technology database) is a large collection of handwritten digits.
346
+ - It has a training set of 60,000 examples, and a test set of 10,000 examples.
347
+ - 28 x 28 size
348
+ - 1 channel
349
+
350
+ '''
351
+ )
352
+ gr.Markdown(
353
+ '''
354
+ Target - MNIST-M
355
+ -------
356
+ - MNIST-M is created by combining MNIST digits with the patches randomly extracted from color photos of BSDS500 as their background.
357
+ - It contains 59,001 training and 90,001 test images.
358
+ - 28 x 28 size
359
+ - 3 channels
360
+ '''
361
+ )
362
+
363
+ gr.Markdown(
364
+ '''
365
+ Please click on the tabs, for more functionality
366
+ -------
367
+ - Inferencing on NonDANN and DANN : Infer MNIST or MNISTM on both Models
368
+ - Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify : Curated list which misclassify on NON DANN but classifies well on NonDANN
369
+ - Case 2: MNIST_M_Both_Misclassify : Curated list which misclassify Both on NON DANN and DANN
370
+ '''
371
+ )
372
+
373
+
374
+
375
+ ################################################
376
+ with gr.Tab("Inferencing on NonDANN and DANN"):
377
+ with gr.Row():
378
+ with gr.Column():
379
+ input_image_classify_mnist = gr.Image(label="Classify MNIST Digit", type = "pil", height = 256, width = 256, image_mode = 'L')
380
+ button_classify_mnist = gr.Button("Submit to Classify MNIST Image", visible = True, size ='sm')
381
+ with gr.Column():
382
+ with gr.Row():
383
+ label_classify_mnist_non_dann = gr.Label(label = "NON DANN Predicted MNIST label", num_top_classes=2, visible = True)
384
+ with gr.Row():
385
+ label_classify_mnist_dann = gr.Label(label = "DANN Predicted MNIST label", num_top_classes=2, visible = True)
386
+ with gr.Row():
387
+ gr.Examples( [img.convert("L") for img in mnist_images],
388
+ inputs=[input_image_classify_mnist], label = "Select an example MNIST Image")
389
+
390
+ with gr.Row():
391
+ with gr.Column():
392
+ input_image_classify_mnist_m = gr.Image(label="Classify MNIST M Digit", type = "pil", height = 256, width = 256, image_mode = 'RGB')
393
+ button_classify_mnist_m = gr.Button("Submit to Classify MNIST M Image", visible = True, size ='sm')
394
+ with gr.Column():
395
+ with gr.Row():
396
+ label_classify_mnist_m_non_dann = gr.Label(label = "NON DANN Predicted MNIST M label", num_top_classes=2, visible = True)
397
+ with gr.Row():
398
+ label_classify_mnist_m_dann = gr.Label(label = "DANN Predicted MNIST M label", num_top_classes=2, visible = True)
399
+ with gr.Row():
400
+ gr.Examples( [img.convert("RGB") for img in mnist_m_images],
401
+ inputs=[input_image_classify_mnist_m], label = "Select an example MNIST M Image")
402
+ with gr.Row():
403
+ gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels]}')
404
+
405
+ button_classify_mnist.click(fn=classify_image_inference,
406
+ inputs=[input_image_classify_mnist],
407
+ outputs=[label_classify_mnist_non_dann, label_classify_mnist_dann])
408
+
409
+ button_classify_mnist_m.click(fn=classify_image_inference,
410
+ inputs=[input_image_classify_mnist_m],
411
+ outputs=[label_classify_mnist_m_non_dann, label_classify_mnist_m_dann])
412
+
413
+
414
+ ######################
415
+ with gr.Tab("Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify"):
416
+ # with gr.Row():
417
+ # radio_model = gr.Radio(["Baseline (Non-DANN)", "DANN"],
418
+ # label="Select the model you want to use.",
419
+ # value="Baseline (Non-DANN)", # Set default value
420
+ # scale=2)
421
+ with gr.Row():
422
+ with gr.Column():
423
+ input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256)
424
+ button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm')
425
+
426
+ with gr.Column():
427
+ with gr.Row():
428
+ label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True)
429
+ with gr.Row():
430
+ label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True)
431
+
432
+ mnist_m_images_1,mnist_m_labels_1 = get_images()
433
+
434
+ with gr.Row():
435
+ gr.Examples(mnist_m_images_1,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working
436
+
437
+ with gr.Row():
438
+ gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_1]}')
439
+
440
+ button_classify_both.click(fn=classify_image_both,
441
+ inputs=[input_image_classify_both],
442
+ outputs=[label_classify_non_dann,label_classify_dann])
443
+
444
+
445
+ ########################################################################
446
+
447
+ with gr.Tab("Case 2 - Show both: MNIST_M_Both_Misclassify"):
448
+
449
+ with gr.Row():
450
+ with gr.Column():
451
+ input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256)
452
+ button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm')
453
+
454
+ with gr.Column():
455
+ with gr.Row():
456
+ label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True)
457
+ with gr.Row():
458
+ label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True)
459
+
460
+ mnist_m_images_2,mnist_m_labels_2 = get_images_2()
461
+
462
+ with gr.Row():
463
+ gr.Examples(mnist_m_images_2,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working
464
+
465
+ with gr.Row():
466
+ gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_2]}')
467
+
468
+
469
+ button_classify_both.click(fn=classify_image_both,
470
+ inputs=[input_image_classify_both],
471
+ outputs=[label_classify_non_dann,label_classify_dann])
472
+
473
+
474
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ numpy
4
+ grad-cam
5
+ pandas
6
+ gradio
7
+ Pillow