orhir commited on
Commit
248b92d
·
1 Parent(s): ea9182f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -71
app.py CHANGED
@@ -66,23 +66,14 @@ COLORS = [
66
  [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
67
  ]
68
 
69
- kp_src = []
70
- skeleton = []
71
- count = 0
72
- color_idx = 0
73
- prev_pt = None
74
- prev_pt_idx = None
75
- prev_clicked = None
76
- original_support_image = None
77
- checkpoint_path = ''
78
-
79
- def process(query_img,
80
  cfg_path='configs/demo_b.py'):
81
- global skeleton
82
  cfg = Config.fromfile(cfg_path)
83
- kp_src_np = np.array(kp_src).copy().astype(np.float32)
84
- kp_src_np[:, 0] = kp_src_np[:, 0] / 128. * cfg.model.encoder_config.img_size
85
- kp_src_np[:, 1] = kp_src_np[:, 1] / 128. * cfg.model.encoder_config.img_size
 
 
86
  kp_src_np = np.flip(kp_src_np, 1).copy()
87
  kp_src_tensor = torch.tensor(kp_src_np).float()
88
  preprocess = transforms.Compose([
@@ -91,10 +82,10 @@ def process(query_img,
91
  Resize_Pad(cfg.model.encoder_config.img_size,
92
  cfg.model.encoder_config.img_size)])
93
 
94
- if len(skeleton) == 0:
95
  skeleton = [(0, 0)]
96
 
97
- support_img = preprocess(original_support_image).flip(0)[None]
98
  np_query = np.array(query_img)[:, :, ::-1].copy()
99
  q_img = preprocess(np_query).flip(0)[None]
100
  # Create heatmap from keypoints
@@ -104,9 +95,9 @@ def process(query_img,
104
  cfg.model.encoder_config.img_size])
105
  data_cfg['joint_weights'] = None
106
  data_cfg['use_different_joint_weights'] = False
107
- kp_src_3d = torch.cat(
108
  (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
109
- kp_src_3d_weight = torch.cat(
110
  (torch.ones_like(kp_src_tensor),
111
  torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
112
  target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
@@ -125,8 +116,8 @@ def process(query_img,
125
  'target_q': None,
126
  'target_weight_q': None,
127
  'return_loss': False,
128
- 'img_metas': [{'sample_skeleton': [skeleton],
129
- 'query_skeleton': skeleton,
130
  'sample_joints_3d': [kp_src_3d],
131
  'query_joints_3d': kp_src_3d,
132
  'sample_center': [kp_src_tensor.mean(dim=0)],
@@ -165,54 +156,77 @@ def process(query_img,
165
  vis_s_weight,
166
  None,
167
  vis_q_weight,
168
- skeleton,
169
  None,
170
  torch.tensor(outputs['points']).squeeze(0),
171
  )
172
- return out
173
 
174
 
175
  with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
 
 
176
  gr.Markdown('''
177
  # Pose Anything Demo
178
- We present a novel approach to category agnostic pose estimation that leverages the inherent geometrical relations between keypoints through a newly designed Graph Transformer Decoder. By capturing and incorporating this crucial structural information, our method enhances the accuracy of keypoint localization, marking a significant departure from conventional CAPE techniques that treat keypoints as isolated entities.
179
- ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](https://github.com/orhir/PoseAnything)
180
- ![](/file=gradio_teaser.png)
 
 
 
 
 
181
  ## Instructions
182
  1. Upload an image of the object you want to pose on the **left** image.
183
  2. Click on the **left** image to mark keypoints.
184
  3. Click on the keypoints on the **right** image to mark limbs.
185
- 4. Upload an image of the object you want to pose to the query image (**bottom**).
 
186
  5. Click **Evaluate** to pose the query image.
187
  ''')
188
  with gr.Row():
189
  support_img = gr.Image(label="Support Image",
190
  type="pil",
191
  info='Click to mark keypoints').style(
192
- height=256, width=256)
193
  posed_support = gr.Image(label="Posed Support Image",
194
  type="pil",
195
- interactive=False).style(height=256, width=256)
 
196
  with gr.Row():
197
  query_img = gr.Image(label="Query Image",
198
- type="pil").style(height=256, width=256)
199
  with gr.Row():
200
  eval_btn = gr.Button(value="Evaluate")
201
  with gr.Row():
202
- output_img = gr.Plot(label="Output Image", height=256, width=256)
203
 
204
 
205
  def get_select_coords(kp_support,
206
  limb_support,
 
207
  evt: gr.SelectData,
208
  r=0.015):
 
 
 
 
209
  pixels_in_queue = set()
210
  pixels_in_queue.add((evt.index[1], evt.index[0]))
211
  while len(pixels_in_queue) > 0:
212
  pixel = pixels_in_queue.pop()
213
  if pixel[0] is not None and pixel[
214
- 1] is not None and pixel not in kp_src:
215
- kp_src.append(pixel)
216
  else:
217
  print("Invalid pixel")
218
  if limb_support is None:
@@ -230,13 +244,13 @@ with gr.Blocks() as demo:
230
  draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
231
  draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
232
 
233
- return canvas_kp, canvas_limb
234
 
235
 
236
  def get_limbs(kp_support,
 
237
  evt: gr.SelectData,
238
  r=0.02, width=0.02):
239
- global count, color_idx, prev_pt, skeleton, prev_pt_idx, prev_clicked
240
  curr_pixel = (evt.index[1], evt.index[0])
241
  pixels_in_queue = set()
242
  pixels_in_queue.add((evt.index[1], evt.index[0]))
@@ -244,64 +258,62 @@ with gr.Blocks() as demo:
244
  w, h = canvas_kp.size
245
  r = int(r * w)
246
  width = int(width * w)
247
- while (len(pixels_in_queue) > 0 and
248
- curr_pixel != prev_clicked and
249
- len(kp_src) > 0):
250
  pixel = pixels_in_queue.pop()
251
- prev_clicked = pixel
252
- closest_point = min(kp_src,
253
  key=lambda p: (p[0] - pixel[0]) ** 2 +
254
  (p[1] - pixel[1]) ** 2)
255
- closest_point_index = kp_src.index(closest_point)
256
  draw_limb = ImageDraw.Draw(canvas_kp)
257
- if color_idx < len(COLORS):
258
- c = COLORS[color_idx]
259
  else:
260
  c = random.choices(range(256), k=3)
261
  leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
262
  rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
263
  twoPointList = [leftUpPoint, rightDownPoint]
264
  draw_limb.ellipse(twoPointList, fill=tuple(c))
265
- if count == 0:
266
- prev_pt = closest_point[1], closest_point[0]
267
- prev_pt_idx = closest_point_index
268
- count = count + 1
269
  else:
270
- if prev_pt_idx != closest_point_index:
271
  # Create Line and add Limb
272
- draw_limb.line([prev_pt, (closest_point[1], closest_point[0])],
273
- fill=tuple(c),
274
- width=width)
275
- skeleton.append((prev_pt_idx, closest_point_index))
276
- color_idx = color_idx + 1
 
277
  else:
278
  draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
279
- count = 0
280
- return canvas_kp
281
 
282
 
283
- def set_query(support_img):
284
- global original_support_image
285
- skeleton.clear()
286
- kp_src.clear()
287
- original_support_image = np.array(support_img)[:, :, ::-1].copy()
288
  support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
289
- return support_img, support_img
290
 
291
 
292
  support_img.select(get_select_coords,
293
- [support_img, posed_support],
294
- [support_img, posed_support],
295
- )
296
- support_img.upload(set_query,
297
- inputs=support_img,
298
- outputs=[support_img,posed_support])
299
  posed_support.select(get_limbs,
300
- posed_support,
301
- posed_support)
302
  eval_btn.click(fn=process,
303
- inputs=[query_img],
304
- outputs=output_img)
 
305
 
306
  if __name__ == "__main__":
307
  parser = argparse.ArgumentParser(description='Pose Anything Demo')
 
66
  [255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
67
  ]
68
 
69
+ def process(query_img, state,
 
 
 
 
 
 
 
 
 
 
70
  cfg_path='configs/demo_b.py'):
 
71
  cfg = Config.fromfile(cfg_path)
72
+ kp_src_np = np.array(state['kp_src']).copy().astype(np.float32)
73
+ kp_src_np[:, 0] = kp_src_np[:,
74
+ 0] / 128. * cfg.model.encoder_config.img_size
75
+ kp_src_np[:, 1] = kp_src_np[:,
76
+ 1] / 128. * cfg.model.encoder_config.img_size
77
  kp_src_np = np.flip(kp_src_np, 1).copy()
78
  kp_src_tensor = torch.tensor(kp_src_np).float()
79
  preprocess = transforms.Compose([
 
82
  Resize_Pad(cfg.model.encoder_config.img_size,
83
  cfg.model.encoder_config.img_size)])
84
 
85
+ if len(state['skeleton']) == 0:
86
  skeleton = [(0, 0)]
87
 
88
+ support_img = preprocess(state['original_support_image']).flip(0)[None]
89
  np_query = np.array(query_img)[:, :, ::-1].copy()
90
  q_img = preprocess(np_query).flip(0)[None]
91
  # Create heatmap from keypoints
 
95
  cfg.model.encoder_config.img_size])
96
  data_cfg['joint_weights'] = None
97
  data_cfg['use_different_joint_weights'] = False
98
+ kp_src_3d = torch.concatenate(
99
  (kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
100
+ kp_src_3d_weight = torch.concatenate(
101
  (torch.ones_like(kp_src_tensor),
102
  torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
103
  target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
 
116
  'target_q': None,
117
  'target_weight_q': None,
118
  'return_loss': False,
119
+ 'img_metas': [{'sample_skeleton': [state['skeleton']],
120
+ 'query_skeleton': state['skeleton'],
121
  'sample_joints_3d': [kp_src_3d],
122
  'query_joints_3d': kp_src_3d,
123
  'sample_center': [kp_src_tensor.mean(dim=0)],
 
156
  vis_s_weight,
157
  None,
158
  vis_q_weight,
159
+ state['skeleton'],
160
  None,
161
  torch.tensor(outputs['points']).squeeze(0),
162
  )
163
+ return out, state
164
 
165
 
166
  with gr.Blocks() as demo:
167
+ state = gr.State({
168
+ 'kp_src': [],
169
+ 'skeleton': [],
170
+ 'count': 0,
171
+ 'color_idx': 0,
172
+ 'prev_pt': None,
173
+ 'prev_pt_idx': None,
174
+ 'prev_clicked': None,
175
+ 'original_support_image': None,
176
+ })
177
+
178
  gr.Markdown('''
179
  # Pose Anything Demo
180
+ We present a novel approach to category agnostic pose estimation that
181
+ leverages the inherent geometrical relations between keypoints through a
182
+ newly designed Graph Transformer Decoder. By capturing and incorporating
183
+ this crucial structural information, our method enhances the accuracy of
184
+ keypoint localization, marking a significant departure from conventional
185
+ CAPE techniques that treat keypoints as isolated entities.
186
+ ### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](
187
+ https://github.com/orhir/PoseAnything)
188
  ## Instructions
189
  1. Upload an image of the object you want to pose on the **left** image.
190
  2. Click on the **left** image to mark keypoints.
191
  3. Click on the keypoints on the **right** image to mark limbs.
192
+ 4. Upload an image of the object you want to pose to the query image (
193
+ **bottom**).
194
  5. Click **Evaluate** to pose the query image.
195
  ''')
196
  with gr.Row():
197
  support_img = gr.Image(label="Support Image",
198
  type="pil",
199
  info='Click to mark keypoints').style(
200
+ height=400, width=400)
201
  posed_support = gr.Image(label="Posed Support Image",
202
  type="pil",
203
+ interactive=False).style(height=400,
204
+ width=400)
205
  with gr.Row():
206
  query_img = gr.Image(label="Query Image",
207
+ type="pil").style(height=400, width=400)
208
  with gr.Row():
209
  eval_btn = gr.Button(value="Evaluate")
210
  with gr.Row():
211
+ output_img = gr.Plot(label="Output Image", height=400, width=400)
212
 
213
 
214
  def get_select_coords(kp_support,
215
  limb_support,
216
+ state,
217
  evt: gr.SelectData,
218
  r=0.015):
219
+ # global original_support_image
220
+ # if len(kp_src) == 0:
221
+ # original_support_image = np.array(kp_support)[:, :,
222
+ # ::-1].copy()
223
  pixels_in_queue = set()
224
  pixels_in_queue.add((evt.index[1], evt.index[0]))
225
  while len(pixels_in_queue) > 0:
226
  pixel = pixels_in_queue.pop()
227
  if pixel[0] is not None and pixel[
228
+ 1] is not None and pixel not in state['kp_src']:
229
+ state['kp_src'].append(pixel)
230
  else:
231
  print("Invalid pixel")
232
  if limb_support is None:
 
244
  draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
245
  draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
246
 
247
+ return canvas_kp, canvas_limb, state
248
 
249
 
250
  def get_limbs(kp_support,
251
+ state,
252
  evt: gr.SelectData,
253
  r=0.02, width=0.02):
 
254
  curr_pixel = (evt.index[1], evt.index[0])
255
  pixels_in_queue = set()
256
  pixels_in_queue.add((evt.index[1], evt.index[0]))
 
258
  w, h = canvas_kp.size
259
  r = int(r * w)
260
  width = int(width * w)
261
+ while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']:
 
 
262
  pixel = pixels_in_queue.pop()
263
+ state['prev_clicked'] = pixel
264
+ closest_point = min(state['kp_src'],
265
  key=lambda p: (p[0] - pixel[0]) ** 2 +
266
  (p[1] - pixel[1]) ** 2)
267
+ closest_point_index = state['kp_src'].index(closest_point)
268
  draw_limb = ImageDraw.Draw(canvas_kp)
269
+ if state['color_idx'] < len(COLORS):
270
+ c = COLORS[state['color_idx']]
271
  else:
272
  c = random.choices(range(256), k=3)
273
  leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
274
  rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
275
  twoPointList = [leftUpPoint, rightDownPoint]
276
  draw_limb.ellipse(twoPointList, fill=tuple(c))
277
+ if state['count'] == 0:
278
+ state['prev_pt'] = closest_point[1], closest_point[0]
279
+ state['prev_pt_idx'] = closest_point_index
280
+ state['count'] = state['count'] + 1
281
  else:
282
+ if state['prev_pt_idx'] != closest_point_index:
283
  # Create Line and add Limb
284
+ draw_limb.line(
285
+ [state['prev_pt'], (closest_point[1], closest_point[0])],
286
+ fill=tuple(c),
287
+ width=width)
288
+ state['skeleton'].append((state['prev_pt_idx'], closest_point_index))
289
+ state['color_idx'] = state['color_idx'] + 1
290
  else:
291
  draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
292
+ state['count'] = 0
293
+ return canvas_kp, state
294
 
295
 
296
+ def set_qery(support_img, state):
297
+ state['skeleton'].clear()
298
+ state['kp_src'].clear()
299
+ state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
 
300
  support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
301
+ return support_img, support_img, state
302
 
303
 
304
  support_img.select(get_select_coords,
305
+ [support_img, posed_support, state],
306
+ [support_img, posed_support, state])
307
+ support_img.upload(set_qery,
308
+ inputs=[support_img, state],
309
+ outputs=[support_img, posed_support, state])
 
310
  posed_support.select(get_limbs,
311
+ [posed_support, state],
312
+ [posed_support, state])
313
  eval_btn.click(fn=process,
314
+ inputs=[query_img, state],
315
+ outputs=[output_img, state])
316
+
317
 
318
  if __name__ == "__main__":
319
  parser = argparse.ArgumentParser(description='Pose Anything Demo')