khan994 commited on
Commit
b00a19c
·
1 Parent(s): 7797488

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -1
app.py CHANGED
@@ -7,6 +7,52 @@ class Hook():
7
 
8
  learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
11
  def classify_img(img):
12
  pred,idx,probs=learn.predict(img)
@@ -23,6 +69,6 @@ examples=["filibe-1-1.jpg",
23
  "varna-1-1.jpg"]
24
 
25
 
26
- demo = gr.Interface(fn=classify_img, inputs=image, outputs=label, examples=examples)
27
 
28
  demo.launch(inline=False)
 
7
 
8
  learn = load_learner("resnet152_fit_one_cycle_freeze_91acc.pkl", cpu=True)
9
 
10
+ #@title DataLoader
11
+ path = "/content/drive/My Drive/arc_project/dataset/train_val_cropped"
12
+ dblock = DataBlock(blocks = (ImageBlock, CategoryBlock),
13
+ get_items = get_image_files,
14
+ splitter = GrandparentSplitter(train_name="train",
15
+ valid_name="valid"),
16
+ get_y=parent_label,
17
+ item_tfms=RandomResizedCrop(128, min_scale=0.7),
18
+ batch_tfms=[*aug_transforms(),
19
+ Normalize.from_stats(*imagenet_stats)])
20
+ dls_augmented = dblock.dataloaders(path, shuffle=True)
21
+ dls_augmented.train.show_batch(max_n=8, nrows=2, figsize=(28,10))
22
+
23
+ class Hook():
24
+ def hook_func(self, m, i, o): self.stored = o.detach().clone()
25
+
26
+ def gradcam(img_create):
27
+
28
+ pred,idx,probs=learn.predict(img_create)
29
+ x,= first(dls_augmented.test_dl([img_create]))
30
+ hook_output = Hook()
31
+ hook = learn.model[0].register_forward_hook(hook_output.hook_func)
32
+ with torch.no_grad(): output = learn.model.eval()(x)
33
+ act = hook_output.stored[0]
34
+ hook.remove()
35
+
36
+ input_size=act.shape[0]
37
+ out_size=learn.model[1][-1].in_features
38
+ kernel_size=act.shape[1]
39
+ a=act
40
+ new_act=tensor(np.zeros((out_size,kernel_size,kernel_size)))
41
+ sum=tensor(np.zeros((1,kernel_size,kernel_size)))
42
+ for i in range(0,input_size,4):
43
+ sum=tensor(np.zeros((1,kernel_size,kernel_size)))
44
+ for j in range(i,i+4):
45
+ sum=sum+act[j,:,:]
46
+ new_act[int(i/4),:,:]=sum/4
47
+ cam_map = torch.einsum('ck,kij->cij', learn.model[1][-1].weight, new_act)
48
+
49
+ x_dec = TensorImage(dls_augmented.train.decode((x,))[0][0])
50
+ fig,ax = plt.subplots()
51
+ x_dec.show(ctx=ax)
52
+ ax.imshow(cam_map[1].detach().cpu(), alpha=0.6, extent=(0,128,128,0), interpolation='bilinear', cmap='magma');
53
+
54
+ return dict(zip(categories, map(float, probs))), fig
55
+
56
  categories = ('arbanasi', 'filibe', 'gjirokoster', 'iskodra', 'kula', 'kuzguncuk', 'larissa_ampelakia', 'mardin', 'ohrid', 'pristina', 'safranbolu', 'selanik', 'sozopol_suzebolu', 'tiran', 'varna')
57
  def classify_img(img):
58
  pred,idx,probs=learn.predict(img)
 
69
  "varna-1-1.jpg"]
70
 
71
 
72
+ demo = gr.Interface(fn=gradcam, inputs=image, outputs=[label, image], examples=examples)
73
 
74
  demo.launch(inline=False)