Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -66,23 +66,14 @@ COLORS = [
|
|
66 |
[255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
|
67 |
]
|
68 |
|
69 |
-
|
70 |
-
skeleton = []
|
71 |
-
count = 0
|
72 |
-
color_idx = 0
|
73 |
-
prev_pt = None
|
74 |
-
prev_pt_idx = None
|
75 |
-
prev_clicked = None
|
76 |
-
original_support_image = None
|
77 |
-
checkpoint_path = ''
|
78 |
-
|
79 |
-
def process(query_img,
|
80 |
cfg_path='configs/demo_b.py'):
|
81 |
-
global skeleton
|
82 |
cfg = Config.fromfile(cfg_path)
|
83 |
-
kp_src_np = np.array(kp_src).copy().astype(np.float32)
|
84 |
-
kp_src_np[:, 0] = kp_src_np[:,
|
85 |
-
|
|
|
|
|
86 |
kp_src_np = np.flip(kp_src_np, 1).copy()
|
87 |
kp_src_tensor = torch.tensor(kp_src_np).float()
|
88 |
preprocess = transforms.Compose([
|
@@ -91,10 +82,10 @@ def process(query_img,
|
|
91 |
Resize_Pad(cfg.model.encoder_config.img_size,
|
92 |
cfg.model.encoder_config.img_size)])
|
93 |
|
94 |
-
if len(skeleton) == 0:
|
95 |
skeleton = [(0, 0)]
|
96 |
|
97 |
-
support_img = preprocess(original_support_image).flip(0)[None]
|
98 |
np_query = np.array(query_img)[:, :, ::-1].copy()
|
99 |
q_img = preprocess(np_query).flip(0)[None]
|
100 |
# Create heatmap from keypoints
|
@@ -104,9 +95,9 @@ def process(query_img,
|
|
104 |
cfg.model.encoder_config.img_size])
|
105 |
data_cfg['joint_weights'] = None
|
106 |
data_cfg['use_different_joint_weights'] = False
|
107 |
-
kp_src_3d = torch.
|
108 |
(kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
109 |
-
kp_src_3d_weight = torch.
|
110 |
(torch.ones_like(kp_src_tensor),
|
111 |
torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
112 |
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
|
@@ -125,8 +116,8 @@ def process(query_img,
|
|
125 |
'target_q': None,
|
126 |
'target_weight_q': None,
|
127 |
'return_loss': False,
|
128 |
-
'img_metas': [{'sample_skeleton': [skeleton],
|
129 |
-
'query_skeleton': skeleton,
|
130 |
'sample_joints_3d': [kp_src_3d],
|
131 |
'query_joints_3d': kp_src_3d,
|
132 |
'sample_center': [kp_src_tensor.mean(dim=0)],
|
@@ -165,54 +156,77 @@ def process(query_img,
|
|
165 |
vis_s_weight,
|
166 |
None,
|
167 |
vis_q_weight,
|
168 |
-
skeleton,
|
169 |
None,
|
170 |
torch.tensor(outputs['points']).squeeze(0),
|
171 |
)
|
172 |
-
return out
|
173 |
|
174 |
|
175 |
with gr.Blocks() as demo:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
gr.Markdown('''
|
177 |
# Pose Anything Demo
|
178 |
-
We present a novel approach to category agnostic pose estimation that
|
179 |
-
|
180 |
-
|
|
|
|
|
|
|
|
|
|
|
181 |
## Instructions
|
182 |
1. Upload an image of the object you want to pose on the **left** image.
|
183 |
2. Click on the **left** image to mark keypoints.
|
184 |
3. Click on the keypoints on the **right** image to mark limbs.
|
185 |
-
4. Upload an image of the object you want to pose to the query image (
|
|
|
186 |
5. Click **Evaluate** to pose the query image.
|
187 |
''')
|
188 |
with gr.Row():
|
189 |
support_img = gr.Image(label="Support Image",
|
190 |
type="pil",
|
191 |
info='Click to mark keypoints').style(
|
192 |
-
height=
|
193 |
posed_support = gr.Image(label="Posed Support Image",
|
194 |
type="pil",
|
195 |
-
interactive=False).style(height=
|
|
|
196 |
with gr.Row():
|
197 |
query_img = gr.Image(label="Query Image",
|
198 |
-
type="pil").style(height=
|
199 |
with gr.Row():
|
200 |
eval_btn = gr.Button(value="Evaluate")
|
201 |
with gr.Row():
|
202 |
-
output_img = gr.Plot(label="Output Image", height=
|
203 |
|
204 |
|
205 |
def get_select_coords(kp_support,
|
206 |
limb_support,
|
|
|
207 |
evt: gr.SelectData,
|
208 |
r=0.015):
|
|
|
|
|
|
|
|
|
209 |
pixels_in_queue = set()
|
210 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
211 |
while len(pixels_in_queue) > 0:
|
212 |
pixel = pixels_in_queue.pop()
|
213 |
if pixel[0] is not None and pixel[
|
214 |
-
1] is not None and pixel not in kp_src:
|
215 |
-
kp_src.append(pixel)
|
216 |
else:
|
217 |
print("Invalid pixel")
|
218 |
if limb_support is None:
|
@@ -230,13 +244,13 @@ with gr.Blocks() as demo:
|
|
230 |
draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
231 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
232 |
|
233 |
-
return canvas_kp, canvas_limb
|
234 |
|
235 |
|
236 |
def get_limbs(kp_support,
|
|
|
237 |
evt: gr.SelectData,
|
238 |
r=0.02, width=0.02):
|
239 |
-
global count, color_idx, prev_pt, skeleton, prev_pt_idx, prev_clicked
|
240 |
curr_pixel = (evt.index[1], evt.index[0])
|
241 |
pixels_in_queue = set()
|
242 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
@@ -244,64 +258,62 @@ with gr.Blocks() as demo:
|
|
244 |
w, h = canvas_kp.size
|
245 |
r = int(r * w)
|
246 |
width = int(width * w)
|
247 |
-
while
|
248 |
-
curr_pixel != prev_clicked and
|
249 |
-
len(kp_src) > 0):
|
250 |
pixel = pixels_in_queue.pop()
|
251 |
-
prev_clicked = pixel
|
252 |
-
closest_point = min(kp_src,
|
253 |
key=lambda p: (p[0] - pixel[0]) ** 2 +
|
254 |
(p[1] - pixel[1]) ** 2)
|
255 |
-
closest_point_index = kp_src.index(closest_point)
|
256 |
draw_limb = ImageDraw.Draw(canvas_kp)
|
257 |
-
if color_idx < len(COLORS):
|
258 |
-
c = COLORS[color_idx]
|
259 |
else:
|
260 |
c = random.choices(range(256), k=3)
|
261 |
leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
|
262 |
rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
|
263 |
twoPointList = [leftUpPoint, rightDownPoint]
|
264 |
draw_limb.ellipse(twoPointList, fill=tuple(c))
|
265 |
-
if count == 0:
|
266 |
-
prev_pt = closest_point[1], closest_point[0]
|
267 |
-
prev_pt_idx = closest_point_index
|
268 |
-
count = count + 1
|
269 |
else:
|
270 |
-
if prev_pt_idx != closest_point_index:
|
271 |
# Create Line and add Limb
|
272 |
-
draw_limb.line(
|
273 |
-
|
274 |
-
|
275 |
-
|
276 |
-
|
|
|
277 |
else:
|
278 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
279 |
-
count = 0
|
280 |
-
return canvas_kp
|
281 |
|
282 |
|
283 |
-
def
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
original_support_image = np.array(support_img)[:, :, ::-1].copy()
|
288 |
support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
|
289 |
-
return support_img, support_img
|
290 |
|
291 |
|
292 |
support_img.select(get_select_coords,
|
293 |
-
[support_img, posed_support],
|
294 |
-
[support_img, posed_support]
|
295 |
-
|
296 |
-
|
297 |
-
|
298 |
-
outputs=[support_img,posed_support])
|
299 |
posed_support.select(get_limbs,
|
300 |
-
posed_support,
|
301 |
-
posed_support)
|
302 |
eval_btn.click(fn=process,
|
303 |
-
inputs=[query_img],
|
304 |
-
outputs=output_img)
|
|
|
305 |
|
306 |
if __name__ == "__main__":
|
307 |
parser = argparse.ArgumentParser(description='Pose Anything Demo')
|
|
|
66 |
[255, 0, 255], [255, 0, 170], [255, 0, 85], [255, 0, 0]
|
67 |
]
|
68 |
|
69 |
+
def process(query_img, state,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
cfg_path='configs/demo_b.py'):
|
|
|
71 |
cfg = Config.fromfile(cfg_path)
|
72 |
+
kp_src_np = np.array(state['kp_src']).copy().astype(np.float32)
|
73 |
+
kp_src_np[:, 0] = kp_src_np[:,
|
74 |
+
0] / 128. * cfg.model.encoder_config.img_size
|
75 |
+
kp_src_np[:, 1] = kp_src_np[:,
|
76 |
+
1] / 128. * cfg.model.encoder_config.img_size
|
77 |
kp_src_np = np.flip(kp_src_np, 1).copy()
|
78 |
kp_src_tensor = torch.tensor(kp_src_np).float()
|
79 |
preprocess = transforms.Compose([
|
|
|
82 |
Resize_Pad(cfg.model.encoder_config.img_size,
|
83 |
cfg.model.encoder_config.img_size)])
|
84 |
|
85 |
+
if len(state['skeleton']) == 0:
|
86 |
skeleton = [(0, 0)]
|
87 |
|
88 |
+
support_img = preprocess(state['original_support_image']).flip(0)[None]
|
89 |
np_query = np.array(query_img)[:, :, ::-1].copy()
|
90 |
q_img = preprocess(np_query).flip(0)[None]
|
91 |
# Create heatmap from keypoints
|
|
|
95 |
cfg.model.encoder_config.img_size])
|
96 |
data_cfg['joint_weights'] = None
|
97 |
data_cfg['use_different_joint_weights'] = False
|
98 |
+
kp_src_3d = torch.concatenate(
|
99 |
(kp_src_tensor, torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
100 |
+
kp_src_3d_weight = torch.concatenate(
|
101 |
(torch.ones_like(kp_src_tensor),
|
102 |
torch.zeros(kp_src_tensor.shape[0], 1)), dim=-1)
|
103 |
target_s, target_weight_s = genHeatMap._msra_generate_target(data_cfg,
|
|
|
116 |
'target_q': None,
|
117 |
'target_weight_q': None,
|
118 |
'return_loss': False,
|
119 |
+
'img_metas': [{'sample_skeleton': [state['skeleton']],
|
120 |
+
'query_skeleton': state['skeleton'],
|
121 |
'sample_joints_3d': [kp_src_3d],
|
122 |
'query_joints_3d': kp_src_3d,
|
123 |
'sample_center': [kp_src_tensor.mean(dim=0)],
|
|
|
156 |
vis_s_weight,
|
157 |
None,
|
158 |
vis_q_weight,
|
159 |
+
state['skeleton'],
|
160 |
None,
|
161 |
torch.tensor(outputs['points']).squeeze(0),
|
162 |
)
|
163 |
+
return out, state
|
164 |
|
165 |
|
166 |
with gr.Blocks() as demo:
|
167 |
+
state = gr.State({
|
168 |
+
'kp_src': [],
|
169 |
+
'skeleton': [],
|
170 |
+
'count': 0,
|
171 |
+
'color_idx': 0,
|
172 |
+
'prev_pt': None,
|
173 |
+
'prev_pt_idx': None,
|
174 |
+
'prev_clicked': None,
|
175 |
+
'original_support_image': None,
|
176 |
+
})
|
177 |
+
|
178 |
gr.Markdown('''
|
179 |
# Pose Anything Demo
|
180 |
+
We present a novel approach to category agnostic pose estimation that
|
181 |
+
leverages the inherent geometrical relations between keypoints through a
|
182 |
+
newly designed Graph Transformer Decoder. By capturing and incorporating
|
183 |
+
this crucial structural information, our method enhances the accuracy of
|
184 |
+
keypoint localization, marking a significant departure from conventional
|
185 |
+
CAPE techniques that treat keypoints as isolated entities.
|
186 |
+
### [Paper](https://arxiv.org/abs/2311.17891) | [Official Repo](
|
187 |
+
https://github.com/orhir/PoseAnything)
|
188 |
## Instructions
|
189 |
1. Upload an image of the object you want to pose on the **left** image.
|
190 |
2. Click on the **left** image to mark keypoints.
|
191 |
3. Click on the keypoints on the **right** image to mark limbs.
|
192 |
+
4. Upload an image of the object you want to pose to the query image (
|
193 |
+
**bottom**).
|
194 |
5. Click **Evaluate** to pose the query image.
|
195 |
''')
|
196 |
with gr.Row():
|
197 |
support_img = gr.Image(label="Support Image",
|
198 |
type="pil",
|
199 |
info='Click to mark keypoints').style(
|
200 |
+
height=400, width=400)
|
201 |
posed_support = gr.Image(label="Posed Support Image",
|
202 |
type="pil",
|
203 |
+
interactive=False).style(height=400,
|
204 |
+
width=400)
|
205 |
with gr.Row():
|
206 |
query_img = gr.Image(label="Query Image",
|
207 |
+
type="pil").style(height=400, width=400)
|
208 |
with gr.Row():
|
209 |
eval_btn = gr.Button(value="Evaluate")
|
210 |
with gr.Row():
|
211 |
+
output_img = gr.Plot(label="Output Image", height=400, width=400)
|
212 |
|
213 |
|
214 |
def get_select_coords(kp_support,
|
215 |
limb_support,
|
216 |
+
state,
|
217 |
evt: gr.SelectData,
|
218 |
r=0.015):
|
219 |
+
# global original_support_image
|
220 |
+
# if len(kp_src) == 0:
|
221 |
+
# original_support_image = np.array(kp_support)[:, :,
|
222 |
+
# ::-1].copy()
|
223 |
pixels_in_queue = set()
|
224 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
225 |
while len(pixels_in_queue) > 0:
|
226 |
pixel = pixels_in_queue.pop()
|
227 |
if pixel[0] is not None and pixel[
|
228 |
+
1] is not None and pixel not in state['kp_src']:
|
229 |
+
state['kp_src'].append(pixel)
|
230 |
else:
|
231 |
print("Invalid pixel")
|
232 |
if limb_support is None:
|
|
|
244 |
draw_pose.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
245 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
246 |
|
247 |
+
return canvas_kp, canvas_limb, state
|
248 |
|
249 |
|
250 |
def get_limbs(kp_support,
|
251 |
+
state,
|
252 |
evt: gr.SelectData,
|
253 |
r=0.02, width=0.02):
|
|
|
254 |
curr_pixel = (evt.index[1], evt.index[0])
|
255 |
pixels_in_queue = set()
|
256 |
pixels_in_queue.add((evt.index[1], evt.index[0]))
|
|
|
258 |
w, h = canvas_kp.size
|
259 |
r = int(r * w)
|
260 |
width = int(width * w)
|
261 |
+
while len(pixels_in_queue) > 0 and curr_pixel != state['prev_clicked']:
|
|
|
|
|
262 |
pixel = pixels_in_queue.pop()
|
263 |
+
state['prev_clicked'] = pixel
|
264 |
+
closest_point = min(state['kp_src'],
|
265 |
key=lambda p: (p[0] - pixel[0]) ** 2 +
|
266 |
(p[1] - pixel[1]) ** 2)
|
267 |
+
closest_point_index = state['kp_src'].index(closest_point)
|
268 |
draw_limb = ImageDraw.Draw(canvas_kp)
|
269 |
+
if state['color_idx'] < len(COLORS):
|
270 |
+
c = COLORS[state['color_idx']]
|
271 |
else:
|
272 |
c = random.choices(range(256), k=3)
|
273 |
leftUpPoint = (closest_point[1] - r, closest_point[0] - r)
|
274 |
rightDownPoint = (closest_point[1] + r, closest_point[0] + r)
|
275 |
twoPointList = [leftUpPoint, rightDownPoint]
|
276 |
draw_limb.ellipse(twoPointList, fill=tuple(c))
|
277 |
+
if state['count'] == 0:
|
278 |
+
state['prev_pt'] = closest_point[1], closest_point[0]
|
279 |
+
state['prev_pt_idx'] = closest_point_index
|
280 |
+
state['count'] = state['count'] + 1
|
281 |
else:
|
282 |
+
if state['prev_pt_idx'] != closest_point_index:
|
283 |
# Create Line and add Limb
|
284 |
+
draw_limb.line(
|
285 |
+
[state['prev_pt'], (closest_point[1], closest_point[0])],
|
286 |
+
fill=tuple(c),
|
287 |
+
width=width)
|
288 |
+
state['skeleton'].append((state['prev_pt_idx'], closest_point_index))
|
289 |
+
state['color_idx'] = state['color_idx'] + 1
|
290 |
else:
|
291 |
draw_limb.ellipse(twoPointList, fill=(255, 0, 0, 255))
|
292 |
+
state['count'] = 0
|
293 |
+
return canvas_kp, state
|
294 |
|
295 |
|
296 |
+
def set_qery(support_img, state):
|
297 |
+
state['skeleton'].clear()
|
298 |
+
state['kp_src'].clear()
|
299 |
+
state['original_support_image'] = np.array(support_img)[:, :, ::-1].copy()
|
|
|
300 |
support_img = support_img.resize((128, 128), Image.Resampling.LANCZOS)
|
301 |
+
return support_img, support_img, state
|
302 |
|
303 |
|
304 |
support_img.select(get_select_coords,
|
305 |
+
[support_img, posed_support, state],
|
306 |
+
[support_img, posed_support, state])
|
307 |
+
support_img.upload(set_qery,
|
308 |
+
inputs=[support_img, state],
|
309 |
+
outputs=[support_img, posed_support, state])
|
|
|
310 |
posed_support.select(get_limbs,
|
311 |
+
[posed_support, state],
|
312 |
+
[posed_support, state])
|
313 |
eval_btn.click(fn=process,
|
314 |
+
inputs=[query_img, state],
|
315 |
+
outputs=[output_img, state])
|
316 |
+
|
317 |
|
318 |
if __name__ == "__main__":
|
319 |
parser = argparse.ArgumentParser(description='Pose Anything Demo')
|