YannisK commited on
Commit
dd504c0
1 Parent(s): 70e3f4a
Files changed (1) hide show
  1. app.py +15 -12
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import gradio as gr
2
 
3
- import torch
4
 
 
5
 
6
  import matplotlib.pyplot as plt
7
  from matplotlib import cm
@@ -14,8 +15,9 @@ import fire_network
14
 
15
  import numpy as np
16
 
 
 
17
  from PIL import Image
18
- from skimage.transform import resize
19
 
20
  # Possible Scales for multiscale inference
21
  scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
@@ -78,17 +80,17 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
78
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
79
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
80
  all_att_bin2.append(att_heat_bin)
81
- print(all_att_bin2[0].shape)
82
 
83
  fin_img = []
84
  img1rsz = np.copy(im1)
85
- print(img1rsz.shape)
86
  for j, att in enumerate(all_att_bin1):
87
- # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
88
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
89
  # att = cv2.resize(att, imgz[i].shape[:2][::-1])
90
- att = resize(att, im1.shape[:2])
91
- print(att.shape)
92
  mask2d = zip(*np.where(att==255))
93
  for m,n in mask2d:
94
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
@@ -99,10 +101,11 @@ def generate_matching_superfeatures(im1, im2, scale_id=6, threshold=50):
99
 
100
  img2rsz = np.copy(im2)
101
  for j, att in enumerate(all_att_bin2):
102
- # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_NEAREST)
103
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
104
- # att = cv2.resize(att, imgz[i].shape[:2][::-1])
105
- att = resize(att, im2.shape[:2])
 
106
  mask2d = zip(*np.where(att==255))
107
  for m,n in mask2d:
108
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
@@ -132,8 +135,8 @@ article = "<p style='text-align: center'><a href='https://github.com/naver/fire'
132
  iface = gr.Interface(
133
  fn=generate_matching_superfeatures,
134
  inputs=[
135
- gr.inputs.Image(shape=(1024, 1024), type="numpy"),
136
- gr.inputs.Image(shape=(1024, 1024), type="numpy"),
137
  gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
138
  gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
139
  outputs="plot",
 
1
  import gradio as gr
2
 
3
+ import cv2
4
 
5
+ import torch
6
 
7
  import matplotlib.pyplot as plt
8
  from matplotlib import cm
 
15
 
16
  import numpy as np
17
 
18
+
19
+
20
  from PIL import Image
 
21
 
22
  # Possible Scales for multiscale inference
23
  scales = [2.0, 1.414, 1.0, 0.707, 0.5, 0.353, 0.25]
 
80
  att_heat = np.uint8(att_heat / np.max(att_heat[:]) * 255.0)
81
  att_heat_bin = np.where(att_heat>threshold, 255, 0)
82
  all_att_bin2.append(att_heat_bin)
83
+
84
 
85
  fin_img = []
86
  img1rsz = np.copy(im1)
87
+ print(img1rsz.size)
88
  for j, att in enumerate(all_att_bin1):
89
+ att = cv2.resize(att, im1.size, interpolation=cv2.INTER_NEAREST)
90
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
91
  # att = cv2.resize(att, imgz[i].shape[:2][::-1])
92
+ # att = att.resize(shape)
93
+ # att = resize(att, im1.size)
94
  mask2d = zip(*np.where(att==255))
95
  for m,n in mask2d:
96
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
 
101
 
102
  img2rsz = np.copy(im2)
103
  for j, att in enumerate(all_att_bin2):
104
+ att = cv2.resize(att, im2.size, interpolation=cv2.INTER_NEAREST)
105
  # att = cv2.resize(att, imgz[i].shape[:2][::-1], interpolation=cv2.INTER_CUBIC)
106
+ # # att = cv2.resize(att, imgz[i].shape[:2][::-1])
107
+ # att = att.resize(im2.shape)
108
+ # print('att:', att.shape)
109
  mask2d = zip(*np.where(att==255))
110
  for m,n in mask2d:
111
  col_ = col.colors[j] if j < 7 else col.colors[j+1]
 
135
  iface = gr.Interface(
136
  fn=generate_matching_superfeatures,
137
  inputs=[
138
+ gr.inputs.Image(shape=(1024, 1024), type="pil"),
139
+ gr.inputs.Image(shape=(1024, 1024), type="pil"),
140
  gr.inputs.Slider(minimum=1, maximum=7, step=1, default=2, label="Scale"),
141
  gr.inputs.Slider(minimum=1, maximum=255, step=25, default=50, label="Binarizatio Threshold")],
142
  outputs="plot",