taesiri commited on
Commit
1bdb168
1 Parent(s): cb998cb
Files changed (2) hide show
  1. app.py +162 -108
  2. requirements.txt +1 -0
app.py CHANGED
@@ -22,123 +22,170 @@ import torch.nn as nn
22
  import torch.nn.functional as F
23
  import random
24
  import gradio as gr
 
25
 
26
  # Downloading the Model
27
- torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt')
28
 
29
  # Model Initialization
30
- args = dict({
31
- 'alpha' : [0.05, 0.1],
32
- 'benchmark':'pfpascal',
33
- 'bsz':90,
34
- 'datapath':'../Datasets_CHM',
35
- 'img_size':240,
36
- 'ktype':'psi',
37
- 'load':'pas_psi.pt',
38
- 'thres':'img'
39
- })
40
-
41
- model = chmnet.CHMNet(args['ktype'])
42
- model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu')))
43
- Evaluator.initialize(args['alpha'])
44
- Geometry.initialize(img_size=args['img_size'])
45
- model.eval();
 
 
46
 
47
  # Transforms
48
 
49
  chm_transform = transforms.Compose(
50
- [transforms.Resize(args['img_size']),
51
- transforms.CenterCrop((args['img_size'], args['img_size'])),
52
- transforms.ToTensor(),
53
- transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
 
 
 
54
 
55
  chm_transform_plot = transforms.Compose(
56
- [transforms.Resize(args['img_size']),
57
- transforms.CenterCrop((args['img_size'], args['img_size']))])
 
 
 
58
 
59
  # A Helper Function
60
- to_np = lambda x: x.data.to('cpu').numpy()
61
 
62
  # Colors for Plotting
63
- cmap = matplotlib.cm.get_cmap('Spectral')
64
  rgba = cmap(0.5)
65
  colors = []
66
  for k in range(49):
67
- colors.append(cmap(k/49.0))
68
 
69
 
70
  # CHM MODEL
71
- def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform):
72
- # Convert to Tensor
73
- src_img_tnsr = chm_transform(source_image).unsqueeze(0)
74
- tgt_img_tnsr = chm_transform(target_image).unsqueeze(0)
75
-
76
- # Selected_points = selected_points.T
77
- keypoints = torch.tensor(selected_points).unsqueeze(0)
78
- n_pts = torch.tensor(np.asarray([number_src_points]))
79
-
80
- # RUN CHM ------------------------------------------------------------------------
81
- with torch.no_grad():
82
- corr_matrix = model(src_img_tnsr, tgt_img_tnsr)
83
- prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False)
84
-
85
- # VISUALIZATION
86
- src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
87
- tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()
88
-
89
- src_points_converted = []
90
- w, h = display_transform(source_image).size
91
-
92
- for x,y in zip(src_points[0], src_points[1]):
93
- src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])])
94
-
95
- src_points_converted = np.asarray(src_points_converted[:number_src_points])
96
- tgt_points_converted = []
97
-
98
- w, h = display_transform(target_image).size
99
- for x, y in zip(tgt_points[0], tgt_points[1]):
100
- tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)])
101
-
102
- tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points])
103
-
104
- tgt_grid = []
105
-
106
- for x, y in zip(tgt_points[0], tgt_points[1]):
107
- tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)])
108
-
109
- # PLOT
110
- fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
111
-
112
- ax[0].imshow(display_transform(source_image))
113
- ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points])
114
- ax[0].set_title('Source')
115
- ax[0].set_xticks([])
116
- ax[0].set_yticks([])
117
-
118
- ax[1].imshow(display_transform(target_image))
119
- ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points])
120
- ax[1].set_title('Target')
121
- ax[1].set_xticks([])
122
- ax[1].set_yticks([])
123
-
124
- for TL in range(49):
125
- ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=11))
126
-
127
- for TL in range(49):
128
- ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=11))
129
-
130
- plt.tight_layout()
131
- fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16)
132
- return fig
133
-
134
-
135
- # Wrapper
136
- def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100):
137
- A = np.linspace(min_x, max_x, 7)
138
- B = np.linspace(min_y, max_y, 7)
139
- point_list = list(product(A, B))
140
- new_points = np.asarray(point_list, dtype=np.float64).T
141
- return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
 
143
 
144
  # GRADIO APP
@@ -146,14 +193,21 @@ title = "Correspondence Matching with Convolutional Hough Matching Networks "
146
  description = "Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid."
147
  article = "<p style='text-align: center'><a href='https://github.com/juhongm999/chm' target='_blank'>Original Github Repo</a></p>"
148
 
149
- iface = gr.Interface(fn=generate_correspondences,
150
- inputs=[gr.inputs.Image(shape=(240, 240), type='pil'),
151
- gr.inputs.Image(shape=(240, 240), type='pil'),
152
- gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='Min X'),
153
- gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='Max X'),
154
- gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='Min Y'),
155
- gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='Max Y')], outputs="plot", enable_queue=True, title=title,
156
- description=description,
157
- article=article,
158
- examples=[['sample1.jpeg', 'sample2.jpeg', 15, 215, 15, 215]])
159
- iface.launch()
 
 
 
 
 
 
 
 
22
  import torch.nn.functional as F
23
  import random
24
  import gradio as gr
25
+ import gdown
26
 
27
  # Downloading the Model
28
+ gdown.download(id="1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6", output="pas_psi.pt", quiet=False)
29
 
30
  # Model Initialization
31
+ args = dict(
32
+ {
33
+ "alpha": [0.05, 0.1],
34
+ "benchmark": "pfpascal",
35
+ "bsz": 90,
36
+ "datapath": "../Datasets_CHM",
37
+ "img_size": 240,
38
+ "ktype": "psi",
39
+ "load": "pas_psi.pt",
40
+ "thres": "img",
41
+ }
42
+ )
43
+
44
+ model = chmnet.CHMNet(args["ktype"])
45
+ model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu")))
46
+ Evaluator.initialize(args["alpha"])
47
+ Geometry.initialize(img_size=args["img_size"])
48
+ model.eval()
49
 
50
  # Transforms
51
 
52
  chm_transform = transforms.Compose(
53
+ [
54
+ transforms.Resize(args["img_size"]),
55
+ transforms.CenterCrop((args["img_size"], args["img_size"])),
56
+ transforms.ToTensor(),
57
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
58
+ ]
59
+ )
60
 
61
  chm_transform_plot = transforms.Compose(
62
+ [
63
+ transforms.Resize(args["img_size"]),
64
+ transforms.CenterCrop((args["img_size"], args["img_size"])),
65
+ ]
66
+ )
67
 
68
  # A Helper Function
69
+ to_np = lambda x: x.data.to("cpu").numpy()
70
 
71
  # Colors for Plotting
72
+ cmap = matplotlib.cm.get_cmap("Spectral")
73
  rgba = cmap(0.5)
74
  colors = []
75
  for k in range(49):
76
+ colors.append(cmap(k / 49.0))
77
 
78
 
79
  # CHM MODEL
80
+ def run_chm(
81
+ source_image,
82
+ target_image,
83
+ selected_points,
84
+ number_src_points,
85
+ chm_transform,
86
+ display_transform,
87
+ ):
88
+ # Convert to Tensor
89
+ src_img_tnsr = chm_transform(source_image).unsqueeze(0)
90
+ tgt_img_tnsr = chm_transform(target_image).unsqueeze(0)
91
+
92
+ # Selected_points = selected_points.T
93
+ keypoints = torch.tensor(selected_points).unsqueeze(0)
94
+ n_pts = torch.tensor(np.asarray([number_src_points]))
95
+
96
+ # RUN CHM ------------------------------------------------------------------------
97
+ with torch.no_grad():
98
+ corr_matrix = model(src_img_tnsr, tgt_img_tnsr)
99
+ prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False)
100
+
101
+ # VISUALIZATION
102
+ src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
103
+ tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()
104
+
105
+ src_points_converted = []
106
+ w, h = display_transform(source_image).size
107
+
108
+ for x, y in zip(src_points[0], src_points[1]):
109
+ src_points_converted.append(
110
+ [int(x * w / args["img_size"]), int((y) * h / args["img_size"])]
111
+ )
112
+
113
+ src_points_converted = np.asarray(src_points_converted[:number_src_points])
114
+ tgt_points_converted = []
115
+
116
+ w, h = display_transform(target_image).size
117
+ for x, y in zip(tgt_points[0], tgt_points[1]):
118
+ tgt_points_converted.append(
119
+ [int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)]
120
+ )
121
+
122
+ tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points])
123
+
124
+ tgt_grid = []
125
+
126
+ for x, y in zip(tgt_points[0], tgt_points[1]):
127
+ tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)])
128
+
129
+ # PLOT
130
+ fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
131
+
132
+ ax[0].imshow(display_transform(source_image))
133
+ ax[0].scatter(
134
+ src_points_converted[:, 0],
135
+ src_points_converted[:, 1],
136
+ c=colors[:number_src_points],
137
+ )
138
+ ax[0].set_title("Source")
139
+ ax[0].set_xticks([])
140
+ ax[0].set_yticks([])
141
+
142
+ ax[1].imshow(display_transform(target_image))
143
+ ax[1].scatter(
144
+ tgt_points_converted[:, 0],
145
+ tgt_points_converted[:, 1],
146
+ c=colors[:number_src_points],
147
+ )
148
+ ax[1].set_title("Target")
149
+ ax[1].set_xticks([])
150
+ ax[1].set_yticks([])
151
+
152
+ for TL in range(49):
153
+ ax[0].text(
154
+ x=src_points_converted[TL][0],
155
+ y=src_points_converted[TL][1],
156
+ s=str(TL),
157
+ fontdict=dict(color="red", size=11),
158
+ )
159
+
160
+ for TL in range(49):
161
+ ax[1].text(
162
+ x=tgt_points_converted[TL][0],
163
+ y=tgt_points_converted[TL][1],
164
+ s=f"{str(TL)}",
165
+ fontdict=dict(color="orange", size=11),
166
+ )
167
+
168
+ plt.tight_layout()
169
+ fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16)
170
+ return fig
171
+
172
+
173
+ # Wrapper
174
+ def generate_correspondences(
175
+ sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100
176
+ ):
177
+ A = np.linspace(min_x, max_x, 7)
178
+ B = np.linspace(min_y, max_y, 7)
179
+ point_list = list(product(A, B))
180
+ new_points = np.asarray(point_list, dtype=np.float64).T
181
+ return run_chm(
182
+ sousrce_image,
183
+ target_image,
184
+ selected_points=new_points,
185
+ number_src_points=49,
186
+ chm_transform=chm_transform,
187
+ display_transform=chm_transform_plot,
188
+ )
189
 
190
 
191
  # GRADIO APP
 
193
  description = "Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid."
194
  article = "<p style='text-align: center'><a href='https://github.com/juhongm999/chm' target='_blank'>Original Github Repo</a></p>"
195
 
196
+ iface = gr.Interface(
197
+ fn=generate_correspondences,
198
+ inputs=[
199
+ gr.inputs.Image(shape=(240, 240), type="pil"),
200
+ gr.inputs.Image(shape=(240, 240), type="pil"),
201
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label="Min X"),
202
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label="Max X"),
203
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label="Min Y"),
204
+ gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label="Max Y"),
205
+ ],
206
+ outputs="plot",
207
+ enable_queue=True,
208
+ title=title,
209
+ description=description,
210
+ article=article,
211
+ examples=[["sample1.jpeg", "sample2.jpeg", 15, 215, 15, 215]],
212
+ )
213
+ iface.launch()
requirements.txt CHANGED
@@ -8,3 +8,4 @@ scipy==1.7.1
8
  tensorboardX==2.4.1
9
  torch==1.10.0
10
  torchvision==0.11.1
 
 
8
  tensorboardX==2.4.1
9
  torch==1.10.0
10
  torchvision==0.11.1
11
+ gdown