File size: 4,872 Bytes
71c714a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f0b1747
71c714a
 
 
83f328a
71c714a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
# Import all the required modules
import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import math
from collections import OrderedDict
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
import albumentations as A
from albumentations.pytorch import ToTensorV2
from torch_lr_finder import LRFinder
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from utils import display_gradcam_output
from torchmetrics import Accuracy
from torchvision.datasets import CIFAR10 
import torch
from pytorch_lightning import LightningModule, Trainer
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, random_split
from torchmetrics import Accuracy
from torchvision import transforms
import matplotlib.pyplot as plt
import random
from resnet import *

PATH_DATASETS = os.environ.get("PATH_DATASETS", ".")
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')



class LitCIFAR(LightningModule):
    def __init__(self, data_dir=PATH_DATASETS, hidden_size=16, learning_rate=2e-4):

        super().__init__()

        # Set our init args as class attributes
        self.data_dir = data_dir
        self.hidden_size = hidden_size
        self.learning_rate = learning_rate

        # Hardcode some dataset specific attributes
        self.num_classes = 10

        self.misclassified_indices = []

        # Define PyTorch model
        self.model = ResNet18()

        self.accuracy = Accuracy(num_classes=self.num_classes, task='multiclass')

    def forward(self, x):
        x = self.model(x)
        x = x.view(-1, 10)
        return F.log_softmax(x, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        self.accuracy(preds, y)
        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.accuracy, prog_bar=True)
        return loss

    # def test_step(self, batch, batch_idx):
    #     # Here we just reuse the validation_step for testing
    #     return self.validation_step(batch, batch_idx)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer


modelfin = LitCIFAR()
lit_cifar_instance = LitCIFAR()
# Load the state dictionary from the checkpoint file
modelfin.load_state_dict(torch.load("model.ckpt"))
# Set the model to evaluation mode
modelfin.eval()
# If you need to use the model on a GPU, move the model to GPU after loading the state dict and setting it to eval mode
modelfin = modelfin.to('cpu')

inv_normalize = transforms.Normalize(
    mean=[-1.9899, -1.9844, -1.7111],
    std=[4.0486, 4.1152, 3.8314]
)

from torch.utils.data import Dataset
import numpy as np
class CIFAR10Dataset(Dataset):
    def __init__(self, data_dir, train=True, transform=None):
        self.data = CIFAR10(data_dir, train=train, download=True, transform=None)
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        image, label = self.data[idx]
        if self.transform:
            # Apply the transformation in the correct format
            image = self.transform(image=np.array(image))['image']
        return image, label

transform = A.Compose(
    [
        A.RandomCrop(height=32, width=32, p=0.2),
        A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        A.HorizontalFlip(),
        A.CoarseDropout(max_holes=1, max_height=16, max_width=16, min_holes=1, min_height=1, min_width=1, fill_value=[0.49139968*255, 0.48215827*255 ,0.44653124*255], mask_fill_value=None),
        ToTensorV2(),
    ]
)
test_transform = A.Compose(
    [
        A.Normalize((0.4914, 0.4822, 0.4465), (0.247, 0.243, 0.261)),
        ToTensorV2(),
    ]
)

cifar_full = CIFAR10Dataset(PATH_DATASETS, train=True, transform=transform)
cifar_test = CIFAR10Dataset(PATH_DATASETS, train=False, transform=test_transform)
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

test_loader = DataLoader(cifar_test, batch_size=512, num_workers=os.cpu_count())

import gradio as gr
from utils import display_gradcam_output
from utils import get_misclassified_data
from visualize import display_cifar_misclassified_data
from PIL import Image

#misclassified_data, classes, inv_normalize, modelfin, target_layers, targets, number_of_samples=2, transparency=0.7