YannisK commited on
Commit
402b433
1 Parent(s): c712472
Files changed (1) hide show
  1. app.py +74 -16
app.py CHANGED
@@ -2,12 +2,18 @@ import gradio as gr
2
 
3
  import torch
4
 
 
5
  import matplotlib.pyplot as plt
 
 
 
6
 
7
  from torchvision import transforms
8
 
9
  import fire_network
10
 
 
 
11
  # Possible Scales for multiscale inference
12
  scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
13
 
@@ -42,23 +48,75 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
42
 
43
  # extract features
44
  with torch.no_grad():
45
- output1 = net.get_superfeatures(im1_tensor.to(device), scales=scales)
46
- feats1 = output1[0]
47
- attns1 = output1[1]
48
- strenghts1 = output1[2]
49
-
50
- output2 = net.get_superfeatures(im2_tensor.to(device), scales=scales)
51
- feats2 = output2[0]
52
- attns2 = output2[1]
53
- strenghts2 = output2[2]
54
-
55
- print(len(feats1))
56
- # print(feats1.shape)
57
- print(feats1[0].shape)
58
- print(attns1[0].shape)
59
- # print(attns1.shape)
60
- # print(strenghts1.shape)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
 
 
 
 
 
 
 
 
 
62
 
63
 
64
  # GRADIO APP
 
2
 
3
  import torch
4
 
5
+
6
  import matplotlib.pyplot as plt
7
+ from matplotlib import cm
8
+ from matplotlib import colors
9
+ from mpl_toolkits.axes_grid1 import ImageGrid
10
 
11
  from torchvision import transforms
12
 
13
  import fire_network
14
 
15
+ from PIL import Image
16
+
17
  # Possible Scales for multiscale inference
18
  scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
19
 
 
48
 
49
  # extract features
50
  with torch.no_grad():
51
+ output1 = net.get_superfeatures(im1_tensor.to(device), scales=[scale_id])
52
+ feats1 = output1[0][0]
53
+ attns1 = output1[1][0]
54
+ strenghts1 = output1[2][0]
55
+
56
+ output2 = net.get_superfeatures(im2_tensor.to(device), scales=[scale_id])
57
+ feats2 = output2[0][0]
58
+ attns2 = output2[1][0]
59
+ strenghts2 = output2[2][0]
60
+
61
+ print(feats1.shape, feats2.shape)
62
+ print(attns1.shape, attns2.shape)
63
+ print(strenghts1.shape, strenghts2.shape)
64
+
65
+ # Store all binary SF att maps to show them all at once in the end
66
+ all_att_bin1 = defaultdict(list)
67
+ all_att_bin2 = defaultdict(list)
68
+ for n, i in enumerate(sf_idx_):
69
+ # all_atts[n].append(attn[j][scale_id][0,i,:,:].numpy())
70
+ att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
71
+ att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
72
+ att_heat_bin = np.where(att_heat>threshold, 255, 0)
73
+ all_att_bin1.append(att_heat_bin)
74
+
75
+ att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
76
+ att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
77
+ att_heat_bin = np.where(att_heat>threshold, 255, 0)
78
+ all_att_bin2.append(att_heat_bin)
79
+
80
+
81
+ fin_img = []
82
+ img1rsz = np.copy(im1)
83
+ for j, att in enumerate(all_att_bin1):
84
+ # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
85
+ # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
86
+ # att = cv2.resize(att, imgz[i].shape[:2][::-1])
87
+ att = att.resize(im1.size)
88
+ mask2d = zip(*np.where(att==255))
89
+ for m,n in mask2d:
90
+ col_ = col.colors[j] if j < 7 else col.colors[j+1]
91
+ if j == 0: col_ = col.colors[9]
92
+ col_ = 255*np.array(colors.to_rgba(col_))[:3]
93
+ img1rsz[m,n, :] = col_[::-1]
94
+ fin_img.append(img1rsz)
95
+
96
+ img2rsz = np.copy(im2)
97
+ for j, att in enumerate(all_att_bin2):
98
+ # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
99
+ # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
100
+ # att = cv2.resize(att, imgz[i].shape[:2][::-1])
101
+ att = att.resize(im2.size)
102
+ mask2d = zip(*np.where(att==255))
103
+ for m,n in mask2d:
104
+ col_ = col.colors[j] if j < 7 else col.colors[j+1]
105
+ if j == 0: col_ = col.colors[9]
106
+ col_ = 255*np.array(colors.to_rgba(col_))[:3]
107
+ img2rsz[m,n, :] = col_[::-1]
108
+ fin_img.append(img2rsz)
109
+
110
 
111
+ fig = plt.figure(figsize=(12,25))
112
+ grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1)
113
+ for ax, img in zip(grid, fin_img):
114
+ ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
115
+ ax.axis('scaled')
116
+ ax.axis('off')
117
+ plt.tight_layout()
118
+ fig.suptitle("Matching SFs", fontsize=16)
119
+ return fig
120
 
121
 
122
  # GRADIO APP