TSAIGradcam / gradio_utils.py
ibrim's picture
Update gradio_utils.py
764fcd0 verified
raw
history blame
No virus
4.89 kB
# Import all the required modules
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import math
from collections import OrderedDict
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch_lr_finder import LRFinder
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from utils import display_gradcam_output
from torchmetrics import Accuracy
from torchvision.datasets import CIFAR10
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
import matplotlib.pyplot as plt
import random
from resnet import *
PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
class LitCIFAR(LightningModule):
def __init__(self, data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4):
super().__init__()
# Set our init args as class attributes
self.data_dir = data_dir
self.hidden_size = hidden_size
self.learning_rate = learning_rate
# Hardcode some dataset specific attributes
self.num_classes = 10
self.misclassified_indices = []
# Define PyTorch model
self.model = ResNet18()
self.accuracy = Accuracy(num_classes=self.num_classes, task='multiclass')
def forward(self, x):
x = self.model(x)
x = x.view(-1, 10)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.accuracy(preds, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.accuracy, prog_bar=True)
return loss
# def test_step(self, batch, batch_idx):
# # Here we just reuse the validation_step for testing
# return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
modelfin = LitCIFAR()
lit_cifar_instance = LitCIFAR()
# Load the state dictionary from the checkpoint file
modelfin.load_state_dict(torch.load("TSAIGradcam/model.ckpt"))
# Set the model to evaluation mode
modelfin.eval()
# If you need to use the model on a GPU, move the model to GPU after loading the state dict and setting it to eval mode
modelfin = modelfin.to('cuda')
inv_normalize = transforms.Normalize(
mean=[-1.9899, -1.9844, -1.7111],
std=[4.0486, 4.1152, 3.8314]
)
from torch.utils.data import Dataset
import numpy as np
class CIFAR10Dataset(Dataset):
def __init__(self, data_dir, train=True, transform=None):
self.data = CIFAR10(data_dir, train=train, download=True, transform=None)
self.transform = transform
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
image, label = self.data[idx]
if self.transform:
# Apply the transformation in the correct format
image = self.transform(image=np.array(image))['image']
return image, label
transform = A.Compose(
[
A.RandomCrop(height=32, width=32, p=0.2),
A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
A.HorizontalFlip(),
A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
ToTensorV2(),
]
)
test_transform = A.Compose(
[
A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
ToTensorV2(),
]
)
cifar_full = CIFAR10Dataset(PATH_DATASETS, train=True, transform=transform)
cifar_test = CIFAR10Dataset(PATH_DATASETS, train=False, transform=test_transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
'dog', 'frog', 'horse', 'ship', 'truck')
test_loader = DataLoader(cifar_test, batch_size=512, num_workers=os.cpu_count())
import gradio as gr
from utils import display_gradcam_output
from utils import get_misclassified_data
from visualize import display_cifar_misclassified_data
from PIL import Image
#misclassified_data, classes, inv_normalize, modelfin, target_layers, targets, number_of_samples=2, transparency=0.7