taesiri commited on
Commit
45fb2aa
β€’
1 Parent(s): 999140c
Files changed (3) hide show
  1. README.md +1 -1
  2. app.py +147 -60
  3. requirements.txt +0 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: πŸ“š
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.5
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: red
5
  colorTo: yellow
6
  sdk: gradio
7
+ sdk_version: 3.30.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py CHANGED
@@ -139,44 +139,59 @@ def run_chm(
139
  for x, y in zip(tgt_points[0], tgt_points[1]):
140
  tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)])
141
 
 
142
  # PLOT
143
  fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
144
 
 
145
  ax[0].imshow(display_transform(source_image))
146
  ax[0].scatter(
147
  src_points_converted[:, 0],
148
  src_points_converted[:, 1],
149
- c=colors[:number_src_points],
 
 
 
150
  )
151
- ax[0].set_title("Source")
152
  ax[0].set_xticks([])
153
  ax[0].set_yticks([])
154
 
 
155
  ax[1].imshow(display_transform(target_image))
156
  ax[1].scatter(
157
  tgt_points_converted[:, 0],
158
  tgt_points_converted[:, 1],
159
- c=colors[:number_src_points],
 
 
 
160
  )
161
- ax[1].set_title("Target")
162
  ax[1].set_xticks([])
163
  ax[1].set_yticks([])
164
 
165
- for TL in range(49):
166
- ax[0].text(
167
- x=src_points_converted[TL][0],
168
- y=src_points_converted[TL][1],
169
- s=str(TL),
170
- fontdict=dict(color="red", size=11),
 
 
 
 
 
 
 
 
 
171
  )
 
172
 
173
- for TL in range(49):
174
- ax[1].text(
175
- x=tgt_points_converted[TL][0],
176
- y=tgt_points_converted[TL][1],
177
- s=f"{str(TL)}",
178
- fontdict=dict(color="orange", size=11),
179
- )
180
 
181
  plt.tight_layout()
182
  fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16)
@@ -201,44 +216,41 @@ def generate_correspondences(
201
  )
202
 
203
 
204
- # Gradio App
205
- main = gr.Interface(
206
- fn=generate_correspondences,
207
- inputs=[
208
- gr.Image(shape=(240, 240), type="pil"),
209
- gr.Image(shape=(240, 240), type="pil"),
210
- gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"),
211
- gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"),
212
- gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"),
213
- gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"),
214
- ],
215
- allow_flagging="never",
216
- outputs="plot",
217
- examples=[
218
- ["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223],
219
- [
220
- "./examples/Red_Winged_Blackbird_0012_6015.jpg",
221
- "./examples/Red_Winged_Blackbird_0025_5342.jpg",
222
- 17,
223
- 223,
224
- 17,
225
- 223,
226
- ],
227
- [
228
- "./examples/Yellow_Headed_Blackbird_0026_8545.jpg",
229
- "./examples/Yellow_Headed_Blackbird_0020_8549.jpg",
230
- 17,
231
- 223,
232
- 17,
233
- 223,
234
- ],
235
- ],
236
- )
237
-
238
-
239
- blocks = gr.Blocks()
240
- with blocks:
241
-
242
  gr.Markdown(
243
  """
244
  # Correspondence Matching with Convolutional Hough Matching Networks
@@ -247,10 +259,85 @@ Performs keypoint transform from a 7x7 gird on the source image to the target im
247
  """
248
  )
249
 
250
- gr.TabbedInterface([main], ["Main"])
 
 
 
 
 
 
 
 
 
 
 
 
 
251
 
 
 
 
 
 
 
 
 
 
252
 
253
- blocks.launch(
254
- debug=True,
255
- enable_queue=False,
256
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  for x, y in zip(tgt_points[0], tgt_points[1]):
140
  tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)])
141
 
142
+ # VISUALIZATION
143
  # PLOT
144
  fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
145
 
146
+ # Source image plot
147
  ax[0].imshow(display_transform(source_image))
148
  ax[0].scatter(
149
  src_points_converted[:, 0],
150
  src_points_converted[:, 1],
151
+ c="blue",
152
+ edgecolors="white",
153
+ s=50,
154
+ label="Source points",
155
  )
156
+ ax[0].set_title("Source Image with Selected Points")
157
  ax[0].set_xticks([])
158
  ax[0].set_yticks([])
159
 
160
+ # Target image plot
161
  ax[1].imshow(display_transform(target_image))
162
  ax[1].scatter(
163
  tgt_points_converted[:, 0],
164
  tgt_points_converted[:, 1],
165
+ c="red",
166
+ edgecolors="white",
167
+ s=50,
168
+ label="Target points",
169
  )
170
+ ax[1].set_title("Target Image with Corresponding Points")
171
  ax[1].set_xticks([])
172
  ax[1].set_yticks([])
173
 
174
+ # Adding labels to points
175
+ for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)):
176
+ ax[0].text(*src, str(i), color="white", bbox=dict(facecolor="black", alpha=0.5))
177
+ ax[1].text(*tgt, str(i), color="black", bbox=dict(facecolor="white", alpha=0.7))
178
+
179
+ # Drawing lines between corresponding source and target points
180
+ for src, tgt in zip(src_points_converted, tgt_points_converted):
181
+ con = ConnectionPatch(
182
+ xyA=tgt,
183
+ xyB=src,
184
+ coordsA="data",
185
+ coordsB="data",
186
+ axesA=ax[1],
187
+ axesB=ax[0],
188
+ color="green",
189
  )
190
+ ax[1].add_artist(con)
191
 
192
+ # Adding legend
193
+ ax[0].legend()
194
+ ax[1].legend()
 
 
 
 
195
 
196
  plt.tight_layout()
197
  fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16)
 
216
  )
217
 
218
 
219
+ # # Gradio App
220
+ # main = gr.Interface(
221
+ # fn=generate_correspondences,
222
+ # inputs=[
223
+ # gr.Image(shape=(240, 240), type="pil"),
224
+ # gr.Image(shape=(240, 240), type="pil"),
225
+ # gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"),
226
+ # gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"),
227
+ # gr.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"),
228
+ # gr.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"),
229
+ # ],
230
+ # allow_flagging="never",
231
+ # outputs="plot",
232
+ # examples=[
233
+ # ["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223],
234
+ # [
235
+ # "./examples/Red_Winged_Blackbird_0012_6015.jpg",
236
+ # "./examples/Red_Winged_Blackbird_0025_5342.jpg",
237
+ # 17,
238
+ # 223,
239
+ # 17,
240
+ # 223,
241
+ # ],
242
+ # [
243
+ # "./examples/Yellow_Headed_Blackbird_0026_8545.jpg",
244
+ # "./examples/Yellow_Headed_Blackbird_0020_8549.jpg",
245
+ # 17,
246
+ # 223,
247
+ # 17,
248
+ # 223,
249
+ # ],
250
+ # ],
251
+ # )
252
+
253
+ with gr.Blocks() as demo:
 
 
 
254
  gr.Markdown(
255
  """
256
  # Correspondence Matching with Convolutional Hough Matching Networks
 
259
  """
260
  )
261
 
262
+ with gr.Row():
263
+ # Add an Image component to display the source image.
264
+ image1 = gr.Image(
265
+ shape=(240, 240),
266
+ type="pil",
267
+ label="Source Image",
268
+ )
269
+
270
+ # Add an Image component to display the target image.
271
+ image2 = gr.Image(
272
+ shape=(240, 240),
273
+ type="pil",
274
+ label="Target Image",
275
+ )
276
 
277
+ with gr.Row():
278
+ # Add a Slider component to adjust the minimum x-coordinate of the grid.
279
+ min_x = gr.Slider(
280
+ minimum=1,
281
+ maximum=240,
282
+ step=1,
283
+ default=15,
284
+ label="Min X",
285
+ )
286
 
287
+ # Add a Slider component to adjust the maximum x-coordinate of the grid.
288
+ max_x = gr.Slider(
289
+ minimum=1,
290
+ maximum=240,
291
+ step=1,
292
+ default=215,
293
+ label="Max X",
294
+ )
295
+
296
+ # Add a Slider component to adjust the minimum y-coordinate of the grid.
297
+ min_y = gr.Slider(
298
+ minimum=1,
299
+ maximum=240,
300
+ step=1,
301
+ default=15,
302
+ label="Min Y",
303
+ )
304
+
305
+ # Add a Slider component to adjust the maximum y-coordinate of the grid.
306
+ max_y = gr.Slider(
307
+ minimum=1,
308
+ maximum=240,
309
+ step=1,
310
+ default=215,
311
+ label="Max Y",
312
+ )
313
+
314
+ with gr.Row():
315
+ output_plot = gr.Plot(
316
+ type="plot",
317
+ label="Output Plot",
318
+ )
319
+
320
+ gr.Examples(
321
+ [
322
+ ["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223],
323
+ ],
324
+ inputs=[
325
+ image1,
326
+ image2,
327
+ min_x,
328
+ max_x,
329
+ min_y,
330
+ max_y,
331
+ ],
332
+ )
333
+
334
+ # Add a Button component to run the app.
335
+ run_btn = gr.Button("Run")
336
+
337
+ run_btn.click(
338
+ generate_correspondences,
339
+ inputs=[image1, image2, min_x, max_x, min_y, max_y],
340
+ outputs=output_plot,
341
+ )
342
+
343
+ demo.launch(debug=True, enable_queue=False)
requirements.txt CHANGED
@@ -1,4 +1,3 @@
1
- gradio==3.0.5
2
  pandas==1.3.4
3
  requests==2.26.0
4
  scipy==1.7.1
 
 
1
  pandas==1.3.4
2
  requests==2.26.0
3
  scipy==1.7.1