YannisK commited on
Commit
0f77bb9
1 Parent(s): 4789885
Files changed (1) hide show
  1. app.py +43 -22
app.py CHANGED
@@ -37,13 +37,21 @@ transform = transforms.Compose([
37
  ])
38
 
39
 
40
- # which sf
41
- sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
42
 
43
  col = plt.get_cmap('tab10')
44
 
45
- def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
46
-
 
 
 
 
 
 
 
 
 
47
  im1_tensor = transform(im1).unsqueeze(0)
48
  im2_tensor = transform(im2).unsqueeze(0)
49
 
@@ -74,7 +82,7 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
74
  att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
75
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
76
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
77
- print(att_heat_bin)
78
  all_att_bin1.append(att_heat_bin)
79
 
80
  att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
@@ -86,7 +94,7 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
86
  fin_img = []
87
  img1rsz = np.copy(im1_cv)
88
  print('im1:', im1.size)
89
- print(img1rsz.shape)
90
  for j, att in enumerate(all_att_bin1):
91
  att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
92
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
@@ -102,6 +110,8 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
102
  fin_img.append(img1rsz)
103
 
104
  img2rsz = np.copy(im2_cv)
 
 
105
  for j, att in enumerate(all_att_bin2):
106
  att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
107
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
@@ -116,19 +126,21 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
116
  img2rsz[m,n, :] = col_[::-1]
117
  fin_img.append(img2rsz)
118
 
119
- fig1 = plt.figure()
120
- fig1.show(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
121
  ax1 = plt.gca()
122
- ax1.axis('scaled')
123
  ax1.axis('off')
124
-
125
  plt.tight_layout()
 
126
 
127
- fig2 = plt.figure()
128
- fig2.show(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
129
  ax2 = plt.gca()
130
- ax2.axis('scaled')
131
  ax2.axis('off')
 
 
132
 
133
  # fig = plt.figure()
134
  # grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1)
@@ -143,7 +155,7 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
143
  # # Now we can save it to a numpy array.
144
  # data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
145
  # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
146
- return fig1,fig2
147
 
148
 
149
  # GRADIO APP
@@ -155,25 +167,34 @@ article = "<p style='text-align: center'><a href='https://github.com/naver/fire'
155
  # css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
156
  # css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
157
  # css = ".output_image, .input_image {hieght: 1000px !important}"
158
- css = ".input_image {height: 600px !important; width: 600px !important;} .output_image {height: 1200px !important; width: 600px !important;}"
159
  # css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
160
 
161
 
162
  iface = gr.Interface(
163
  fn=generate_matching_superfeatures,
164
  inputs=[
165
- gr.inputs.Image(shape=(1024, 1024), type="pil"),
166
- gr.inputs.Image(shape=(1024, 1024), type="pil"),
167
- gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="Scale"),
168
- gr.inputs.Slider(minimum=1, maximum=255, step=25, default=100, label="Binarization Threshold")],
169
- outputs=["plot", "plot"],
 
 
 
 
 
 
170
  # outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
171
  title=title,
172
- theme='dark-peach',
173
  layout="horizontal",
174
  description=description,
175
  article=article,
176
  css=css,
177
- examples=[["chateau_1.png", "chateau_2.png", 3, 100]],
 
 
 
178
  )
179
  iface.launch(enable_queue=True)
 
37
  ])
38
 
39
 
40
+ # sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
 
41
 
42
  col = plt.get_cmap('tab10')
43
 
44
+ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50, sf_ids=''):
45
+ print('im1:', im1.size)
46
+ print('im2:', im2.size)
47
+ # which sf
48
+ sf_idx_ = [55, 14, 5, 4, 52, 57, 40, 9]
49
+ if sf_ids.lower().startswith('r'):
50
+ n_sf_ids = int(sf_ids[1:])
51
+ sf_idx_ = np.random.randint(256, size=n_sf_ids)
52
+ elif sf_ids != '':
53
+ sf_idx_ = map(int, sf_ids.strip().split(','))
54
+
55
  im1_tensor = transform(im1).unsqueeze(0)
56
  im2_tensor = transform(im2).unsqueeze(0)
57
 
 
82
  att_heat = np.array(attns1[0,i,:,:].numpy(), dtype=np.float32)
83
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
84
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
85
+ # print(att_heat_bin)
86
  all_att_bin1.append(att_heat_bin)
87
 
88
  att_heat = np.array(attns2[0,i,:,:].numpy(), dtype=np.float32)
 
94
  fin_img = []
95
  img1rsz = np.copy(im1_cv)
96
  print('im1:', im1.size)
97
+ print('img1rsz:', img1rsz.shape)
98
  for j, att in enumerate(all_att_bin1):
99
  att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
100
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
 
110
  fin_img.append(img1rsz)
111
 
112
  img2rsz = np.copy(im2_cv)
113
+ print('im2:', im2.size)
114
+ print('img2rsz:', img2rsz.shape)
115
  for j, att in enumerate(all_att_bin2):
116
  att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
117
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
 
126
  img2rsz[m,n, :] = col_[::-1]
127
  fin_img.append(img2rsz)
128
 
129
+ fig1 = plt.figure(1)
130
+ plt.imshow(cv2.cvtColor(img1rsz, cv2.COLOR_BGR2RGB))
131
  ax1 = plt.gca()
132
+ # ax1.axis('scaled')
133
  ax1.axis('off')
 
134
  plt.tight_layout()
135
+ # fig1.canvas.draw()
136
 
137
+ fig2 = plt.figure(2)
138
+ plt.imshow(cv2.cvtColor(img2rsz, cv2.COLOR_BGR2RGB))
139
  ax2 = plt.gca()
140
+ # ax2.axis('scaled')
141
  ax2.axis('off')
142
+ plt.tight_layout()
143
+ # fig2.canvas.draw()
144
 
145
  # fig = plt.figure()
146
  # grid = ImageGrid(fig, 111, nrows_ncols=(2, 1), axes_pad=0.1)
 
155
  # # Now we can save it to a numpy array.
156
  # data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
157
  # data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
158
+ return fig1, fig2, ','.join(map(str, sf_idx_))
159
 
160
 
161
  # GRADIO APP
 
167
  # css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
168
  # css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
169
  # css = ".output_image, .input_image {hieght: 1000px !important}"
170
+ css = ".input_image, .input_image {height: 600px !important; width: 600px !important;} "
171
  # css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
172
 
173
 
174
  iface = gr.Interface(
175
  fn=generate_matching_superfeatures,
176
  inputs=[
177
+ # gr.inputs.Image(shape=(1024, 1024), type="pil", label="First Image"),
178
+ # gr.inputs.Image(shape=(1024, 1024), type="pil", label="Second Image"),
179
+ gr.inputs.Image(type="pil", label="First Image"),
180
+ gr.inputs.Image(type="pil", label="Second Image"),
181
+ gr.inputs.Slider(minimum=0, maximum=6, step=1, default=2, label="Scale"),
182
+ gr.inputs.Slider(minimum=1, maximum=255, step=25, default=100, label="Binarization Threshold"),
183
+ gr.inputs.Textbox(lines=1, default="", label="SF IDs to show (comma separated numbers from 0-255; typing 'rX' will return X random SFs", optional=True)],
184
+ outputs=[
185
+ "plot",
186
+ "plot",
187
+ gr.outputs.Textbox(label="SFs")],
188
  # outputs=gr.outputs.Image(shape=(1024,2048), type="plot"),
189
  title=title,
190
+ theme='peach',
191
  layout="horizontal",
192
  description=description,
193
  article=article,
194
  css=css,
195
+ examples=[
196
+ ["chateau_1.png", "chateau_2.png", 2, 100, '55,14,5,4,52,57,40,9'],
197
+ ["anafi1.jpeg", "anafi2.jpeg", 4, 50, '99,100,142,213,236']
198
+ ],
199
  )
200
  iface.launch(enable_queue=True)