Spaces:
Sleeping
Sleeping
# -*- coding: utf-8 -*- | |
"""Gradio_C1_C2_v3.ipynb | |
Automatically generated by Colab. | |
Original file is located at | |
https://colab.research.google.com/drive/1KBTZm5X8qNslEbM7sLFu2IO-d2kg1XZY | |
""" | |
import gradio as gr | |
import os | |
from PIL import Image | |
from torchvision import datasets,transforms | |
import random | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Function | |
from collections import OrderedDict | |
import pandas as pd | |
import io | |
import base64 | |
# # checking the mounted drive and mounting if not done | |
# if not os.path.exists('/content/gdrive'): | |
# from google.colab import drive | |
# drive.mount('/content/gdrive') | |
# else: | |
# print("Google Drive is already mounted.") | |
list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified.pt') | |
class CustomDataset(torch.utils.data.Dataset): | |
def __init__(self, data): | |
self.data = data | |
def __len__(self): | |
return len(self.data) | |
def __getitem__(self, idx): | |
imgs, labels, image_names = self.data[idx] | |
return imgs, labels, image_names | |
dataset_c1 = CustomDataset(list_c1) | |
# Create a dataloader with the filtered dataset | |
dataloader_c1 = torch.utils.data.DataLoader(dataset_c1, batch_size=10, shuffle=True) | |
transform_to_pil = transforms.ToPILImage() | |
def get_images(): | |
images, labels,image_names = next(iter(dataloader_c1)) | |
pil_images = [transform_to_pil(image) for image in images] | |
return pil_images, labels.tolist() | |
list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified.pt') | |
dataset_c2 = CustomDataset(list_c2) | |
dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True) | |
def get_images_2(): | |
images, labels,image_names = next(iter(dataloader_c2)) | |
pil_images = [transform_to_pil(image) for image in images] | |
return pil_images, labels.tolist() | |
# next(iter(dataloader_c1)) | |
def get_device(): | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
device = "mps" | |
else: | |
device = "cpu" | |
print("Device Selected:", device) | |
return device | |
device = get_device() | |
class GradientReversalFn(Function): | |
def forward(ctx, x, alpha): | |
ctx.alpha = alpha | |
return x.view_as(x) | |
def backward(ctx, grad_output): | |
output = grad_output.neg() * ctx.alpha | |
return output, None | |
class Network(nn.Module): | |
def __init__(self, num_classes = 10): | |
super(Network, self).__init__() # Initialize the parent class | |
drop_out_value = 0.1 | |
#---------------------Feature Extractor Network------------------------# | |
self.feature_extractor = nn.Sequential( | |
# Input Block | |
nn.Conv2d(3, 16, 3, bias=False), # In: 3x28x28, Out: 16x26x26, RF: 3x3, Stride: 1 | |
nn.ReLU(), | |
nn.BatchNorm2d(16), | |
nn.Dropout(drop_out_value), | |
# Conv Block 2 | |
nn.Conv2d(16, 16, 3, bias=False), # In: 16x26x26, Out: 16x24x24, RF: 5x5, Stride: 1 | |
nn.ReLU(), | |
nn.BatchNorm2d(16), | |
nn.Dropout(drop_out_value), | |
# Conv Block 3 | |
nn.Conv2d(16, 16, 3, bias=False), # In: 16x24x24, Out: 16x22x22, RF: 7x7, Stride: 1 | |
nn.ReLU(), | |
nn.BatchNorm2d(16), | |
nn.Dropout(drop_out_value), | |
# Transition Block 1 | |
nn.MaxPool2d(kernel_size=2, stride=2), # In: 16x22x22, Out: 16x11x11, RF: 8x8, Stride: 2 | |
# Conv Block 4 | |
nn.Conv2d(16, 16, 3, bias=False), # In: 16x11x11, Out: 16x9x9, RF: 12x12, Stride: 1 | |
nn.ReLU(), | |
nn.BatchNorm2d(16), | |
nn.Dropout(drop_out_value), | |
# Conv Block 5 | |
nn.Conv2d(16, 32, 3, bias=False), # In: 16x9x9, Out: 32x7x7, RF: 16x16, Stride: 1 | |
nn.ReLU(), | |
nn.BatchNorm2d(32), | |
nn.Dropout(drop_out_value), | |
# Output Block | |
nn.Conv2d(32, 64, 1, bias=False), # In: 32x7x7, Out: 64x7x7, RF: 16x16, Stride: 1 | |
# Global Average Pooling | |
nn.AvgPool2d(7) # In: 64x7x7, Out: 64x1x1, RF: 16x16, Stride: 7 | |
) | |
#---------------------Class Classifier Network------------------------# | |
self.class_classifier = nn.Sequential(nn.ReLU(), | |
nn.Dropout(p=drop_out_value), | |
nn.Linear(64,50), | |
nn.BatchNorm1d(50), # added batch norm to improve accuracy | |
nn.ReLU(), | |
nn.Dropout(p=drop_out_value), | |
nn.Linear(50,num_classes)) | |
#---------------------Label Classifier Network------------------------# | |
self.domain_classifier = nn.Sequential(nn.ReLU(), | |
nn.Dropout(p=drop_out_value), | |
nn.Linear(64,50), | |
nn.BatchNorm1d(50), # added batch norm to improve accuracy | |
nn.ReLU(), | |
nn.Dropout(p=drop_out_value), | |
nn.Linear(50,2)) | |
def forward(self, input_data, alpha = 1.0): | |
if input_data.data.shape[1] == 1: | |
input_data = input_data.expand(input_data.data.shape[0], 3, img_size, img_size) | |
input_data = self.feature_extractor(input_data) | |
features = input_data.view(input_data.size(0), -1) # Flatten the output for fully connected layer | |
reverse_features = GradientReversalFn.apply(features, alpha) | |
class_output = self.class_classifier(features) | |
domain_output = self.domain_classifier(reverse_features) | |
return class_output, domain_output, features | |
## NON DANN | |
# Instantiate the model (make sure it has the same architecture) | |
loaded_model_non_dann = Network() | |
loaded_model_non_dann = loaded_model_non_dann.to(device) | |
# Load the saved state dictionary | |
loaded_model_non_dann.load_state_dict(torch.load('non_dann_26_06.pt', map_location=device), strict=False) | |
loaded_model_non_dann.eval() | |
## DANN | |
# Instantiate the model (make sure it has the same architecture) | |
loaded_model_dann = Network() | |
loaded_model_dann = loaded_model_dann.to(device) | |
# Load the saved state dictionary | |
loaded_model_dann.load_state_dict(torch.load('dann_26_06.pt', map_location=device), strict=False) | |
loaded_model_dann.eval() | |
img_size = 28 # for mnist | |
cpu_batch_size = 10 | |
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] | |
def classify_image_both(image): | |
target_test_transforms = transforms.Compose([ | |
transforms.Resize(img_size), | |
transforms.ToTensor(),# converts to tesnor | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
]) | |
target_transformed_image = target_test_transforms(image) | |
image_tensor = target_transformed_image.to(device).unsqueeze(0) | |
list_confidences = [] | |
for model in [loaded_model_non_dann, loaded_model_dann]: | |
model.eval() | |
logits,_,_ = model(image_tensor) | |
output = F.softmax(logits.view(-1), dim = -1) | |
confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))] | |
confidences.sort(key=lambda x: x[1], reverse=True) | |
confidences = OrderedDict(confidences[:3]) | |
label = torch.argmax(output).item() | |
list_confidences.append(confidences) | |
return list_confidences[0],list_confidences[1] | |
### SOURCE DATA - MNIST | |
# Test Phase transformations | |
test_transforms = transforms.Compose([ | |
# transforms.Resize(img_size), | |
transforms.ToTensor(),# converts to tesnor | |
# transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
transform_to_pil = transforms.ToPILImage() | |
test = datasets.MNIST('./data', | |
train=False, | |
download=True, | |
transform=test_transforms) | |
dataloader_args = dict(shuffle=True, batch_size=cpu_batch_size) | |
mnist_loader = torch.utils.data.DataLoader( | |
dataset = test, | |
**dataloader_args | |
) | |
def get_mnist_images(): | |
images, labels = next(iter(mnist_loader)) | |
pil_images = [transform_to_pil(image) for image in images] | |
return pil_images, labels.tolist() | |
splits = {'train': 'data/train-00000-of-00001-571b6b1e2c195186.parquet', 'test': 'data/test-00000-of-00001-ba3ad971b105ff65.parquet'} | |
df = pd.read_parquet("hf://datasets/Mike0307/MNIST-M/" + splits["test"]) | |
class MNIST_M(torch.utils.data.Dataset): | |
def __init__(self, dataframe, transform=None): | |
self.dataframe = dataframe | |
self.transform = transform | |
def __len__(self): | |
return len(self.dataframe) | |
def __getitem__(self, idx): | |
# Get image and label from dataframe | |
img_data = self.dataframe.iloc[idx]['image']['bytes'] | |
label = self.dataframe.iloc[idx]['label'] | |
img_path = self.dataframe.iloc[idx]['image']['path'] | |
# Decode image data (assuming it's base64 encoded) | |
img = Image.open(io.BytesIO(img_data)) | |
# Apply transformations if any | |
if self.transform: | |
img = self.transform(img) | |
return img, label,img_path | |
# Test Phase transformations | |
target_test_transforms = transforms.Compose([ | |
transforms.Resize(img_size), | |
transforms.ToTensor(),# converts to tesnor | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
]) | |
transform_to_pil = transforms.ToPILImage() | |
# Create dataset | |
target_test_dataset = MNIST_M(dataframe=df, transform=target_test_transforms) | |
target_test_dataloader = torch.utils.data.DataLoader(target_test_dataset, batch_size=cpu_batch_size, shuffle=True) | |
def get_mnist_m_images(): | |
images, labels,image_names = next(iter(target_test_dataloader)) | |
pil_images = [transform_to_pil(image) for image in images] | |
return pil_images, labels.tolist() | |
mnist_images, mnist_labels = get_mnist_images() | |
mnist_m_images,mnist_m_labels = get_mnist_m_images() | |
def classify_image_inference(image): | |
# print(image.mode) | |
image_transforms = None | |
if image.mode == "L": | |
# image = image.convert("RGB") | |
source = 'MNIST' | |
image_transforms = transforms.Compose([ | |
transforms.Resize(img_size), | |
transforms.ToTensor(),# converts to tesnor | |
transforms.Normalize((0.1307,), (0.3081,)) | |
]) | |
else: | |
source = 'MNIST-M' | |
image_transforms = transforms.Compose([ | |
transforms.Resize(img_size), | |
transforms.ToTensor(),# converts to tesnor | |
transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) | |
]) | |
transformed_image = image_transforms(image) | |
image_tensor = transformed_image.to(device).unsqueeze(0) | |
list_confidences = [] | |
for model in [loaded_model_non_dann, loaded_model_dann]: | |
model.eval() | |
logits,_,_ = model(image_tensor) | |
output = F.softmax(logits.view(-1), dim = -1) | |
confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))] | |
confidences.sort(key=lambda x: x[1], reverse=True) | |
confidences = OrderedDict(confidences[:3]) | |
label = torch.argmax(output).item() | |
list_confidences.append(confidences) | |
return list_confidences[0],list_confidences[1] | |
def display_image(): | |
# Load the image from a local file | |
image = Image.open("mnist-m.JPG") | |
return image | |
with gr.Blocks() as demo: | |
with gr.Tab("Introduction"): | |
gr.Markdown("## Domain Adaptation in Deep Networks - Demonstration") | |
with gr.Row(): | |
with gr.Column(): | |
image_output = gr.Image(value=display_image(), label = "source and target",height = 256, width = 256, show_label = True) | |
gr.Markdown( | |
''' | |
Source - MNIST | |
------ | |
- The MNIST database (Modified National Institute of Standards and Technology database) is a large collection of handwritten digits. | |
- It has a training set of 60,000 examples, and a test set of 10,000 examples. | |
- 28 x 28 size | |
- 1 channel | |
''' | |
) | |
gr.Markdown( | |
''' | |
Target - MNIST-M | |
------- | |
- MNIST-M is created by combining MNIST digits with the patches randomly extracted from color photos of BSDS500 as their background. | |
- It contains 59,001 training and 90,001 test images. | |
- 28 x 28 size | |
- 3 channels | |
''' | |
) | |
gr.Markdown( | |
''' | |
Please click on the tabs, for more functionality | |
------- | |
- Inferencing on NonDANN and DANN : Infer MNIST or MNISTM on both Models | |
- Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify : Curated list which misclassify on NON DANN but classifies well on NonDANN | |
- Case 2: MNIST_M_Both_Misclassify : Curated list which misclassify Both on NON DANN and DANN | |
''' | |
) | |
################################################ | |
with gr.Tab("Inferencing on NonDANN and DANN"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image_classify_mnist = gr.Image(label="Classify MNIST Digit", type = "pil", height = 256, width = 256, image_mode = 'L') | |
button_classify_mnist = gr.Button("Submit to Classify MNIST Image", visible = True, size ='sm') | |
with gr.Column(): | |
with gr.Row(): | |
label_classify_mnist_non_dann = gr.Label(label = "NON DANN Predicted MNIST label", num_top_classes=2, visible = True) | |
with gr.Row(): | |
label_classify_mnist_dann = gr.Label(label = "DANN Predicted MNIST label", num_top_classes=2, visible = True) | |
with gr.Row(): | |
gr.Examples( [img.convert("L") for img in mnist_images], | |
inputs=[input_image_classify_mnist], label = "Select an example MNIST Image") | |
with gr.Row(): | |
with gr.Column(): | |
input_image_classify_mnist_m = gr.Image(label="Classify MNIST M Digit", type = "pil", height = 256, width = 256, image_mode = 'RGB') | |
button_classify_mnist_m = gr.Button("Submit to Classify MNIST M Image", visible = True, size ='sm') | |
with gr.Column(): | |
with gr.Row(): | |
label_classify_mnist_m_non_dann = gr.Label(label = "NON DANN Predicted MNIST M label", num_top_classes=2, visible = True) | |
with gr.Row(): | |
label_classify_mnist_m_dann = gr.Label(label = "DANN Predicted MNIST M label", num_top_classes=2, visible = True) | |
with gr.Row(): | |
gr.Examples( [img.convert("RGB") for img in mnist_m_images], | |
inputs=[input_image_classify_mnist_m], label = "Select an example MNIST M Image") | |
with gr.Row(): | |
gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels]}') | |
button_classify_mnist.click(fn=classify_image_inference, | |
inputs=[input_image_classify_mnist], | |
outputs=[label_classify_mnist_non_dann, label_classify_mnist_dann]) | |
button_classify_mnist_m.click(fn=classify_image_inference, | |
inputs=[input_image_classify_mnist_m], | |
outputs=[label_classify_mnist_m_non_dann, label_classify_mnist_m_dann]) | |
###################### | |
with gr.Tab("Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify"): | |
# with gr.Row(): | |
# radio_model = gr.Radio(["Baseline (Non-DANN)", "DANN"], | |
# label="Select the model you want to use.", | |
# value="Baseline (Non-DANN)", # Set default value | |
# scale=2) | |
with gr.Row(): | |
with gr.Column(): | |
input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256) | |
button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm') | |
with gr.Column(): | |
with gr.Row(): | |
label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True) | |
with gr.Row(): | |
label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True) | |
mnist_m_images_1,mnist_m_labels_1 = get_images() | |
with gr.Row(): | |
gr.Examples(mnist_m_images_1,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working | |
with gr.Row(): | |
gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_1]}') | |
button_classify_both.click(fn=classify_image_both, | |
inputs=[input_image_classify_both], | |
outputs=[label_classify_non_dann,label_classify_dann]) | |
######################################################################## | |
with gr.Tab("Case 2 - Show both: MNIST_M_Both_Misclassify"): | |
with gr.Row(): | |
with gr.Column(): | |
input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256) | |
button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm') | |
with gr.Column(): | |
with gr.Row(): | |
label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True) | |
with gr.Row(): | |
label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True) | |
mnist_m_images_2,mnist_m_labels_2 = get_images_2() | |
with gr.Row(): | |
gr.Examples(mnist_m_images_2,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working | |
with gr.Row(): | |
gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_2]}') | |
button_classify_both.click(fn=classify_image_both, | |
inputs=[input_image_classify_both], | |
outputs=[label_classify_non_dann,label_classify_dann]) | |
demo.launch() |