MostHumble
commited on
Commit
•
eed12b2
1
Parent(s):
b8e9456
add inference script
Browse files- utils/__init__.py +0 -0
- utils/data.py +341 -0
- utils/inference_utils.py +193 -0
- utils/interpretability.py +110 -0
- utils/train_utils.py +251 -0
utils/__init__.py
ADDED
File without changes
|
utils/data.py
ADDED
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from torch.utils.data import DataLoader, Dataset, Subset
|
3 |
+
from torchvision.datasets import ImageFolder
|
4 |
+
from sklearn.model_selection import train_test_split
|
5 |
+
import torch
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
import pickle
|
8 |
+
|
9 |
+
CLASS_NAMES = ['Abra',
|
10 |
+
'Aerodactyl',
|
11 |
+
'Alakazam',
|
12 |
+
'Alolan Sandslash',
|
13 |
+
'Arbok',
|
14 |
+
'Arcanine',
|
15 |
+
'Articuno',
|
16 |
+
'Beedrill',
|
17 |
+
'Bellsprout',
|
18 |
+
'Blastoise',
|
19 |
+
'Bulbasaur',
|
20 |
+
'Butterfree',
|
21 |
+
'Caterpie',
|
22 |
+
'Chansey',
|
23 |
+
'Charizard',
|
24 |
+
'Charmander',
|
25 |
+
'Charmeleon',
|
26 |
+
'Clefable',
|
27 |
+
'Clefairy',
|
28 |
+
'Cloyster',
|
29 |
+
'Cubone',
|
30 |
+
'Dewgong',
|
31 |
+
'Diglett',
|
32 |
+
'Ditto',
|
33 |
+
'Dodrio',
|
34 |
+
'Doduo',
|
35 |
+
'Dragonair',
|
36 |
+
'Dragonite',
|
37 |
+
'Dratini',
|
38 |
+
'Drowzee',
|
39 |
+
'Dugtrio',
|
40 |
+
'Eevee',
|
41 |
+
'Ekans',
|
42 |
+
'Electabuzz',
|
43 |
+
'Electrode',
|
44 |
+
'Exeggcute',
|
45 |
+
'Exeggutor',
|
46 |
+
'Farfetchd',
|
47 |
+
'Fearow',
|
48 |
+
'Flareon',
|
49 |
+
'Gastly',
|
50 |
+
'Gengar',
|
51 |
+
'Geodude',
|
52 |
+
'Gloom',
|
53 |
+
'Golbat',
|
54 |
+
'Goldeen',
|
55 |
+
'Golduck',
|
56 |
+
'Golem',
|
57 |
+
'Graveler',
|
58 |
+
'Grimer',
|
59 |
+
'Growlithe',
|
60 |
+
'Gyarados',
|
61 |
+
'Haunter',
|
62 |
+
'Hitmonchan',
|
63 |
+
'Hitmonlee',
|
64 |
+
'Horsea',
|
65 |
+
'Hypno',
|
66 |
+
'Ivysaur',
|
67 |
+
'Jigglypuff',
|
68 |
+
'Jolteon',
|
69 |
+
'Jynx',
|
70 |
+
'Kabuto',
|
71 |
+
'Kabutops',
|
72 |
+
'Kadabra',
|
73 |
+
'Kakuna',
|
74 |
+
'Kangaskhan',
|
75 |
+
'Kingler',
|
76 |
+
'Koffing',
|
77 |
+
'Krabby',
|
78 |
+
'Lapras',
|
79 |
+
'Lickitung',
|
80 |
+
'Machamp',
|
81 |
+
'Machoke',
|
82 |
+
'Machop',
|
83 |
+
'Magikarp',
|
84 |
+
'Magmar',
|
85 |
+
'Magnemite',
|
86 |
+
'Magneton',
|
87 |
+
'Mankey',
|
88 |
+
'Marowak',
|
89 |
+
'Meowth',
|
90 |
+
'Metapod',
|
91 |
+
'Mew',
|
92 |
+
'Mewtwo',
|
93 |
+
'Moltres',
|
94 |
+
'MrMime',
|
95 |
+
'Muk',
|
96 |
+
'Nidoking',
|
97 |
+
'Nidoqueen',
|
98 |
+
'Nidorina',
|
99 |
+
'Nidorino',
|
100 |
+
'Ninetales',
|
101 |
+
'Oddish',
|
102 |
+
'Omanyte',
|
103 |
+
'Omastar',
|
104 |
+
'Onix',
|
105 |
+
'Paras',
|
106 |
+
'Parasect',
|
107 |
+
'Persian',
|
108 |
+
'Pidgeot',
|
109 |
+
'Pidgeotto',
|
110 |
+
'Pidgey',
|
111 |
+
'Pikachu',
|
112 |
+
'Pinsir',
|
113 |
+
'Poliwag',
|
114 |
+
'Poliwhirl',
|
115 |
+
'Poliwrath',
|
116 |
+
'Ponyta',
|
117 |
+
'Porygon',
|
118 |
+
'Primeape',
|
119 |
+
'Psyduck',
|
120 |
+
'Raichu',
|
121 |
+
'Rapidash',
|
122 |
+
'Raticate',
|
123 |
+
'Rattata',
|
124 |
+
'Rhydon',
|
125 |
+
'Rhyhorn',
|
126 |
+
'Sandshrew',
|
127 |
+
'Sandslash',
|
128 |
+
'Scyther',
|
129 |
+
'Seadra',
|
130 |
+
'Seaking',
|
131 |
+
'Seel',
|
132 |
+
'Shellder',
|
133 |
+
'Slowbro',
|
134 |
+
'Slowpoke',
|
135 |
+
'Snorlax',
|
136 |
+
'Spearow',
|
137 |
+
'Squirtle',
|
138 |
+
'Starmie',
|
139 |
+
'Staryu',
|
140 |
+
'Tangela',
|
141 |
+
'Tauros',
|
142 |
+
'Tentacool',
|
143 |
+
'Tentacruel',
|
144 |
+
'Vaporeon',
|
145 |
+
'Venomoth',
|
146 |
+
'Venonat',
|
147 |
+
'Venusaur',
|
148 |
+
'Victreebel',
|
149 |
+
'Vileplume',
|
150 |
+
'Voltorb',
|
151 |
+
'Vulpix',
|
152 |
+
'Wartortle',
|
153 |
+
'Weedle',
|
154 |
+
'Weepinbell',
|
155 |
+
'Weezing',
|
156 |
+
'Wigglytuff',
|
157 |
+
'Zapdos',
|
158 |
+
'Zubat']
|
159 |
+
|
160 |
+
class TransformSubset(Dataset):
|
161 |
+
"""
|
162 |
+
Wrapper for applying transformations to a Subset.
|
163 |
+
"""
|
164 |
+
|
165 |
+
def __init__(self, subset, transform):
|
166 |
+
self.subset = subset
|
167 |
+
self.transform = transform
|
168 |
+
|
169 |
+
def __getitem__(self, idx):
|
170 |
+
img, label = self.subset[idx]
|
171 |
+
if self.transform:
|
172 |
+
img = self.transform(img)
|
173 |
+
return img, label
|
174 |
+
|
175 |
+
def __len__(self):
|
176 |
+
return len(self.subset)
|
177 |
+
|
178 |
+
|
179 |
+
class PokemonDataModule(Dataset):
|
180 |
+
def __init__(self, data_dir):
|
181 |
+
self.dataset = ImageFolder(root=data_dir)
|
182 |
+
self.class_names = self.dataset.classes
|
183 |
+
|
184 |
+
def __len__(self):
|
185 |
+
return len(self.dataset)
|
186 |
+
|
187 |
+
def __getitem__(self, index):
|
188 |
+
image, label = self.dataset[index]
|
189 |
+
return image, label
|
190 |
+
|
191 |
+
def plot_examples(self, dataloader, n_rows=1, n_cols=4, stats=None):
|
192 |
+
"""
|
193 |
+
Plot examples from a DataLoader.
|
194 |
+
|
195 |
+
Args:
|
196 |
+
dataloader (DataLoader): DataLoader object to fetch images and labels from.
|
197 |
+
n_rows (int): Number of rows in the plot grid.
|
198 |
+
n_cols (int): Number of columns in the plot grid.
|
199 |
+
denormalize (callable, optional): Function to reverse normalization for visualization.
|
200 |
+
Should accept a tensor and return a denormalized tensor.
|
201 |
+
"""
|
202 |
+
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 3, n_rows * 3))
|
203 |
+
axes = axes.flatten() # Flatten to iterate easily
|
204 |
+
|
205 |
+
# Iterate over the dataloader to get a batch of data
|
206 |
+
for data, labels in dataloader:
|
207 |
+
# Take the first n_rows * n_cols samples from the batch
|
208 |
+
for i, ax in enumerate(axes[: n_rows * n_cols]):
|
209 |
+
if i >= len(data): # If fewer samples than the grid size, stop
|
210 |
+
break
|
211 |
+
|
212 |
+
img, label = data[i], labels[i]
|
213 |
+
|
214 |
+
# Apply denormalization if provided
|
215 |
+
if stats:
|
216 |
+
img = self._denormalize(img, stats)
|
217 |
+
|
218 |
+
# Convert CHW to HWC for plotting
|
219 |
+
img = img.permute(1, 2, 0).cpu().numpy()
|
220 |
+
|
221 |
+
ax.imshow(img)
|
222 |
+
ax.set_title(self.class_names[label.item()])
|
223 |
+
ax.axis("off")
|
224 |
+
break # Only process the first batch
|
225 |
+
|
226 |
+
plt.tight_layout()
|
227 |
+
plt.show()
|
228 |
+
|
229 |
+
def _denormalize(self, img, stats):
|
230 |
+
"""
|
231 |
+
Denormalize an image tensor.
|
232 |
+
|
233 |
+
Args:
|
234 |
+
img (Tensor): Image tensor with shape (C, H, W).
|
235 |
+
stats (dict): Dictionary containing 'means' and 'stds' for each channel.
|
236 |
+
Example: {'means': [0.485, 0.456, 0.406], 'stds': [0.229, 0.224, 0.225]}.
|
237 |
+
|
238 |
+
Returns:
|
239 |
+
Tensor: Denormalized image tensor.
|
240 |
+
"""
|
241 |
+
return img * stats["std"].view(-1, 1, 1) + stats["mean"].view(-1, 1, 1)
|
242 |
+
|
243 |
+
def _get_stats(self, dataset):
|
244 |
+
"""
|
245 |
+
Calculate the mean and standard deviation of the dataset for standardization.
|
246 |
+
"""
|
247 |
+
dataloader = DataLoader(dataset, batch_size=2048, shuffle=False)
|
248 |
+
total_sum, total_squared_sum, total_count = 0, 0, 0
|
249 |
+
with torch.cuda.device(0):
|
250 |
+
for data, _ in dataloader:
|
251 |
+
data.cuda()
|
252 |
+
total_sum += data.sum(dim=(0, 2, 3))
|
253 |
+
total_squared_sum += (data**2).sum(dim=(0, 2, 3))
|
254 |
+
total_count += data.size(0) * data.size(2) * data.size(3)
|
255 |
+
|
256 |
+
means = total_sum / total_count
|
257 |
+
stds = torch.sqrt((total_squared_sum / total_count) - (means**2))
|
258 |
+
return {"mean": means, "std": stds}
|
259 |
+
|
260 |
+
def prepare_data(self, indices_file="indices.pkl", get_stats=False):
|
261 |
+
"""
|
262 |
+
Prepare train and test dataloaders with optional transformations.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
indices_file (str): Path to save or load train/test indices.
|
266 |
+
transform (callable): Primary transformation to apply to the data.
|
267 |
+
additional_transforms (callable): Additional transformations to compose.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
tuple: trainloader, testloader
|
271 |
+
"""
|
272 |
+
try:
|
273 |
+
with open(indices_file, "rb") as f:
|
274 |
+
self.train_indices, self.test_indices = pickle.load(f)
|
275 |
+
except (EOFError, FileNotFoundError):
|
276 |
+
# Generate new indices if file is empty or doesn't exist
|
277 |
+
self.train_indices, self.test_indices = train_test_split(
|
278 |
+
range(len(self.dataset)),
|
279 |
+
test_size=0.2,
|
280 |
+
stratify=self.dataset.targets,
|
281 |
+
random_state=42,
|
282 |
+
)
|
283 |
+
|
284 |
+
# Ensure directory exists before saving
|
285 |
+
os.makedirs(os.path.dirname(indices_file) or ".", exist_ok=True)
|
286 |
+
|
287 |
+
with open(indices_file, "wb") as f:
|
288 |
+
pickle.dump([self.train_indices, self.test_indices], f)
|
289 |
+
|
290 |
+
# Prepare train and test subsets
|
291 |
+
self.train_dataset = Subset(self.dataset, self.train_indices)
|
292 |
+
self.test_dataset = Subset(self.dataset, self.test_indices)
|
293 |
+
|
294 |
+
return self._get_stats(self.train_dataset) if get_stats else None
|
295 |
+
|
296 |
+
def get_dataloaders(
|
297 |
+
self,
|
298 |
+
train_transform=None,
|
299 |
+
test_transform=None,
|
300 |
+
train_batch_size=None,
|
301 |
+
test_batch_size=None,
|
302 |
+
):
|
303 |
+
"""
|
304 |
+
Prepare train and test dataloaders with optional transformations.
|
305 |
+
|
306 |
+
Args:
|
307 |
+
train_transform (callable): Transformation to apply to training data.
|
308 |
+
train_batch_size (int): Batch size for the training dataloader.
|
309 |
+
validation_batch_size (int): Batch size for the validation dataloader.
|
310 |
+
|
311 |
+
Returns:
|
312 |
+
tuple: trainloader, testloader
|
313 |
+
"""
|
314 |
+
assert (
|
315 |
+
self.train_dataset is not None
|
316 |
+
), "You need to call `prepare_data` before using `get_dataloaders`."
|
317 |
+
|
318 |
+
# Default batch sizes if not provided
|
319 |
+
test_batch_size = (
|
320 |
+
train_batch_size if test_batch_size is None else test_batch_size
|
321 |
+
)
|
322 |
+
|
323 |
+
# Wrap subsets in a transformed dataset if transformations are provided
|
324 |
+
train_dataset = (
|
325 |
+
TransformSubset(self.train_dataset, train_transform)
|
326 |
+
if train_transform
|
327 |
+
else self.train_dataset
|
328 |
+
)
|
329 |
+
|
330 |
+
test_dataset = (
|
331 |
+
TransformSubset(self.test_dataset, test_transform)
|
332 |
+
if test_transform
|
333 |
+
else self.test_dataset
|
334 |
+
)
|
335 |
+
|
336 |
+
trainloader = DataLoader(
|
337 |
+
train_dataset, batch_size=train_batch_size, shuffle=True
|
338 |
+
)
|
339 |
+
testloader = DataLoader(test_dataset, batch_size=test_batch_size, shuffle=False)
|
340 |
+
|
341 |
+
return trainloader, testloader
|
utils/inference_utils.py
ADDED
@@ -0,0 +1,193 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
from torchvision import transforms
|
4 |
+
from PIL import Image
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
from utils.data import CLASS_NAMES
|
8 |
+
|
9 |
+
# Function to find correctly and incorrectly classified images
|
10 |
+
def find_images(dataloader, model, device, num_correct, num_incorrect):
|
11 |
+
correct_images = []
|
12 |
+
incorrect_images = []
|
13 |
+
correct_labels = []
|
14 |
+
incorrect_labels = []
|
15 |
+
correct_preds = []
|
16 |
+
incorrect_preds = []
|
17 |
+
|
18 |
+
model.eval()
|
19 |
+
with torch.no_grad():
|
20 |
+
for images, labels in dataloader:
|
21 |
+
images, labels = images.to(device), labels.to(device)
|
22 |
+
outputs = model(images)
|
23 |
+
_, preds = torch.max(outputs, 1)
|
24 |
+
|
25 |
+
for i in range(images.size(0)):
|
26 |
+
if preds[i] == labels[i] and len(correct_images) < num_correct:
|
27 |
+
correct_images.append(images[i].cpu())
|
28 |
+
correct_labels.append(labels[i].cpu())
|
29 |
+
correct_preds.append(preds[i].cpu())
|
30 |
+
elif preds[i] != labels[i] and len(incorrect_images) < num_incorrect:
|
31 |
+
incorrect_images.append(images[i].cpu())
|
32 |
+
incorrect_labels.append(labels[i].cpu())
|
33 |
+
incorrect_preds.append(preds[i].cpu())
|
34 |
+
|
35 |
+
if (
|
36 |
+
len(correct_images) >= num_correct
|
37 |
+
and len(incorrect_images) >= num_incorrect
|
38 |
+
):
|
39 |
+
break
|
40 |
+
if (
|
41 |
+
len(correct_images) >= num_correct
|
42 |
+
and len(incorrect_images) >= num_incorrect
|
43 |
+
):
|
44 |
+
break
|
45 |
+
|
46 |
+
return (
|
47 |
+
correct_images,
|
48 |
+
correct_labels,
|
49 |
+
correct_preds,
|
50 |
+
incorrect_images,
|
51 |
+
incorrect_labels,
|
52 |
+
incorrect_preds,
|
53 |
+
)
|
54 |
+
|
55 |
+
def find_images_from_path(data_path, model, device, num_correct=2, num_incorrect=2, label=None):
|
56 |
+
correct_images_paths = []
|
57 |
+
incorrect_images_paths = []
|
58 |
+
correct_labels = []
|
59 |
+
incorrect_labels = []
|
60 |
+
|
61 |
+
label_to_idx = {label: idx for idx, label in enumerate(CLASS_NAMES)}
|
62 |
+
|
63 |
+
model.eval()
|
64 |
+
# First collect available images for the specified label or all labels
|
65 |
+
label_images = {}
|
66 |
+
if label:
|
67 |
+
if os.path.isdir(os.path.join(data_path, label)):
|
68 |
+
label_path = os.path.join(data_path, label)
|
69 |
+
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
|
70 |
+
else:
|
71 |
+
for label in os.listdir(data_path):
|
72 |
+
label_path = os.path.join(data_path, label)
|
73 |
+
if not os.path.isdir(label_path):
|
74 |
+
continue
|
75 |
+
label_images[label] = [os.path.join(label_path, img) for img in os.listdir(label_path)]
|
76 |
+
|
77 |
+
# Randomly process images until we have enough samples
|
78 |
+
with torch.no_grad():
|
79 |
+
while len(correct_images_paths) < num_correct or len(incorrect_images_paths) < num_incorrect:
|
80 |
+
# Randomly select a label that still has unprocessed images
|
81 |
+
available_labels = [l for l in label_images if label_images[l]]
|
82 |
+
if not available_labels:
|
83 |
+
break
|
84 |
+
|
85 |
+
selected_label = random.choice(available_labels)
|
86 |
+
image_path = random.choice(label_images[selected_label])
|
87 |
+
label_images[selected_label].remove(image_path) # Remove the selected image
|
88 |
+
|
89 |
+
image = preprocess_image(image_path, (224, 224)).to(device)
|
90 |
+
label_idx = label_to_idx[selected_label]
|
91 |
+
|
92 |
+
outputs = model(image)
|
93 |
+
_, pred = torch.max(outputs, 1)
|
94 |
+
|
95 |
+
if pred == label_idx and len(correct_images_paths) < num_correct:
|
96 |
+
correct_images_paths.append(image_path)
|
97 |
+
correct_labels.append(label_idx)
|
98 |
+
elif pred != label_idx and len(incorrect_images_paths) < num_incorrect:
|
99 |
+
incorrect_images_paths.append(image_path)
|
100 |
+
incorrect_labels.append(label_idx)
|
101 |
+
|
102 |
+
save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels)
|
103 |
+
|
104 |
+
def save_images_by_class(correct_images_paths, correct_labels, incorrect_images_paths, incorrect_labels):
|
105 |
+
# Create root directories for correct and incorrect classifications
|
106 |
+
for class_name in CLASS_NAMES:
|
107 |
+
os.makedirs(os.path.join('predictions', class_name, 'correct'), exist_ok=True)
|
108 |
+
os.makedirs(os.path.join('predictions', class_name, 'mistake'), exist_ok=True)
|
109 |
+
|
110 |
+
# Save correctly classified images
|
111 |
+
for img_path, label in zip(correct_images_paths, correct_labels):
|
112 |
+
class_name = CLASS_NAMES[label]
|
113 |
+
img_name = os.path.basename(img_path)
|
114 |
+
destination = os.path.join('predictions', class_name, 'correct', img_name)
|
115 |
+
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
116 |
+
Image.open(img_path).save(destination)
|
117 |
+
|
118 |
+
# Save incorrectly classified images
|
119 |
+
for img_path, label in zip(incorrect_images_paths, incorrect_labels):
|
120 |
+
class_name = CLASS_NAMES[label]
|
121 |
+
img_name = os.path.basename(img_path)
|
122 |
+
destination = os.path.join('predictions', class_name, 'mistake', img_name)
|
123 |
+
os.makedirs(os.path.dirname(destination), exist_ok=True)
|
124 |
+
Image.open(img_path).save(destination)
|
125 |
+
|
126 |
+
def show_samples(dataloader, model, device, num_correct=3, num_incorrect=3):
|
127 |
+
# Get some correctly and incorrectly classified images
|
128 |
+
(
|
129 |
+
correct_images,
|
130 |
+
correct_labels,
|
131 |
+
correct_preds,
|
132 |
+
incorrect_images,
|
133 |
+
incorrect_labels,
|
134 |
+
incorrect_preds,
|
135 |
+
) = find_images(dataloader, model, device, num_correct, num_incorrect)
|
136 |
+
# Display the results in a grid
|
137 |
+
fig, axes = plt.subplots(
|
138 |
+
num_correct + num_incorrect, 1, figsize=(10, (num_correct + num_incorrect) * 5)
|
139 |
+
)
|
140 |
+
|
141 |
+
for i in range(num_correct):
|
142 |
+
axes[i].imshow(correct_images[i].permute(1, 2, 0))
|
143 |
+
axes[i].set_title(
|
144 |
+
f"Correctly Classified: True Label = {correct_labels[i]}, Predicted = {correct_preds[i]}"
|
145 |
+
)
|
146 |
+
axes[i].axis("off")
|
147 |
+
|
148 |
+
for i in range(num_incorrect):
|
149 |
+
axes[num_correct + i].imshow(incorrect_images[i].permute(1, 2, 0))
|
150 |
+
axes[num_correct + i].set_title(
|
151 |
+
f"Incorrectly Classified: True Label = {incorrect_labels[i]}, Predicted = {incorrect_preds[i]}"
|
152 |
+
)
|
153 |
+
axes[num_correct + i].axis("off")
|
154 |
+
|
155 |
+
plt.tight_layout()
|
156 |
+
plt.show()
|
157 |
+
|
158 |
+
|
159 |
+
# Function to preprocess image
|
160 |
+
def preprocess_image(image_path, img_shape):
|
161 |
+
|
162 |
+
# Load the image using PIL
|
163 |
+
image = Image.open(image_path)
|
164 |
+
|
165 |
+
# Apply preprocessing transformations
|
166 |
+
preprocess = transforms.Compose([
|
167 |
+
transforms.Resize(img_shape),
|
168 |
+
transforms.ToTensor(),
|
169 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
170 |
+
])
|
171 |
+
image = preprocess(image).unsqueeze(0)
|
172 |
+
|
173 |
+
return image
|
174 |
+
|
175 |
+
|
176 |
+
# Function to predict
|
177 |
+
def predict(model, image):
|
178 |
+
model.eval()
|
179 |
+
with torch.no_grad():
|
180 |
+
outputs = model(image)
|
181 |
+
return outputs
|
182 |
+
|
183 |
+
|
184 |
+
# Function to get model predictions for LIME
|
185 |
+
def batch_predict(model, images, device):
|
186 |
+
model.eval()
|
187 |
+
batch = torch.stack(
|
188 |
+
tuple(preprocess_image(image, (224, 224)) for image in images), dim=0
|
189 |
+
)
|
190 |
+
batch = batch.to(device)
|
191 |
+
logits = model(batch)
|
192 |
+
probs = torch.nn.functional.softmax(logits, dim=1)
|
193 |
+
return probs.detach().cpu().numpy()
|
utils/interpretability.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from lime import lime_image
|
2 |
+
from skimage.segmentation import mark_boundaries
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from utils.inference_utils import predict
|
5 |
+
import os
|
6 |
+
import torch
|
7 |
+
from PIL import Image
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
def unnormalize(image):
|
12 |
+
# Make sure the image is on the correct dtype and device
|
13 |
+
# Convert mean and std to torch tensors with the correct dtype
|
14 |
+
mean = torch.tensor([0.485, 0.456, 0.406], dtype=torch.float32) # Use torch.float32
|
15 |
+
std = torch.tensor([0.229, 0.224, 0.225], dtype=torch.float32) # Use torch.float32
|
16 |
+
|
17 |
+
# If the image is a PyTorch tensor, ensure it has the same dtype
|
18 |
+
if isinstance(image, torch.Tensor):
|
19 |
+
image = image * std + mean
|
20 |
+
else:
|
21 |
+
image = torch.tensor(image, dtype=torch.float32) * std + mean # Convert to torch if necessary
|
22 |
+
|
23 |
+
return image
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
def lime_interpret_image_inference(args, model, image, device):
|
28 |
+
# prepare the image
|
29 |
+
def prepare_for_plot(image): return unnormalize(image).cpu().numpy()
|
30 |
+
# Remove batch dimension and Rearrange dimensions to (H, W, C)
|
31 |
+
image = image.squeeze(0).permute(1, 2, 0) # From From [1, 3, 224, 224] to [224, 224, 3]
|
32 |
+
|
33 |
+
# Convert to NumPy array
|
34 |
+
image_np = image.cpu().numpy() # Ensure the tensor is on the CPU
|
35 |
+
|
36 |
+
# Initialize LIME explainer
|
37 |
+
explainer = lime_image.LimeImageExplainer()
|
38 |
+
|
39 |
+
# Define the prediction function
|
40 |
+
def predict_fn(x):
|
41 |
+
# Convert (B, H, W, C) to PyTorch tensor (B, C, H, W)
|
42 |
+
x_tensor = torch.tensor(x).permute(0, 3, 1, 2).to(device)
|
43 |
+
preds = model(x_tensor)
|
44 |
+
return preds.detach().cpu().numpy()
|
45 |
+
|
46 |
+
# Run LIME explanation
|
47 |
+
explanation = explainer.explain_instance(
|
48 |
+
image_np,
|
49 |
+
predict_fn,
|
50 |
+
top_labels=5,
|
51 |
+
hide_color=0,
|
52 |
+
num_samples=5000
|
53 |
+
)
|
54 |
+
|
55 |
+
# Get the mask for the top predicted class
|
56 |
+
temp, mask = explanation.get_image_and_mask(
|
57 |
+
explanation.top_labels[0],
|
58 |
+
positive_only=True,
|
59 |
+
num_features=10,
|
60 |
+
hide_rest=False
|
61 |
+
)
|
62 |
+
|
63 |
+
# Create a 2x2 subplot
|
64 |
+
fig, axs = plt.subplots(2, 2, figsize=(15, 15))
|
65 |
+
|
66 |
+
# Plot the original image
|
67 |
+
axs[0, 0].imshow(prepare_for_plot(image))
|
68 |
+
axs[0, 0].set_title("Original Image")
|
69 |
+
|
70 |
+
# Plot the feature that contributed the most positively
|
71 |
+
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=True, num_features=5, hide_rest=True)
|
72 |
+
axs[0, 1].imshow(prepare_for_plot(mark_boundaries(temp, mask)))
|
73 |
+
axs[0, 1].set_title("Top Positive Features")
|
74 |
+
|
75 |
+
# Plot the features that contributed the most positively and negatively
|
76 |
+
temp, mask = explanation.get_image_and_mask(explanation.top_labels[0], positive_only=False, num_features=1000, hide_rest=False, min_weight=0.1)
|
77 |
+
axs[1, 0].imshow(mark_boundaries(prepare_for_plot(temp), mask))
|
78 |
+
axs[1, 0].set_title("Top Positive and Negative Features")
|
79 |
+
|
80 |
+
# Plot a heatmap of the features
|
81 |
+
ind = explanation.top_labels[0]
|
82 |
+
dict_heatmap = dict(explanation.local_exp[ind])
|
83 |
+
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments)
|
84 |
+
im = axs[1, 1].imshow(heatmap, cmap='RdBu', vmin=-heatmap.max(), vmax=heatmap.max())
|
85 |
+
axs[1, 1].set_title("Feature Heatmap")
|
86 |
+
fig.colorbar(im, ax=axs[1, 1])
|
87 |
+
|
88 |
+
plt.tight_layout()
|
89 |
+
|
90 |
+
# If classification mode is enabled, save in the appropriate directory
|
91 |
+
# check if the basename is an jpg image
|
92 |
+
if args.classify:
|
93 |
+
# Extract the class name and correctness from the image path
|
94 |
+
path_parts = args.image_path.split(os.sep)
|
95 |
+
class_name = path_parts[-3]
|
96 |
+
correctness = path_parts[-2] # correct or mistake
|
97 |
+
assert correctness in ['correct', 'mistake'], "The image path should contain 'correct' or 'mistake'"
|
98 |
+
|
99 |
+
# Create the full save path under the explanations directory
|
100 |
+
save_path = os.path.join('explanations', class_name, correctness, os.path.basename(args.image_path))
|
101 |
+
os.makedirs(os.path.dirname(save_path), exist_ok=True)
|
102 |
+
|
103 |
+
# Save the explanation
|
104 |
+
plt.savefig(save_path, dpi=300)
|
105 |
+
print(f"Explanation saved at {save_path}")
|
106 |
+
else:
|
107 |
+
# make dir for storing the explanations and save it there with the same name as the image
|
108 |
+
os.makedirs("./explanations", exist_ok=True)
|
109 |
+
plt.savefig(f"./explanations/{os.path.basename(args.image_path)}")
|
110 |
+
print(f"Explanation saved at ./explanations/{os.path.basename(args.image_path)}")
|
utils/train_utils.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision import models
|
2 |
+
import torch.nn as nn
|
3 |
+
from tqdm import tqdm
|
4 |
+
import torch
|
5 |
+
import mlflow
|
6 |
+
from sklearn.metrics import precision_score, recall_score, f1_score, accuracy_score
|
7 |
+
from sklearn.ensemble import RandomForestClassifier
|
8 |
+
|
9 |
+
|
10 |
+
# Define the training loop
|
11 |
+
def train_one_epoch(model, trainloader, criterion, optimizer, device):
|
12 |
+
model.train()
|
13 |
+
running_loss = 0.0
|
14 |
+
correct = 0
|
15 |
+
total = 0
|
16 |
+
|
17 |
+
for images, labels in tqdm(trainloader, desc="Training", leave=False):
|
18 |
+
images, labels = images.to(device), labels.to(device)
|
19 |
+
|
20 |
+
# Forward pass
|
21 |
+
outputs = model(images)
|
22 |
+
loss = criterion(outputs, labels)
|
23 |
+
|
24 |
+
# Backward pass and optimization
|
25 |
+
optimizer.zero_grad()
|
26 |
+
loss.backward()
|
27 |
+
optimizer.step()
|
28 |
+
|
29 |
+
# Track loss and accuracy
|
30 |
+
running_loss += loss.item()
|
31 |
+
_, predicted = outputs.max(1)
|
32 |
+
correct += predicted.eq(labels).sum().item()
|
33 |
+
total += labels.size(0)
|
34 |
+
|
35 |
+
epoch_loss = running_loss / len(trainloader)
|
36 |
+
epoch_accuracy = 100.0 * correct / total
|
37 |
+
return epoch_loss, epoch_accuracy
|
38 |
+
|
39 |
+
|
40 |
+
# Define the evaluation loop
|
41 |
+
@torch.no_grad()
|
42 |
+
def evaluate(model, testloader, criterion, device):
|
43 |
+
model.eval()
|
44 |
+
running_loss = 0.0
|
45 |
+
correct = 0
|
46 |
+
total = 0
|
47 |
+
all_labels = []
|
48 |
+
all_predictions = []
|
49 |
+
|
50 |
+
for images, labels in tqdm(testloader, desc="Evaluating", leave=False):
|
51 |
+
images, labels = images.to(device), labels.to(device)
|
52 |
+
|
53 |
+
# Forward pass
|
54 |
+
outputs = model(images)
|
55 |
+
loss = criterion(outputs, labels)
|
56 |
+
|
57 |
+
# Track loss and accuracy
|
58 |
+
running_loss += loss.item()
|
59 |
+
_, predicted = outputs.max(1)
|
60 |
+
correct += predicted.eq(labels).sum().item()
|
61 |
+
total += labels.size(0)
|
62 |
+
|
63 |
+
# Collect all labels and predictions for metrics
|
64 |
+
all_labels.extend(labels.cpu().numpy())
|
65 |
+
all_predictions.extend(predicted.cpu().numpy())
|
66 |
+
|
67 |
+
epoch_loss = running_loss / len(testloader)
|
68 |
+
|
69 |
+
# Calculate accuracy, precision, recall, and F1-score
|
70 |
+
epoch_accuracy = accuracy_score(all_labels, all_predictions, normalize=True) * 100
|
71 |
+
precision = precision_score(all_labels, all_predictions, average="weighted")
|
72 |
+
recall = recall_score(all_labels, all_predictions, average="weighted")
|
73 |
+
f1 = f1_score(all_labels, all_predictions, average="weighted")
|
74 |
+
|
75 |
+
return epoch_loss, epoch_accuracy, precision, recall, f1
|
76 |
+
|
77 |
+
|
78 |
+
# Define the pipeline
|
79 |
+
def train_and_evaluate(
|
80 |
+
model,
|
81 |
+
trainloader,
|
82 |
+
testloader,
|
83 |
+
criterion,
|
84 |
+
optimizer,
|
85 |
+
device,
|
86 |
+
epochs,
|
87 |
+
use_mlflow=False,
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Train and evaluate the model.
|
91 |
+
|
92 |
+
Args:
|
93 |
+
model (nn.Module): The neural network model.
|
94 |
+
trainloader (DataLoader): DataLoader for training data.
|
95 |
+
testloader (DataLoader): DataLoader for test data.
|
96 |
+
criterion (nn.Module): Loss function.
|
97 |
+
optimizer (optim.Optimizer): Optimizer.
|
98 |
+
device (torch.device): Device to train on ('cuda' or 'cpu').
|
99 |
+
epochs (int): Number of epochs to train.
|
100 |
+
|
101 |
+
Returns:
|
102 |
+
dict: Training and evaluation statistics.
|
103 |
+
"""
|
104 |
+
history = {
|
105 |
+
"train_loss": [],
|
106 |
+
"train_acc": [],
|
107 |
+
"test_loss": [],
|
108 |
+
"test_acc": [],
|
109 |
+
"precision": [],
|
110 |
+
"recall": [],
|
111 |
+
"f1": [],
|
112 |
+
}
|
113 |
+
|
114 |
+
model.to(device)
|
115 |
+
|
116 |
+
for epoch in range(epochs):
|
117 |
+
print(f"Epoch {epoch + 1}/{epochs}")
|
118 |
+
|
119 |
+
# Train for one epoch
|
120 |
+
train_loss, train_acc = train_one_epoch(
|
121 |
+
model, trainloader, criterion, optimizer, device
|
122 |
+
)
|
123 |
+
print(f"Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.2f}%")
|
124 |
+
|
125 |
+
# Evaluate the model
|
126 |
+
test_loss, test_acc, precision, recall, f1 = evaluate(
|
127 |
+
model, testloader, criterion, device
|
128 |
+
)
|
129 |
+
print(f"Test Loss: {test_loss:.4f}, Test Accuracy: {test_acc:.2f}%")
|
130 |
+
|
131 |
+
# Save statistics
|
132 |
+
history["train_loss"].append(train_loss)
|
133 |
+
history["train_acc"].append(train_acc)
|
134 |
+
history["test_loss"].append(test_loss)
|
135 |
+
history["test_acc"].append(test_acc)
|
136 |
+
history["precision"].append(precision)
|
137 |
+
history["recall"].append(recall)
|
138 |
+
history["f1"].append(f1)
|
139 |
+
|
140 |
+
if use_mlflow:
|
141 |
+
mlflow.log_metric("epoch", epoch)
|
142 |
+
mlflow.log_metric("train_loss", train_loss)
|
143 |
+
mlflow.log_metric("train_acc", train_acc)
|
144 |
+
mlflow.log_metric("test_loss", test_loss)
|
145 |
+
mlflow.log_metric("test_acc", test_acc)
|
146 |
+
mlflow.log_metric("precision", precision)
|
147 |
+
mlflow.log_metric("recall", recall)
|
148 |
+
mlflow.log_metric("f1", f1)
|
149 |
+
return history
|
150 |
+
|
151 |
+
|
152 |
+
def set_parameter_requires_grad(model, feature_extracting):
|
153 |
+
if feature_extracting:
|
154 |
+
for param in model.parameters():
|
155 |
+
param.requires_grad = False
|
156 |
+
|
157 |
+
|
158 |
+
def initialize_model(
|
159 |
+
model_name,
|
160 |
+
num_classes,
|
161 |
+
feature_extract=True,
|
162 |
+
use_pretrained=True,
|
163 |
+
hidden_size=512,
|
164 |
+
image_shape=(224, 224, 3),
|
165 |
+
):
|
166 |
+
# Initialize these variables which will be set in this if statement. Each of these
|
167 |
+
# variables is model specific.
|
168 |
+
model_ft = None
|
169 |
+
|
170 |
+
if model_name == "resnet":
|
171 |
+
""" Resnet18
|
172 |
+
"""
|
173 |
+
model_ft = models.resnet18(pretrained=use_pretrained)
|
174 |
+
set_parameter_requires_grad(model_ft, feature_extract)
|
175 |
+
num_ftrs = model_ft.fc.in_features
|
176 |
+
model_ft.fc = nn.Linear(num_ftrs, num_classes)
|
177 |
+
|
178 |
+
elif model_name == "alexnet":
|
179 |
+
""" Alexnet
|
180 |
+
"""
|
181 |
+
model_ft = models.alexnet(pretrained=use_pretrained)
|
182 |
+
set_parameter_requires_grad(model_ft, feature_extract)
|
183 |
+
num_ftrs = model_ft.classifier[6].in_features
|
184 |
+
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
|
185 |
+
|
186 |
+
elif model_name == "vgg":
|
187 |
+
""" VGG11_bn
|
188 |
+
"""
|
189 |
+
model_ft = models.vgg11_bn(pretrained=use_pretrained)
|
190 |
+
set_parameter_requires_grad(model_ft, feature_extract)
|
191 |
+
num_ftrs = model_ft.classifier[6].in_features
|
192 |
+
model_ft.classifier[6] = nn.Linear(num_ftrs, num_classes)
|
193 |
+
|
194 |
+
elif model_name == "squeezenet":
|
195 |
+
""" Squeezenet
|
196 |
+
"""
|
197 |
+
model_ft = models.squeezenet1_0(pretrained=use_pretrained)
|
198 |
+
set_parameter_requires_grad(model_ft, feature_extract)
|
199 |
+
model_ft.classifier[1] = nn.Conv2d(
|
200 |
+
512, num_classes, kernel_size=(1, 1), stride=(1, 1)
|
201 |
+
)
|
202 |
+
model_ft.num_classes = num_classes
|
203 |
+
|
204 |
+
elif model_name == "densenet":
|
205 |
+
""" Densenet
|
206 |
+
"""
|
207 |
+
model_ft = models.densenet121(pretrained=use_pretrained)
|
208 |
+
set_parameter_requires_grad(model_ft, feature_extract)
|
209 |
+
num_ftrs = model_ft.classifier.in_features
|
210 |
+
model_ft.classifier = nn.Linear(num_ftrs, num_classes)
|
211 |
+
|
212 |
+
elif model_name == "custom_mlp":
|
213 |
+
""" Custom MLP
|
214 |
+
"""
|
215 |
+
model_ft = nn.Sequential(
|
216 |
+
nn.Linear(image_shape[0] * image_shape[1] * image_shape[2], hidden_size),
|
217 |
+
nn.ReLU(),
|
218 |
+
nn.Linear(hidden_size, hidden_size),
|
219 |
+
nn.ReLU(),
|
220 |
+
nn.Linear(hidden_size, hidden_size // 2),
|
221 |
+
nn.ReLU(),
|
222 |
+
nn.Linear(hidden_size // 2, num_classes),
|
223 |
+
)
|
224 |
+
elif model_name == "custom_cnn":
|
225 |
+
""" Custom CNN
|
226 |
+
"""
|
227 |
+
model_ft = nn.Sequential(
|
228 |
+
nn.Conv2d(3, 16, 3, 1, 1),
|
229 |
+
nn.ReLU(),
|
230 |
+
nn.MaxPool2d(2),
|
231 |
+
nn.Conv2d(16, 32, 3, 1, 1),
|
232 |
+
nn.ReLU(),
|
233 |
+
nn.MaxPool2d(2),
|
234 |
+
nn.Conv2d(32, 64, 3, 1, 1),
|
235 |
+
nn.ReLU(),
|
236 |
+
nn.MaxPool2d(2),
|
237 |
+
nn.Flatten(),
|
238 |
+
nn.Linear(64 * 28 * 28, hidden_size),
|
239 |
+
nn.ReLU(),
|
240 |
+
nn.Linear(hidden_size, num_classes),
|
241 |
+
)
|
242 |
+
elif model_name == "random_forest":
|
243 |
+
""" Random Forest
|
244 |
+
"""
|
245 |
+
model_ft = RandomForestClassifier(n_estimators=100, random_state=42)
|
246 |
+
|
247 |
+
else:
|
248 |
+
print("Invalid model name, exiting...")
|
249 |
+
exit()
|
250 |
+
|
251 |
+
return model_ft
|