Update app.py
Browse files
app.py
CHANGED
@@ -7,6 +7,52 @@ class Hook():
|
|
7 |
|
8 |
learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
|
11 |
def classify_img(img):
|
12 |
pred,idx,probs=learn.predict(img)
|
@@ -23,6 +69,6 @@ examples=["filibe-1-1.jpg",
|
|
23 |
"varna-1-1.jpg"]
|
24 |
|
25 |
|
26 |
-
demo = gr.Interface(fn=
|
27 |
|
28 |
demo.launch(inline=False)
|
|
|
7 |
|
8 |
learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
|
9 |
|
10 |
+
#@title DataLoader
|
11 |
+
path = "/content/drive/My Drive/arc_project/dataset/train_val_cropped"
|
12 |
+
dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
|
13 |
+
get_items = get_image_files,
|
14 |
+
splitter = GrandparentSplitter(train_name="train",
|
15 |
+
valid_name="valid"),
|
16 |
+
get_y=parent_label,
|
17 |
+
item_tfms=RandomResizedCrop(128, min_scale=0.7),
|
18 |
+
batch_tfms=[*aug_transforms(),
|
19 |
+
Normalize.from_stats(*imagenet_stats)])
|
20 |
+
dls_augmented = dblock.dataloaders(path, shuffle=True)
|
21 |
+
dls_augmented.train.show_batch(max_n=8, nrows=2, figsize=(28,10))
|
22 |
+
|
23 |
+
class Hook():
|
24 |
+
def hook_func(self, m, i, o): self.stored = o.detach().clone()
|
25 |
+
|
26 |
+
def gradcam(img_create):
|
27 |
+
|
28 |
+
pred,idx,probs=learn.predict(img_create)
|
29 |
+
x,= first(dls_augmented.test_dl([img_create]))
|
30 |
+
hook_output = Hook()
|
31 |
+
hook = learn.model[0].register_forward_hook(hook_output.hook_func)
|
32 |
+
with torch.no_grad(): output = learn.model.eval()(x)
|
33 |
+
act = hook_output.stored[0]
|
34 |
+
hook.remove()
|
35 |
+
|
36 |
+
input_size=act.shape[0]
|
37 |
+
out_size=learn.model[1][-1].in_features
|
38 |
+
kernel_size=act.shape[1]
|
39 |
+
a=act
|
40 |
+
new_act=tensor(np.zeros((out_size,kernel_size,kernel_size)))
|
41 |
+
sum=tensor(np.zeros((1,kernel_size,kernel_size)))
|
42 |
+
for i in range(0,input_size,4):
|
43 |
+
sum=tensor(np.zeros((1,kernel_size,kernel_size)))
|
44 |
+
for j in range(i,i+4):
|
45 |
+
sum=sum+act[j,:,:]
|
46 |
+
new_act[int(i/4),:,:]=sum/4
|
47 |
+
cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act)
|
48 |
+
|
49 |
+
x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0])
|
50 |
+
fig,ax = plt.subplots()
|
51 |
+
x_dec.show(ctx=ax)
|
52 |
+
ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma');
|
53 |
+
|
54 |
+
return dict(zip(categories, map(float, probs))), fig
|
55 |
+
|
56 |
categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
|
57 |
def classify_img(img):
|
58 |
pred,idx,probs=learn.predict(img)
|
|
|
69 |
"varna-1-1.jpg"]
|
70 |
|
71 |
|
72 |
+
demo = gr.Interface(fn=gradcam, inputs=image, outputs=[label, image], examples=examples)
|
73 |
|
74 |
demo.launch(inline=False)
|