# -*- coding: utf-8 -*- """Gradio_C1_C2_v3.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1KBTZm5X8qNslEbM7sLFu2IO-d2kg1XZY """ import gradio as gr import os from PIL import Image from torchvision import datasets,transforms import random import torch import torch.nn as nn import torch.nn.functional as F from torch.autograd import Function from collections import OrderedDict import pandas as pd import io import base64 # # checking the mounted drive and mounting if not done # if not os.path.exists('/content/gdrive'): # from google.colab import drive # drive.mount('/content/gdrive') # else: # print("Google Drive is already mounted.") list_c1 = torch.load('list_mnist_m_non_dann_misclassified_dann_classified_08_07.pt') class CustomDataset(torch.utils.data.Dataset): def __init__(self, data): self.data = data def __len__(self): return len(self.data) def __getitem__(self, idx): imgs, labels, image_names = self.data[idx] return imgs, labels, image_names dataset_c1 = CustomDataset(list_c1) # Create a dataloader with the filtered dataset dataloader_c1 = torch.utils.data.DataLoader(dataset_c1, batch_size=10, shuffle=True) transform_to_pil = transforms.ToPILImage() def get_images(): images, labels,image_names = next(iter(dataloader_c1)) pil_images = [transform_to_pil(image) for image in images] return pil_images, labels.tolist() list_c2 = torch.load('list_mnist_m_non_dann_misclassified_dann_misclassified_08_07.pt') dataset_c2 = CustomDataset(list_c2) dataloader_c2 = torch.utils.data.DataLoader(dataset_c2, batch_size=10, shuffle=True) def get_images_2(): images, labels,image_names = next(iter(dataloader_c2)) pil_images = [transform_to_pil(image) for image in images] return pil_images, labels.tolist() # next(iter(dataloader_c1)) def get_device(): if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): device = "mps" else: device = "cpu" print("Device Selected:", device) return device device = get_device() class GradientReversalFn(Function): @staticmethod def forward(ctx, x, alpha): ctx.alpha = alpha return x.view_as(x) @staticmethod def backward(ctx, grad_output): output = grad_output.neg() * ctx.alpha return output, None class Network(nn.Module): def __init__(self, num_classes = 10): super(Network, self).__init__() # Initialize the parent class drop_out_value = 0.1 #---------------------Feature Extractor Network------------------------# self.feature_extractor = nn.Sequential( # Input Block nn.Conv2d(3, 16, 3, bias=False), # In: 3x28x28, Out: 16x26x26, RF: 3x3, Stride: 1 nn.ReLU(), nn.BatchNorm2d(16), nn.Dropout(drop_out_value), # Conv Block 2 nn.Conv2d(16, 16, 3, bias=False), # In: 16x26x26, Out: 16x24x24, RF: 5x5, Stride: 1 nn.ReLU(), nn.BatchNorm2d(16), nn.Dropout(drop_out_value), # Conv Block 3 nn.Conv2d(16, 16, 3, bias=False), # In: 16x24x24, Out: 16x22x22, RF: 7x7, Stride: 1 nn.ReLU(), nn.BatchNorm2d(16), nn.Dropout(drop_out_value), # Transition Block 1 nn.MaxPool2d(kernel_size=2, stride=2), # In: 16x22x22, Out: 16x11x11, RF: 8x8, Stride: 2 # Conv Block 4 nn.Conv2d(16, 16, 3, bias=False), # In: 16x11x11, Out: 16x9x9, RF: 12x12, Stride: 1 nn.ReLU(), nn.BatchNorm2d(16), nn.Dropout(drop_out_value), # Conv Block 5 nn.Conv2d(16, 32, 3, bias=False), # In: 16x9x9, Out: 32x7x7, RF: 16x16, Stride: 1 nn.ReLU(), nn.BatchNorm2d(32), nn.Dropout(drop_out_value), # Output Block nn.Conv2d(32, 64, 1, bias=False), # In: 32x7x7, Out: 64x7x7, RF: 16x16, Stride: 1 # Global Average Pooling nn.AvgPool2d(7) # In: 64x7x7, Out: 64x1x1, RF: 16x16, Stride: 7 ) #---------------------Class Classifier Network------------------------# self.class_classifier = nn.Sequential(nn.ReLU(), nn.Dropout(p=drop_out_value), nn.Linear(64,50), nn.BatchNorm1d(50), # added batch norm to improve accuracy nn.ReLU(), nn.Dropout(p=drop_out_value), nn.Linear(50,num_classes)) #---------------------Label Classifier Network------------------------# self.domain_classifier = nn.Sequential(nn.ReLU(), nn.Dropout(p=drop_out_value), nn.Linear(64,50), nn.BatchNorm1d(50), # added batch norm to improve accuracy nn.ReLU(), nn.Dropout(p=drop_out_value), nn.Linear(50,2)) def forward(self, input_data, alpha = 1.0): if input_data.data.shape[1] == 1: input_data = input_data.expand(input_data.data.shape[0], 3, img_size, img_size) input_data = self.feature_extractor(input_data) features = input_data.view(input_data.size(0), -1) # Flatten the output for fully connected layer reverse_features = GradientReversalFn.apply(features, alpha) class_output = self.class_classifier(features) domain_output = self.domain_classifier(reverse_features) return class_output, domain_output, features ## NON DANN # Instantiate the model (make sure it has the same architecture) loaded_model_non_dann = Network() loaded_model_non_dann = loaded_model_non_dann.to(device) # Load the saved state dictionary loaded_model_non_dann.load_state_dict(torch.load('non_dann_08_07.pt', map_location=device), strict=False) loaded_model_non_dann.eval() ## DANN # Instantiate the model (make sure it has the same architecture) loaded_model_dann = Network() loaded_model_dann = loaded_model_dann.to(device) # Load the saved state dictionary loaded_model_dann.load_state_dict(torch.load('dann_08_07.pt', map_location=device), strict=False) loaded_model_dann.eval() img_size = 28 # for mnist cpu_batch_size = 10 class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9'] def classify_image_both(image): target_test_transforms = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(),# converts to tesnor transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]) target_transformed_image = target_test_transforms(image) image_tensor = target_transformed_image.to(device).unsqueeze(0) list_confidences = [] for model in [loaded_model_non_dann, loaded_model_dann]: model.eval() logits,_,_ = model(image_tensor) output = F.softmax(logits.view(-1), dim = -1) confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))] confidences.sort(key=lambda x: x[1], reverse=True) confidences = OrderedDict(confidences[:3]) label = torch.argmax(output).item() list_confidences.append(confidences) return list_confidences[0],list_confidences[1] ### SOURCE DATA - MNIST # Test Phase transformations test_transforms = transforms.Compose([ # transforms.Resize(img_size), transforms.ToTensor(),# converts to tesnor # transforms.Normalize((0.1307,), (0.3081,)) ]) transform_to_pil = transforms.ToPILImage() test = datasets.MNIST('./data', train=False, download=True, transform=test_transforms) dataloader_args = dict(shuffle=True, batch_size=cpu_batch_size) mnist_loader = torch.utils.data.DataLoader( dataset = test, **dataloader_args ) def get_mnist_images(): images, labels = next(iter(mnist_loader)) pil_images = [transform_to_pil(image) for image in images] return pil_images, labels.tolist() splits = {'train': 'data/train-00000-of-00001-571b6b1e2c195186.parquet', 'test': 'data/test-00000-of-00001-ba3ad971b105ff65.parquet'} df = pd.read_parquet("hf://datasets/Mike0307/MNIST-M/" + splits["test"]) class MNIST_M(torch.utils.data.Dataset): def __init__(self, dataframe, transform=None): self.dataframe = dataframe self.transform = transform def __len__(self): return len(self.dataframe) def __getitem__(self, idx): # Get image and label from dataframe img_data = self.dataframe.iloc[idx]['image']['bytes'] label = self.dataframe.iloc[idx]['label'] img_path = self.dataframe.iloc[idx]['image']['path'] # Decode image data (assuming it's base64 encoded) img = Image.open(io.BytesIO(img_data)) # Apply transformations if any if self.transform: img = self.transform(img) return img, label,img_path # Test Phase transformations target_test_transforms = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(),# converts to tesnor transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]) transform_to_pil = transforms.ToPILImage() # Create dataset target_test_dataset = MNIST_M(dataframe=df, transform=target_test_transforms) target_test_dataloader = torch.utils.data.DataLoader(target_test_dataset, batch_size=cpu_batch_size, shuffle=True) def get_mnist_m_images(): images, labels,image_names = next(iter(target_test_dataloader)) pil_images = [transform_to_pil(image) for image in images] return pil_images, labels.tolist() mnist_images, mnist_labels = get_mnist_images() mnist_m_images,mnist_m_labels = get_mnist_m_images() def classify_image_inference(image): # print(image.mode) image_transforms = None if image.mode == "L": # image = image.convert("RGB") source = 'MNIST' image_transforms = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(),# converts to tesnor transforms.Normalize((0.1307,), (0.3081,)) ]) else: source = 'MNIST-M' image_transforms = transforms.Compose([ transforms.Resize(img_size), transforms.ToTensor(),# converts to tesnor transforms.Normalize((0.5,0.5,0.5), (0.5,0.5,0.5)) ]) transformed_image = image_transforms(image) image_tensor = transformed_image.to(device).unsqueeze(0) list_confidences = [] for model in [loaded_model_non_dann, loaded_model_dann]: model.eval() logits,_,_ = model(image_tensor) output = F.softmax(logits.view(-1), dim = -1) confidences = [(class_names[i], float(output[i])) for i in range(len(class_names))] confidences.sort(key=lambda x: x[1], reverse=True) confidences = OrderedDict(confidences[:3]) label = torch.argmax(output).item() list_confidences.append(confidences) return list_confidences[0],list_confidences[1] def display_image(): # Load the image from a local file image = Image.open("mnist-m.JPG") return image with gr.Blocks() as demo: with gr.Tab("Introduction"): gr.Markdown("## Domain Adaptation in Deep Networks - Demonstration") with gr.Row(): with gr.Column(): image_output = gr.Image(value=display_image(), label = "source and target",height = 256, width = 256, show_label = True) gr.Markdown( ''' Source - MNIST ------ - The MNIST database (Modified National Institute of Standards and Technology database) is a large collection of handwritten digits. - It has a training set of 60,000 examples, and a test set of 10,000 examples. - 28 x 28 size - 1 channel ''' ) gr.Markdown( ''' Target - MNIST-M ------- - MNIST-M is created by combining MNIST digits with the patches randomly extracted from color photos of BSDS500 as their background. - It contains 59,001 training and 90,001 test images. - 28 x 28 size - 3 channels ''' ) gr.Markdown( ''' Please click on the tabs, for more functionality ------- - Inferencing on NonDANN and DANN : Infer MNIST or MNISTM on both Models - Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify : Curated list which misclassify on NON DANN but classifies well on NonDANN - Case 2: MNIST_M_Both_Misclassify : Curated list which misclassify Both on NON DANN and DANN ''' ) ################################################ with gr.Tab("Inferencing on NonDANN and DANN"): with gr.Row(): with gr.Column(): input_image_classify_mnist = gr.Image(label="Classify MNIST Digit", type = "pil", height = 256, width = 256, image_mode = 'L') button_classify_mnist = gr.Button("Submit to Classify MNIST Image", visible = True, size ='sm') with gr.Column(): with gr.Row(): label_classify_mnist_non_dann = gr.Label(label = "NON DANN Predicted MNIST label", num_top_classes=2, visible = True) with gr.Row(): label_classify_mnist_dann = gr.Label(label = "DANN Predicted MNIST label", num_top_classes=2, visible = True) with gr.Row(): gr.Examples( [img.convert("L") for img in mnist_images], inputs=[input_image_classify_mnist], label = "Select an example MNIST Image") with gr.Row(): with gr.Column(): input_image_classify_mnist_m = gr.Image(label="Classify MNIST M Digit", type = "pil", height = 256, width = 256, image_mode = 'RGB') button_classify_mnist_m = gr.Button("Submit to Classify MNIST M Image", visible = True, size ='sm') with gr.Column(): with gr.Row(): label_classify_mnist_m_non_dann = gr.Label(label = "NON DANN Predicted MNIST M label", num_top_classes=2, visible = True) with gr.Row(): label_classify_mnist_m_dann = gr.Label(label = "DANN Predicted MNIST M label", num_top_classes=2, visible = True) with gr.Row(): gr.Examples( [img.convert("RGB") for img in mnist_m_images], inputs=[input_image_classify_mnist_m], label = "Select an example MNIST M Image") with gr.Row(): gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels]}') button_classify_mnist.click(fn=classify_image_inference, inputs=[input_image_classify_mnist], outputs=[label_classify_mnist_non_dann, label_classify_mnist_dann]) button_classify_mnist_m.click(fn=classify_image_inference, inputs=[input_image_classify_mnist_m], outputs=[label_classify_mnist_m_non_dann, label_classify_mnist_m_dann]) ###################### with gr.Tab("Case 1: MNIST_M_Non_DANN_Misclassify_DANN_Classify"): # with gr.Row(): # radio_model = gr.Radio(["Baseline (Non-DANN)", "DANN"], # label="Select the model you want to use.", # value="Baseline (Non-DANN)", # Set default value # scale=2) with gr.Row(): with gr.Column(): input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256) button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm') with gr.Column(): with gr.Row(): label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True) with gr.Row(): label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True) mnist_m_images_1,mnist_m_labels_1 = get_images() with gr.Row(): gr.Examples(mnist_m_images_1,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working with gr.Row(): gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_1]}') button_classify_both.click(fn=classify_image_both, inputs=[input_image_classify_both], outputs=[label_classify_non_dann,label_classify_dann]) ######################################################################## with gr.Tab("Case 2 - Show both: MNIST_M_Both_Misclassify"): with gr.Row(): with gr.Column(): input_image_classify_both = gr.Image(label="Classify Digit", type = "pil", height = 256, width = 256) button_classify_both = gr.Button("Submit to Classify Image with Both Models", visible = True, size ='sm') with gr.Column(): with gr.Row(): label_classify_non_dann = gr.Label(label = "NON DANN Predicted label", num_top_classes=2, visible = True) with gr.Row(): label_classify_dann = gr.Label(label = "DANN Predicted label", num_top_classes=2, visible = True) mnist_m_images_2,mnist_m_labels_2 = get_images_2() with gr.Row(): gr.Examples(mnist_m_images_2,inputs=[input_image_classify_both], label = "Select an example MNIST-M Image") #working with gr.Row(): gr.Markdown(value = f'MNIST- M Ground Truth Label = {[label for label in mnist_m_labels_2]}') button_classify_both.click(fn=classify_image_both, inputs=[input_image_classify_both], outputs=[label_classify_non_dann,label_classify_dann]) demo.launch()