MostHumble commited on
Commit
eed12b2
1 Parent(s): b8e9456

add inference script

Browse files
utils/__init__.py ADDED
File without changes
utils/data.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from torch.utils.data import DataLoader, Dataset, Subset
3
+ from torchvision.datasets import ImageFolder
4
+ from sklearn.model_selection import train_test_split
5
+ import torch
6
+ import matplotlib.pyplot as plt
7
+ import pickle
8
+
9
+ CLASS_NAMES = ['Abra',
10
+ 'Aerodactyl',
11
+ 'Alakazam',
12
+ 'Alolan Sandslash',
13
+ 'Arbok',
14
+ 'Arcanine',
15
+ 'Articuno',
16
+ 'Beedrill',
17
+ 'Bellsprout',
18
+ 'Blastoise',
19
+ 'Bulbasaur',
20
+ 'Butterfree',
21
+ 'Caterpie',
22
+ 'Chansey',
23
+ 'Charizard',
24
+ 'Charmander',
25
+ 'Charmeleon',
26
+ 'Clefable',
27
+ 'Clefairy',
28
+ 'Cloyster',
29
+ 'Cubone',
30
+ 'Dewgong',
31
+ 'Diglett',
32
+ 'Ditto',
33
+ 'Dodrio',
34
+ 'Doduo',
35
+ 'Dragonair',
36
+ 'Dragonite',
37
+ 'Dratini',
38
+ 'Drowzee',
39
+ 'Dugtrio',
40
+ 'Eevee',
41
+ 'Ekans',
42
+ 'Electabuzz',
43
+ 'Electrode',
44
+ 'Exeggcute',
45
+ 'Exeggutor',
46
+ 'Farfetchd',
47
+ 'Fearow',
48
+ 'Flareon',
49
+ 'Gastly',
50
+ 'Gengar',
51
+ 'Geodude',
52
+ 'Gloom',
53
+ 'Golbat',
54
+ 'Goldeen',
55
+ 'Golduck',
56
+ 'Golem',
57
+ 'Graveler',
58
+ 'Grimer',
59
+ 'Growlithe',
60
+ 'Gyarados',
61
+ 'Haunter',
62
+ 'Hitmonchan',
63
+ 'Hitmonlee',
64
+ 'Horsea',
65
+ 'Hypno',
66
+ 'Ivysaur',
67
+ 'Jigglypuff',
68
+ 'Jolteon',
69
+ 'Jynx',
70
+ 'Kabuto',
71
+ 'Kabutops',
72
+ 'Kadabra',
73
+ 'Kakuna',
74
+ 'Kangaskhan',
75
+ 'Kingler',
76
+ 'Koffing',
77
+ 'Krabby',
78
+ 'Lapras',
79
+ 'Lickitung',
80
+ 'Machamp',
81
+ 'Machoke',
82
+ 'Machop',
83
+ 'Magikarp',
84
+ 'Magmar',
85
+ 'Magnemite',
86
+ 'Magneton',
87
+ 'Mankey',
88
+ 'Marowak',
89
+ 'Meowth',
90
+ 'Metapod',
91
+ 'Mew',
92
+ 'Mewtwo',
93
+ 'Moltres',
94
+ 'MrMime',
95
+ 'Muk',
96
+ 'Nidoking',
97
+ 'Nidoqueen',
98
+ 'Nidorina',
99
+ 'Nidorino',
100
+ 'Ninetales',
101
+ 'Oddish',
102
+ 'Omanyte',
103
+ 'Omastar',
104
+ 'Onix',
105
+ 'Paras',
106
+ 'Parasect',
107
+ 'Persian',
108
+ 'Pidgeot',
109
+ 'Pidgeotto',
110
+ 'Pidgey',
111
+ 'Pikachu',
112
+ 'Pinsir',
113
+ 'Poliwag',
114
+ 'Poliwhirl',
115
+ 'Poliwrath',
116
+ 'Ponyta',
117
+ 'Porygon',
118
+ 'Primeape',
119
+ 'Psyduck',
120
+ 'Raichu',
121
+ 'Rapidash',
122
+ 'Raticate',
123
+ 'Rattata',
124
+ 'Rhydon',
125
+ 'Rhyhorn',
126
+ 'Sandshrew',
127
+ 'Sandslash',
128
+ 'Scyther',
129
+ 'Seadra',
130
+ 'Seaking',
131
+ 'Seel',
132
+ 'Shellder',
133
+ 'Slowbro',
134
+ 'Slowpoke',
135
+ 'Snorlax',
136
+ 'Spearow',
137
+ 'Squirtle',
138
+ 'Starmie',
139
+ 'Staryu',
140
+ 'Tangela',
141
+ 'Tauros',
142
+ 'Tentacool',
143
+ 'Tentacruel',
144
+ 'Vaporeon',
145
+ 'Venomoth',
146
+ 'Venonat',
147
+ 'Venusaur',
148
+ 'Victreebel',
149
+ 'Vileplume',
150
+ 'Voltorb',
151
+ 'Vulpix',
152
+ 'Wartortle',
153
+ 'Weedle',
154
+ 'Weepinbell',
155
+ 'Weezing',
156
+ 'Wigglytuff',
157
+ 'Zapdos',
158
+ 'Zubat']
159
+
160
+ class TransformSubset(Dataset):
161
+ """
162
+ Wrapper for applying transformations to a Subset.
163
+ """
164
+
165
+ def __init__(self, subset, transform):
166
+ self.subset = subset
167
+ self.transform = transform
168
+
169
+ def __getitem__(self, idx):
170
+ img, label = self.subset[idx]
171
+ if self.transform:
172
+ img = self.transform(img)
173
+ return img, label
174
+
175
+ def __len__(self):
176
+ return len(self.subset)
177
+
178
+
179
+ class PokemonDataModule(Dataset):
180
+ def __init__(self, data_dir):
181
+ self.dataset = ImageFolder(root=data_dir)
182
+ self.class_names = self.dataset.classes
183
+
184
+ def __len__(self):
185
+ return len(self.dataset)
186
+
187
+ def __getitem__(self, index):
188
+ image, label = self.dataset[index]
189
+ return image, label
190
+
191
+ def plot_examples(self, dataloader, n_rows=1, n_cols=4, stats=None):
192
+ """
193
+ Plot examples from a DataLoader.
194
+
195
+ Args:
196
+ dataloader (DataLoader): DataLoader object to fetch images and labels from.
197
+ n_rows (int): Number of rows in the plot grid.
198
+ n_cols (int): Number of columns in the plot grid.
199
+ denormalize (callable, optional): Function to reverse normalization for visualization.
200
+ Should accept a tensor and return a denormalized tensor.
201
+ """
202
+ fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
203
+ axes = axes.flatten() # Flatten to iterate easily
204
+
205
+ # Iterate over the dataloader to get a batch of data
206
+ for data, labels in dataloader:
207
+ # Take the first n_rows * n_cols samples from the batch
208
+ for i, ax in enumerate(axes[: n_rows * n_cols]):
209
+ if i >= len(data): # If fewer samples than the grid size, stop
210
+ break
211
+
212
+ img, label = data[i], labels[i]
213
+
214
+ # Apply denormalization if provided
215
+ if stats:
216
+ img = self._denormalize(img, stats)
217
+
218
+ # Convert CHW to HWC for plotting
219
+ img = img.permute(1, 2, 0).cpu().numpy()
220
+
221
+ ax.imshow(img)
222
+ ax.set_title(self.class_names[label.item()])
223
+ ax.axis("off")
224
+ break # Only process the first batch
225
+
226
+ plt.tight_layout()
227
+ plt.show()
228
+
229
+ def _denormalize(self, img, stats):
230
+ """
231
+ Denormalize an image tensor.
232
+
233
+ Args:
234
+ img (Tensor): Image tensor with shape (C, H, W).
235
+ stats (dict): Dictionary containing 'means' and 'stds' for each channel.
236
+ Example: {'means': [0.485, 0.456, 0.406], 'stds': [0.229, 0.224, 0.225]}.
237
+
238
+ Returns:
239
+ Tensor: Denormalized image tensor.
240
+ """
241
+ return img * stats["std"].view(-1, 1, 1) + stats["mean"].view(-1, 1, 1)
242
+
243
+ def _get_stats(self, dataset):
244
+ """
245
+ Calculate the mean and standard deviation of the dataset for standardization.
246
+ """
247
+ dataloader = DataLoader(dataset, batch_size=2048, shuffle=False)
248
+ total_sum, total_squared_sum, total_count = 0, 0, 0
249
+ with torch.cuda.device(0):
250
+ for data, _ in dataloader:
251
+ data.cuda()
252
+ total_sum += data.sum(dim=(0, 2, 3))
253
+ total_squared_sum += (data**2).sum(dim=(0, 2, 3))
254
+ total_count += data.size(0) * data.size(2) * data.size(3)
255
+
256
+ means = total_sum / total_count
257
+ stds = torch.sqrt((total_squared_sum / total_count) - (means**2))
258
+ return {"mean": means, "std": stds}
259
+
260
+ def prepare_data(self, indices_file="indices.pkl", get_stats=False):
261
+ """
262
+ Prepare train and test dataloaders with optional transformations.
263
+
264
+ Args:
265
+ indices_file (str): Path to save or load train/test indices.
266
+ transform (callable): Primary transformation to apply to the data.
267
+ additional_transforms (callable): Additional transformations to compose.
268
+
269
+ Returns:
270
+ tuple: trainloader, testloader
271
+ """
272
+ try:
273
+ with open(indices_file, "rb") as f:
274
+ self.train_indices, self.test_indices = pickle.load(f)
275
+ except (EOFError, FileNotFoundError):
276
+ # Generate new indices if file is empty or doesn't exist
277
+ self.train_indices, self.test_indices = train_test_split(
278
+ range(len(self.dataset)),
279
+ test_size=0.2,
280
+ stratify=self.dataset.targets,
281
+ random_state=42,
282
+ )
283
+
284
+ # Ensure directory exists before saving
285
+ os.makedirs(os.path.dirname(indices_file) or ".", exist_ok=True)
286
+
287
+ with open(indices_file, "wb") as f:
288
+ pickle.dump([self.train_indices, self.test_indices], f)
289
+
290
+ # Prepare train and test subsets
291
+ self.train_dataset = Subset(self.dataset, self.train_indices)
292
+ self.test_dataset = Subset(self.dataset, self.test_indices)
293
+
294
+ return self._get_stats(self.train_dataset) if get_stats else None
295
+
296
+ def get_dataloaders(
297
+ self,
298
+ train_transform=None,
299
+ test_transform=None,
300
+ train_batch_size=None,
301
+ test_batch_size=None,
302
+ ):
303
+ """
304
+ Prepare train and test dataloaders with optional transformations.
305
+
306
+ Args:
307
+ train_transform (callable): Transformation to apply to training data.
308
+ train_batch_size (int): Batch size for the training dataloader.
309
+ validation_batch_size (int): Batch size for the validation dataloader.
310
+
311
+ Returns:
312
+ tuple: trainloader, testloader
313
+ """
314
+ assert (
315
+ self.train_dataset is not None
316
+ ), "You need to call `prepare_data` before using `get_dataloaders`."
317
+
318
+ # Default batch sizes if not provided
319
+ test_batch_size = (
320
+ train_batch_size if test_batch_size is None else test_batch_size
321
+ )
322
+
323
+ # Wrap subsets in a transformed dataset if transformations are provided
324
+ train_dataset = (
325
+ TransformSubset(self.train_dataset, train_transform)
326
+ if train_transform
327
+ else self.train_dataset
328
+ )
329
+
330
+ test_dataset = (
331
+ TransformSubset(self.test_dataset, test_transform)
332
+ if test_transform
333
+ else self.test_dataset
334
+ )
335
+
336
+ trainloader = DataLoader(
337
+ train_dataset, batch_size=train_batch_size, shuffle=True
338
+ )
339
+ testloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
340
+
341
+ return trainloader, testloader
utils/inference_utils.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ from torchvision import transforms
4
+ from PIL import Image
5
+ import os
6
+ import random
7
+ from utils.data import CLASS_NAMES
8
+
9
+ # Function to find correctly and incorrectly classified images
10
+ def find_images(dataloader, model, device, num_correct, num_incorrect):
11
+ correct_images = []
12
+ incorrect_images = []
13
+ correct_labels = []
14
+ incorrect_labels = []
15
+ correct_preds = []
16
+ incorrect_preds = []
17
+
18
+ model.eval()
19
+ with torch.no_grad():
20
+ for images, labels in dataloader:
21
+ images, labels = images.to(device), labels.to(device)
22
+ outputs = model(images)
23
+ _, preds = torch.max(outputs, 1)
24
+
25
+ for i in range(images.size(0)):
26
+ if preds[i] == labels[i] and len(correct_images) < num_correct:
27
+ correct_images.append(images[i].cpu())
28
+ correct_labels.append(labels[i].cpu())
29
+ correct_preds.append(preds[i].cpu())
30
+ elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
31
+ incorrect_images.append(images[i].cpu())
32
+ incorrect_labels.append(labels[i].cpu())
33
+ incorrect_preds.append(preds[i].cpu())
34
+
35
+ if (
36
+ len(correct_images) >= num_correct
37
+ and len(incorrect_images) >= num_incorrect
38
+ ):
39
+ break
40
+ if (
41
+ len(correct_images) >= num_correct
42
+ and len(incorrect_images) >= num_incorrect
43
+ ):
44
+ break
45
+
46
+ return (
47
+ correct_images,
48
+ correct_labels,
49
+ correct_preds,
50
+ incorrect_images,
51
+ incorrect_labels,
52
+ incorrect_preds,
53
+ )
54
+
55
+ def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
56
+ correct_images_paths = []
57
+ incorrect_images_paths = []
58
+ correct_labels = []
59
+ incorrect_labels = []
60
+
61
+ label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
62
+
63
+ model.eval()
64
+ # First collect available images for the specified label or all labels
65
+ label_images = {}
66
+ if label:
67
+ if os.path.isdir(os.path.join(data_path, label)):
68
+ label_path = os.path.join(data_path, label)
69
+ label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
70
+ else:
71
+ for label in os.listdir(data_path):
72
+ label_path = os.path.join(data_path, label)
73
+ if not os.path.isdir(label_path):
74
+ continue
75
+ label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
76
+
77
+ # Randomly process images until we have enough samples
78
+ with torch.no_grad():
79
+ while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
80
+ # Randomly select a label that still has unprocessed images
81
+ available_labels = [l for l in label_images if label_images[l]]
82
+ if not available_labels:
83
+ break
84
+
85
+ selected_label = random.choice(available_labels)
86
+ image_path = random.choice(label_images[selected_label])
87
+ label_images[selected_label].remove(image_path) # Remove the selected image
88
+
89
+ image = preprocess_image(image_path, (224, 224)).to(device)
90
+ label_idx = label_to_idx[selected_label]
91
+
92
+ outputs = model(image)
93
+ _, pred = torch.max(outputs, 1)
94
+
95
+ if pred == label_idx and len(correct_images_paths) < num_correct:
96
+ correct_images_paths.append(image_path)
97
+ correct_labels.append(label_idx)
98
+ elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
99
+ incorrect_images_paths.append(image_path)
100
+ incorrect_labels.append(label_idx)
101
+
102
+ save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
103
+
104
+ def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
105
+ # Create root directories for correct and incorrect classifications
106
+ for class_name in CLASS_NAMES:
107
+ os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
108
+ os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
109
+
110
+ # Save correctly classified images
111
+ for img_path, label in zip(correct_images_paths, correct_labels):
112
+ class_name = CLASS_NAMES[label]
113
+ img_name = os.path.basename(img_path)
114
+ destination = os.path.join('predictions', class_name, 'correct', img_name)
115
+ os.makedirs(os.path.dirname(destination), exist_ok=True)
116
+ Image.open(img_path).save(destination)
117
+
118
+ # Save incorrectly classified images
119
+ for img_path, label in zip(incorrect_images_paths, incorrect_labels):
120
+ class_name = CLASS_NAMES[label]
121
+ img_name = os.path.basename(img_path)
122
+ destination = os.path.join('predictions', class_name, 'mistake', img_name)
123
+ os.makedirs(os.path.dirname(destination), exist_ok=True)
124
+ Image.open(img_path).save(destination)
125
+
126
+ def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
127
+ # Get some correctly and incorrectly classified images
128
+ (
129
+ correct_images,
130
+ correct_labels,
131
+ correct_preds,
132
+ incorrect_images,
133
+ incorrect_labels,
134
+ incorrect_preds,
135
+ ) = find_images(dataloader, model, device, num_correct, num_incorrect)
136
+ # Display the results in a grid
137
+ fig, axes = plt.subplots(
138
+ num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
139
+ )
140
+
141
+ for i in range(num_correct):
142
+ axes[i].imshow(correct_images[i].permute(1, 2, 0))
143
+ axes[i].set_title(
144
+ f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
145
+ )
146
+ axes[i].axis("off")
147
+
148
+ for i in range(num_incorrect):
149
+ axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
150
+ axes[num_correct + i].set_title(
151
+ f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
152
+ )
153
+ axes[num_correct + i].axis("off")
154
+
155
+ plt.tight_layout()
156
+ plt.show()
157
+
158
+
159
+ # Function to preprocess image
160
+ def preprocess_image(image_path, img_shape):
161
+
162
+ # Load the image using PIL
163
+ image = Image.open(image_path)
164
+
165
+ # Apply preprocessing transformations
166
+ preprocess = transforms.Compose([
167
+ transforms.Resize(img_shape),
168
+ transforms.ToTensor(),
169
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
170
+ ])
171
+ image = preprocess(image).unsqueeze(0)
172
+
173
+ return image
174
+
175
+
176
+ # Function to predict
177
+ def predict(model, image):
178
+ model.eval()
179
+ with torch.no_grad():
180
+ outputs = model(image)
181
+ return outputs
182
+
183
+
184
+ # Function to get model predictions for LIME
185
+ def batch_predict(model, images, device):
186
+ model.eval()
187
+ batch = torch.stack(
188
+ tuple(preprocess_image(image, (224, 224)) for image in images), dim=0
189
+ )
190
+ batch = batch.to(device)
191
+ logits = model(batch)
192
+ probs = torch.nn.functional.softmax(logits, dim=1)
193
+ return probs.detach().cpu().numpy()
utils/interpretability.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from lime import lime_image
2
+ from skimage.segmentation import mark_boundaries
3
+ import matplotlib.pyplot as plt
4
+ from utils.inference_utils import predict
5
+ import os
6
+ import torch
7
+ from PIL import Image
8
+ import numpy as np
9
+
10
+
11
+ def unnormalize(image):
12
+ # Make sure the image is on the correct dtype and device
13
+ # Convert mean and std to torch tensors with the correct dtype
14
+ mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) # Use torch.float32
15
+ std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) # Use torch.float32
16
+
17
+ # If the image is a PyTorch tensor, ensure it has the same dtype
18
+ if isinstance(image, torch.Tensor):
19
+ image = image * std + mean
20
+ else:
21
+ image = torch.tensor(image, dtype=torch.float32) * std + mean # Convert to torch if necessary
22
+
23
+ return image
24
+
25
+
26
+
27
+ def lime_interpret_image_inference(args, model, image, device):
28
+ # prepare the image
29
+ def prepare_for_plot(image): return unnormalize(image).cpu().numpy()
30
+ # Remove batch dimension and Rearrange dimensions to (H, W, C)
31
+ image = image.squeeze(0).permute(1, 2, 0) # From From [1, 3, 224, 224] to [224, 224, 3]
32
+
33
+ # Convert to NumPy array
34
+ image_np = image.cpu().numpy() # Ensure the tensor is on the CPU
35
+
36
+ # Initialize LIME explainer
37
+ explainer = lime_image.LimeImageExplainer()
38
+
39
+ # Define the prediction function
40
+ def predict_fn(x):
41
+ # Convert (B, H, W, C) to PyTorch tensor (B, C, H, W)
42
+ x_tensor = torch.tensor(x).permute(0, 3, 1, 2).to(device)
43
+ preds = model(x_tensor)
44
+ return preds.detach().cpu().numpy()
45
+
46
+ # Run LIME explanation
47
+ explanation = explainer.explain_instance(
48
+ image_np,
49
+ predict_fn,
50
+ top_labels=5,
51
+ hide_color=0,
52
+ num_samples=5000
53
+ )
54
+
55
+ # Get the mask for the top predicted class
56
+ temp, mask = explanation.get_image_and_mask(
57
+ explanation.top_labels[0],
58
+ positive_only=True,
59
+ num_features=10,
60
+ hide_rest=False
61
+ )
62
+
63
+ # Create a 2x2 subplot
64
+ fig, axs = plt.subplots(2, 2, figsize=(15, 15))
65
+
66
+ # Plot the original image
67
+ axs[0, 0].imshow(prepare_for_plot(image))
68
+ axs[0, 0].set_title("Original Image")
69
+
70
+ # Plot the feature that contributed the most positively
71
+ temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
72
+ axs[0, 1].imshow(prepare_for_plot(mark_boundaries(temp, mask)))
73
+ axs[0, 1].set_title("Top Positive Features")
74
+
75
+ # Plot the features that contributed the most positively and negatively
76
+ temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=1000, hide_rest=False, min_weight=0.1)
77
+ axs[1, 0].imshow(mark_boundaries(prepare_for_plot(temp), mask))
78
+ axs[1, 0].set_title("Top Positive and Negative Features")
79
+
80
+ # Plot a heatmap of the features
81
+ ind = explanation.top_labels[0]
82
+ dict_heatmap = dict(explanation.local_exp[ind])
83
+ heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
84
+ im = axs[1, 1].imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
85
+ axs[1, 1].set_title("Feature Heatmap")
86
+ fig.colorbar(im, ax=axs[1, 1])
87
+
88
+ plt.tight_layout()
89
+
90
+ # If classification mode is enabled, save in the appropriate directory
91
+ # check if the basename is an jpg image
92
+ if args.classify:
93
+ # Extract the class name and correctness from the image path
94
+ path_parts = args.image_path.split(os.sep)
95
+ class_name = path_parts[-3]
96
+ correctness = path_parts[-2] # correct or mistake
97
+ assert correctness in ['correct', 'mistake'], "The image path should contain 'correct' or 'mistake'"
98
+
99
+ # Create the full save path under the explanations directory
100
+ save_path = os.path.join('explanations', class_name, correctness, os.path.basename(args.image_path))
101
+ os.makedirs(os.path.dirname(save_path), exist_ok=True)
102
+
103
+ # Save the explanation
104
+ plt.savefig(save_path, dpi=300)
105
+ print(f"Explanation saved at {save_path}")
106
+ else:
107
+ # make dir for storing the explanations and save it there with the same name as the image
108
+ os.makedirs("./explanations", exist_ok=True)
109
+ plt.savefig(f"./explanations/{os.path.basename(args.image_path)}")
110
+ print(f"Explanation saved at ./explanations/{os.path.basename(args.image_path)}")
utils/train_utils.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import models
2
+ import torch.nn as nn
3
+ from tqdm import tqdm
4
+ import torch
5
+ import mlflow
6
+ from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
7
+ from sklearn.ensemble import RandomForestClassifier
8
+
9
+
10
+ # Define the training loop
11
+ def train_one_epoch(model, trainloader, criterion, optimizer, device):
12
+ model.train()
13
+ running_loss = 0.0
14
+ correct = 0
15
+ total = 0
16
+
17
+ for images, labels in tqdm(trainloader, desc="Training", leave=False):
18
+ images, labels = images.to(device), labels.to(device)
19
+
20
+ # Forward pass
21
+ outputs = model(images)
22
+ loss = criterion(outputs, labels)
23
+
24
+ # Backward pass and optimization
25
+ optimizer.zero_grad()
26
+ loss.backward()
27
+ optimizer.step()
28
+
29
+ # Track loss and accuracy
30
+ running_loss += loss.item()
31
+ _, predicted = outputs.max(1)
32
+ correct += predicted.eq(labels).sum().item()
33
+ total += labels.size(0)
34
+
35
+ epoch_loss = running_loss / len(trainloader)
36
+ epoch_accuracy = 100.0 * correct / total
37
+ return epoch_loss, epoch_accuracy
38
+
39
+
40
+ # Define the evaluation loop
41
+ @torch.no_grad()
42
+ def evaluate(model, testloader, criterion, device):
43
+ model.eval()
44
+ running_loss = 0.0
45
+ correct = 0
46
+ total = 0
47
+ all_labels = []
48
+ all_predictions = []
49
+
50
+ for images, labels in tqdm(testloader, desc="Evaluating", leave=False):
51
+ images, labels = images.to(device), labels.to(device)
52
+
53
+ # Forward pass
54
+ outputs = model(images)
55
+ loss = criterion(outputs, labels)
56
+
57
+ # Track loss and accuracy
58
+ running_loss += loss.item()
59
+ _, predicted = outputs.max(1)
60
+ correct += predicted.eq(labels).sum().item()
61
+ total += labels.size(0)
62
+
63
+ # Collect all labels and predictions for metrics
64
+ all_labels.extend(labels.cpu().numpy())
65
+ all_predictions.extend(predicted.cpu().numpy())
66
+
67
+ epoch_loss = running_loss / len(testloader)
68
+
69
+ # Calculate accuracy, precision, recall, and F1-score
70
+ epoch_accuracy = accuracy_score(all_labels, all_predictions, normalize=True) * 100
71
+ precision = precision_score(all_labels, all_predictions, average="weighted")
72
+ recall = recall_score(all_labels, all_predictions, average="weighted")
73
+ f1 = f1_score(all_labels, all_predictions, average="weighted")
74
+
75
+ return epoch_loss, epoch_accuracy, precision, recall, f1
76
+
77
+
78
+ # Define the pipeline
79
+ def train_and_evaluate(
80
+ model,
81
+ trainloader,
82
+ testloader,
83
+ criterion,
84
+ optimizer,
85
+ device,
86
+ epochs,
87
+ use_mlflow=False,
88
+ ):
89
+ """
90
+ Train and evaluate the model.
91
+
92
+ Args:
93
+ model (nn.Module): The neural network model.
94
+ trainloader (DataLoader): DataLoader for training data.
95
+ testloader (DataLoader): DataLoader for test data.
96
+ criterion (nn.Module): Loss function.
97
+ optimizer (optim.Optimizer): Optimizer.
98
+ device (torch.device): Device to train on ('cuda' or 'cpu').
99
+ epochs (int): Number of epochs to train.
100
+
101
+ Returns:
102
+ dict: Training and evaluation statistics.
103
+ """
104
+ history = {
105
+ "train_loss": [],
106
+ "train_acc": [],
107
+ "test_loss": [],
108
+ "test_acc": [],
109
+ "precision": [],
110
+ "recall": [],
111
+ "f1": [],
112
+ }
113
+
114
+ model.to(device)
115
+
116
+ for epoch in range(epochs):
117
+ print(f"Epoch {epoch + 1}/{epochs}")
118
+
119
+ # Train for one epoch
120
+ train_loss, train_acc = train_one_epoch(
121
+ model, trainloader, criterion, optimizer, device
122
+ )
123
+ print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
124
+
125
+ # Evaluate the model
126
+ test_loss, test_acc, precision, recall, f1 = evaluate(
127
+ model, testloader, criterion, device
128
+ )
129
+ print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")
130
+
131
+ # Save statistics
132
+ history["train_loss"].append(train_loss)
133
+ history["train_acc"].append(train_acc)
134
+ history["test_loss"].append(test_loss)
135
+ history["test_acc"].append(test_acc)
136
+ history["precision"].append(precision)
137
+ history["recall"].append(recall)
138
+ history["f1"].append(f1)
139
+
140
+ if use_mlflow:
141
+ mlflow.log_metric("epoch", epoch)
142
+ mlflow.log_metric("train_loss", train_loss)
143
+ mlflow.log_metric("train_acc", train_acc)
144
+ mlflow.log_metric("test_loss", test_loss)
145
+ mlflow.log_metric("test_acc", test_acc)
146
+ mlflow.log_metric("precision", precision)
147
+ mlflow.log_metric("recall", recall)
148
+ mlflow.log_metric("f1", f1)
149
+ return history
150
+
151
+
152
+ def set_parameter_requires_grad(model, feature_extracting):
153
+ if feature_extracting:
154
+ for param in model.parameters():
155
+ param.requires_grad = False
156
+
157
+
158
+ def initialize_model(
159
+ model_name,
160
+ num_classes,
161
+ feature_extract=True,
162
+ use_pretrained=True,
163
+ hidden_size=512,
164
+ image_shape=(224, 224, 3),
165
+ ):
166
+ # Initialize these variables which will be set in this if statement. Each of these
167
+ # variables is model specific.
168
+ model_ft = None
169
+
170
+ if model_name == "resnet":
171
+ """ Resnet18
172
+ """
173
+ model_ft = models.resnet18(pretrained=use_pretrained)
174
+ set_parameter_requires_grad(model_ft, feature_extract)
175
+ num_ftrs = model_ft.fc.in_features
176
+ model_ft.fc = nn.Linear(num_ftrs, num_classes)
177
+
178
+ elif model_name == "alexnet":
179
+ """ Alexnet
180
+ """
181
+ model_ft = models.alexnet(pretrained=use_pretrained)
182
+ set_parameter_requires_grad(model_ft, feature_extract)
183
+ num_ftrs = model_ft.classifier[6].in_features
184
+ model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
185
+
186
+ elif model_name == "vgg":
187
+ """ VGG11_bn
188
+ """
189
+ model_ft = models.vgg11_bn(pretrained=use_pretrained)
190
+ set_parameter_requires_grad(model_ft, feature_extract)
191
+ num_ftrs = model_ft.classifier[6].in_features
192
+ model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
193
+
194
+ elif model_name == "squeezenet":
195
+ """ Squeezenet
196
+ """
197
+ model_ft = models.squeezenet1_0(pretrained=use_pretrained)
198
+ set_parameter_requires_grad(model_ft, feature_extract)
199
+ model_ft.classifier[1] = nn.Conv2d(
200
+ 512, num_classes, kernel_size=(1, 1), stride=(1, 1)
201
+ )
202
+ model_ft.num_classes = num_classes
203
+
204
+ elif model_name == "densenet":
205
+ """ Densenet
206
+ """
207
+ model_ft = models.densenet121(pretrained=use_pretrained)
208
+ set_parameter_requires_grad(model_ft, feature_extract)
209
+ num_ftrs = model_ft.classifier.in_features
210
+ model_ft.classifier = nn.Linear(num_ftrs, num_classes)
211
+
212
+ elif model_name == "custom_mlp":
213
+ """ Custom MLP
214
+ """
215
+ model_ft = nn.Sequential(
216
+ nn.Linear(image_shape[0] * image_shape[1] * image_shape[2], hidden_size),
217
+ nn.ReLU(),
218
+ nn.Linear(hidden_size, hidden_size),
219
+ nn.ReLU(),
220
+ nn.Linear(hidden_size, hidden_size // 2),
221
+ nn.ReLU(),
222
+ nn.Linear(hidden_size // 2, num_classes),
223
+ )
224
+ elif model_name == "custom_cnn":
225
+ """ Custom CNN
226
+ """
227
+ model_ft = nn.Sequential(
228
+ nn.Conv2d(3, 16, 3, 1, 1),
229
+ nn.ReLU(),
230
+ nn.MaxPool2d(2),
231
+ nn.Conv2d(16, 32, 3, 1, 1),
232
+ nn.ReLU(),
233
+ nn.MaxPool2d(2),
234
+ nn.Conv2d(32, 64, 3, 1, 1),
235
+ nn.ReLU(),
236
+ nn.MaxPool2d(2),
237
+ nn.Flatten(),
238
+ nn.Linear(64 * 28 * 28, hidden_size),
239
+ nn.ReLU(),
240
+ nn.Linear(hidden_size, num_classes),
241
+ )
242
+ elif model_name == "random_forest":
243
+ """ Random Forest
244
+ """
245
+ model_ft = RandomForestClassifier(n_estimators=100, random_state=42)
246
+
247
+ else:
248
+ print("Invalid model name, exiting...")
249
+ exit()
250
+
251
+ return model_ft