peeyushsinghal's picture
files from da
2adcc85 verified
raw
history blame
18.3 kB
# -*- coding: utf-8 -*-
"""Gradio_C1_C2_v3.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1KBTZm5X8qNslEbM7sLFu2IO-d2kg1XZY
"""
import gradio as gr
import os
from PIL import Image
from torchvision import datasets,transforms
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Function
from collections import OrderedDict
import pandas as pd
import io
import base64
# # checking the mounted drive and mounting if not done
# if not os.path.exists('/content/gdrive'):
# from google.colab import drive
# drive.mount('/content/gdrive')
# else:
# print("Google Drive is already mounted.")
list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified.pt')
class CustomDataset(torch.utils.data.Dataset):
def __init__(self, data):
self.data = data
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
imgs, labels, image_names = self.data[idx]
return imgs, labels, image_names
dataset_c1 = CustomDataset(list_c1)
# Create a dataloader with the filtered dataset
dataloader_c1 = torch.utils.data.DataLoader(dataset_c1, batch_size=10, shuffle=True)
transform_to_pil = transforms.ToPILImage()
def get_images():
images, labels,image_names = next(iter(dataloader_c1))
pil_images = [transform_to_pil(image) for image in images]
return pil_images, labels.tolist()
list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified.pt')
dataset_c2 = CustomDataset(list_c2)
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True)
def get_images_2():
images, labels,image_names = next(iter(dataloader_c2))
pil_images = [transform_to_pil(image) for image in images]
return pil_images, labels.tolist()
# next(iter(dataloader_c1))
def get_device():
if torch.cuda.is_available():
device = "cuda"
elif torch.backends.mps.is_available():
device = "mps"
else:
device = "cpu"
print("Device Selected:", device)
return device
device = get_device()
class GradientReversalFn(Function):
@staticmethod
def forward(ctx, x, alpha):
ctx.alpha = alpha
return x.view_as(x)
@staticmethod
def backward(ctx, grad_output):
output = grad_output.neg() * ctx.alpha
return output, None
class Network(nn.Module):
def __init__(self, num_classes = 10):
super(Network, self).__init__() # Initialize the parent class
drop_out_value = 0.1
#---------------------Feature Extractor Network------------------------#
self.feature_extractor = nn.Sequential(
# Input Block
nn.Conv2d(3, 16, 3, bias=False), # In: 3x28x28, Out: 16x26x26, RF: 3x3, Stride: 1
nn.ReLU(),
nn.BatchNorm2d(16),
nn.Dropout(drop_out_value),
# Conv Block 2
nn.Conv2d(16, 16, 3, bias=False), # In: 16x26x26, Out: 16x24x24, RF: 5x5, Stride: 1
nn.ReLU(),
nn.BatchNorm2d(16),
nn.Dropout(drop_out_value),
# Conv Block 3
nn.Conv2d(16, 16, 3, bias=False), # In: 16x24x24, Out: 16x22x22, RF: 7x7, Stride: 1
nn.ReLU(),
nn.BatchNorm2d(16),
nn.Dropout(drop_out_value),
# Transition Block 1
nn.MaxPool2d(kernel_size=2, stride=2), # In: 16x22x22, Out: 16x11x11, RF: 8x8, Stride: 2
# Conv Block 4
nn.Conv2d(16, 16, 3, bias=False), # In: 16x11x11, Out: 16x9x9, RF: 12x12, Stride: 1
nn.ReLU(),
nn.BatchNorm2d(16),
nn.Dropout(drop_out_value),
# Conv Block 5
nn.Conv2d(16, 32, 3, bias=False), # In: 16x9x9, Out: 32x7x7, RF: 16x16, Stride: 1
nn.ReLU(),
nn.BatchNorm2d(32),
nn.Dropout(drop_out_value),
# Output Block
nn.Conv2d(32, 64, 1, bias=False), # In: 32x7x7, Out: 64x7x7, RF: 16x16, Stride: 1
# Global Average Pooling
nn.AvgPool2d(7) # In: 64x7x7, Out: 64x1x1, RF: 16x16, Stride: 7
)
#---------------------Class Classifier Network------------------------#
self.class_classifier = nn.Sequential(nn.ReLU(),
nn.Dropout(p=drop_out_value),
nn.Linear(64,50),
nn.BatchNorm1d(50), # added batch norm to improve accuracy
nn.ReLU(),
nn.Dropout(p=drop_out_value),
nn.Linear(50,num_classes))
#---------------------Label Classifier Network------------------------#
self.domain_classifier = nn.Sequential(nn.ReLU(),
nn.Dropout(p=drop_out_value),
nn.Linear(64,50),
nn.BatchNorm1d(50), # added batch norm to improve accuracy
nn.ReLU(),
nn.Dropout(p=drop_out_value),
nn.Linear(50,2))
def forward(self, input_data, alpha = 1.0):
if input_data.data.shape[1] == 1:
input_data = input_data.expand(input_data.data.shape[0], 3, img_size, img_size)
input_data = self.feature_extractor(input_data)
features = input_data.view(input_data.size(0), -1) # Flatten the output for fully connected layer
reverse_features = GradientReversalFn.apply(features, alpha)
class_output = self.class_classifier(features)
domain_output = self.domain_classifier(reverse_features)
return class_output, domain_output, features
## NON DANN
# Instantiate the model (make sure it has the same architecture)
loaded_model_non_dann = Network()
loaded_model_non_dann = loaded_model_non_dann.to(device)
# Load the saved state dictionary
loaded_model_non_dann.load_state_dict(torch.load('non_dann_26_06.pt', map_location=device), strict=False)
loaded_model_non_dann.eval()
## DANN
# Instantiate the model (make sure it has the same architecture)
loaded_model_dann = Network()
loaded_model_dann = loaded_model_dann.to(device)
# Load the saved state dictionary
loaded_model_dann.load_state_dict(torch.load('dann_26_06.pt', map_location=device), strict=False)
loaded_model_dann.eval()
img_size = 28 # for mnist
cpu_batch_size = 10
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
def classify_image_both(image):
target_test_transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),# converts to tesnor
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
target_transformed_image = target_test_transforms(image)
image_tensor = target_transformed_image.to(device).unsqueeze(0)
list_confidences = []
for model in [loaded_model_non_dann, loaded_model_dann]:
model.eval()
logits,_,_ = model(image_tensor)
output = F.softmax(logits.view(-1), dim = -1)
confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))]
confidences.sort(key=lambda x: x[1], reverse=True)
confidences = OrderedDict(confidences[:3])
label = torch.argmax(output).item()
list_confidences.append(confidences)
return list_confidences[0],list_confidences[1]
### SOURCE DATA - MNIST
# Test Phase transformations
test_transforms = transforms.Compose([
# transforms.Resize(img_size),
transforms.ToTensor(),# converts to tesnor
# transforms.Normalize((0.1307,), (0.3081,))
])
transform_to_pil = transforms.ToPILImage()
test = datasets.MNIST('./data',
train=False,
download=True,
transform=test_transforms)
dataloader_args = dict(shuffle=True, batch_size=cpu_batch_size)
mnist_loader = torch.utils.data.DataLoader(
dataset = test,
**dataloader_args
)
def get_mnist_images():
images, labels = next(iter(mnist_loader))
pil_images = [transform_to_pil(image) for image in images]
return pil_images, labels.tolist()
splits = {'train': 'data/train-00000-of-00001-571b6b1e2c195186.parquet', 'test': 'data/test-00000-of-00001-ba3ad971b105ff65.parquet'}
df = pd.read_parquet("hf://datasets/Mike0307/MNIST-M/" + splits["test"])
class MNIST_M(torch.utils.data.Dataset):
def __init__(self, dataframe, transform=None):
self.dataframe = dataframe
self.transform = transform
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
# Get image and label from dataframe
img_data = self.dataframe.iloc[idx]['image']['bytes']
label = self.dataframe.iloc[idx]['label']
img_path = self.dataframe.iloc[idx]['image']['path']
# Decode image data (assuming it's base64 encoded)
img = Image.open(io.BytesIO(img_data))
# Apply transformations if any
if self.transform:
img = self.transform(img)
return img, label,img_path
# Test Phase transformations
target_test_transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),# converts to tesnor
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
transform_to_pil = transforms.ToPILImage()
# Create dataset
target_test_dataset = MNIST_M(dataframe=df, transform=target_test_transforms)
target_test_dataloader = torch.utils.data.DataLoader(target_test_dataset, batch_size=cpu_batch_size, shuffle=True)
def get_mnist_m_images():
images, labels,image_names = next(iter(target_test_dataloader))
pil_images = [transform_to_pil(image) for image in images]
return pil_images, labels.tolist()
mnist_images, mnist_labels = get_mnist_images()
mnist_m_images,mnist_m_labels = get_mnist_m_images()
def classify_image_inference(image):
# print(image.mode)
image_transforms = None
if image.mode == "L":
# image = image.convert("RGB")
source = 'MNIST'
image_transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),# converts to tesnor
transforms.Normalize((0.1307,), (0.3081,))
])
else:
source = 'MNIST-M'
image_transforms = transforms.Compose([
transforms.Resize(img_size),
transforms.ToTensor(),# converts to tesnor
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5))
])
transformed_image = image_transforms(image)
image_tensor = transformed_image.to(device).unsqueeze(0)
list_confidences = []
for model in [loaded_model_non_dann, loaded_model_dann]:
model.eval()
logits,_,_ = model(image_tensor)
output = F.softmax(logits.view(-1), dim = -1)
confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))]
confidences.sort(key=lambda x: x[1], reverse=True)
confidences = OrderedDict(confidences[:3])
label = torch.argmax(output).item()
list_confidences.append(confidences)
return list_confidences[0],list_confidences[1]
def display_image():
# Load the image from a local file
image = Image.open("mnist-m.JPG")
return image
with gr.Blocks() as demo:
with gr.Tab("Introduction"):
gr.Markdown("## Domain Adaptation in Deep Networks - Demonstration")
with gr.Row():
with gr.Column():
image_output = gr.Image(value=display_image(), label = "source and target",height = 256, width = 256, show_label = True)
gr.Markdown(
'''
Source - MNIST
------
- The MNIST database (Modified National Institute of Standards and Technology database) is a large collection of handwritten digits.
- It has a training set of 60,000 examples, and a test set of 10,000 examples.
- 28 x 28 size
- 1 channel
'''
)
gr.Markdown(
'''
Target - MNIST-M
-------
- MNIST-M is created by combining MNIST digits with the patches randomly extracted from color photos of BSDS500 as their background.
- It contains 59,001 training and 90,001 test images.
- 28 x 28 size
- 3 channels
'''
)
gr.Markdown(
'''
Please click on the tabs, for more functionality
-------
- Inferencing on NonDANN and DANN : Infer MNIST or MNISTM on both Models
- Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify : Curated list which misclassify on NON DANN but classifies well on NonDANN
- Case 2: MNIST_M_Both_Misclassify : Curated list which misclassify Both on NON DANN and DANN
'''
)
################################################
with gr.Tab("Inferencing on NonDANN and DANN"):
with gr.Row():
with gr.Column():
input_image_classify_mnist = gr.Image(label="Classify MNIST Digit", type = "pil", height = 256, width = 256, image_mode = 'L')
button_classify_mnist = gr.Button("Submit to Classify MNIST Image", visible = True, size ='sm')
with gr.Column():
with gr.Row():
label_classify_mnist_non_dann = gr.Label(label = "NON DANN Predicted MNIST label", num_top_classes=2, visible = True)
with gr.Row():
label_classify_mnist_dann = gr.Label(label = "DANN Predicted MNIST label", num_top_classes=2, visible = True)
with gr.Row():
gr.Examples( [img.convert("L") for img in mnist_images],
inputs=[input_image_classify_mnist], label = "Select an example MNIST Image")
with gr.Row():
with gr.Column():
input_image_classify_mnist_m = gr.Image(label="Classify MNIST M Digit", type = "pil", height = 256, width = 256, image_mode = 'RGB')
button_classify_mnist_m = gr.Button("Submit to Classify MNIST M Image", visible = True, size ='sm')
with gr.Column():
with gr.Row():
label_classify_mnist_m_non_dann = gr.Label(label = "NON DANN Predicted MNIST M label", num_top_classes=2, visible = True)
with gr.Row():
label_classify_mnist_m_dann = gr.Label(label = "DANN Predicted MNIST M label", num_top_classes=2, visible = True)
with gr.Row():
gr.Examples( [img.convert("RGB") for img in mnist_m_images],
inputs=[input_image_classify_mnist_m], label = "Select an example MNIST M Image")
with gr.Row():
gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels]}')
button_classify_mnist.click(fn=classify_image_inference,
inputs=[input_image_classify_mnist],
outputs=[label_classify_mnist_non_dann, label_classify_mnist_dann])
button_classify_mnist_m.click(fn=classify_image_inference,
inputs=[input_image_classify_mnist_m],
outputs=[label_classify_mnist_m_non_dann, label_classify_mnist_m_dann])
######################
with gr.Tab("Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify"):
# with gr.Row():
# radio_model = gr.Radio(["Baseline (Non-DANN)", "DANN"],
# label="Select the model you want to use.",
# value="Baseline (Non-DANN)", # Set default value
# scale=2)
with gr.Row():
with gr.Column():
input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256)
button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm')
with gr.Column():
with gr.Row():
label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True)
with gr.Row():
label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True)
mnist_m_images_1,mnist_m_labels_1 = get_images()
with gr.Row():
gr.Examples(mnist_m_images_1,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working
with gr.Row():
gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_1]}')
button_classify_both.click(fn=classify_image_both,
inputs=[input_image_classify_both],
outputs=[label_classify_non_dann,label_classify_dann])
########################################################################
with gr.Tab("Case 2 - Show both: MNIST_M_Both_Misclassify"):
with gr.Row():
with gr.Column():
input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256)
button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm')
with gr.Column():
with gr.Row():
label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True)
with gr.Row():
label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True)
mnist_m_images_2,mnist_m_labels_2 = get_images_2()
with gr.Row():
gr.Examples(mnist_m_images_2,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working
with gr.Row():
gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_2]}')
button_classify_both.click(fn=classify_image_both,
inputs=[input_image_classify_both],
outputs=[label_classify_non_dann,label_classify_dann])
demo.launch()