satpalsr commited on
Commit
d7f5f1a
1 Parent(s): e971533

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +368 -0
app.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # CODE WAS MODIFIED FROM https://github.com/leoxiaobin/deep-high-resolution-net.pytorch
2
+ import torch
3
+ import cv2
4
+ import torchvision.transforms as transforms
5
+ import numpy as np
6
+ import math
7
+ import torchvision
8
+ import gradio as gr
9
+
10
+ from PIL import Image
11
+ import requests
12
+
13
+
14
+
15
+ COCO_KEYPOINT_INDEXES = {
16
+ 0: 'nose',
17
+ 1: 'left_eye',
18
+ 2: 'right_eye',
19
+ 3: 'left_ear',
20
+ 4: 'right_ear',
21
+ 5: 'left_shoulder',
22
+ 6: 'right_shoulder',
23
+ 7: 'left_elbow',
24
+ 8: 'right_elbow',
25
+ 9: 'left_wrist',
26
+ 10: 'right_wrist',
27
+ 11: 'left_hip',
28
+ 12: 'right_hip',
29
+ 13: 'left_knee',
30
+ 14: 'right_knee',
31
+ 15: 'left_ankle',
32
+ 16: 'right_ankle'
33
+ }
34
+
35
+ COCO_INSTANCE_CATEGORY_NAMES = [
36
+ '__background__', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
37
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A', 'stop sign',
38
+ 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep', 'cow',
39
+ 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A', 'N/A',
40
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard', 'sports ball',
41
+ 'kite', 'baseball bat', 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
42
+ 'bottle', 'N/A', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
43
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
44
+ 'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table',
45
+ 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
46
+ 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A', 'book',
47
+ 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
48
+ ]
49
+
50
+
51
+ def get_max_preds(batch_heatmaps):
52
+ '''
53
+ get predictions from score maps
54
+ heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
55
+ '''
56
+ assert isinstance(batch_heatmaps, np.ndarray), \
57
+ 'batch_heatmaps should be numpy.ndarray'
58
+ assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
59
+
60
+ batch_size = batch_heatmaps.shape[0]
61
+ num_joints = batch_heatmaps.shape[1]
62
+ width = batch_heatmaps.shape[3]
63
+ heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
64
+ idx = np.argmax(heatmaps_reshaped, 2)
65
+ maxvals = np.amax(heatmaps_reshaped, 2)
66
+
67
+ maxvals = maxvals.reshape((batch_size, num_joints, 1))
68
+ idx = idx.reshape((batch_size, num_joints, 1))
69
+
70
+ preds = np.tile(idx, (1, 1, 2)).astype(np.float32)
71
+
72
+ preds[:, :, 0] = (preds[:, :, 0]) % width
73
+ preds[:, :, 1] = np.floor((preds[:, :, 1]) / width)
74
+
75
+ pred_mask = np.tile(np.greater(maxvals, 0.0), (1, 1, 2))
76
+ pred_mask = pred_mask.astype(np.float32)
77
+
78
+ preds *= pred_mask
79
+ return preds, maxvals
80
+
81
+
82
+ def get_dir(src_point, rot_rad):
83
+ sn, cs = np.sin(rot_rad), np.cos(rot_rad)
84
+
85
+ src_result = [0, 0]
86
+ src_result[0] = src_point[0] * cs - src_point[1] * sn
87
+ src_result[1] = src_point[0] * sn + src_point[1] * cs
88
+
89
+ return src_result
90
+
91
+
92
+ def get_3rd_point(a, b):
93
+ direct = a - b
94
+ return b + np.array([-direct[1], direct[0]], dtype=np.float32)
95
+
96
+
97
+ def get_affine_transform(
98
+ center, scale, rot, output_size,
99
+ shift=np.array([0, 0], dtype=np.float32), inv=0
100
+ ):
101
+ if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
102
+ print(scale)
103
+ scale = np.array([scale, scale])
104
+
105
+ scale_tmp = scale * 200.0
106
+ src_w = scale_tmp[0]
107
+ dst_w = output_size[0]
108
+ dst_h = output_size[1]
109
+
110
+ rot_rad = np.pi * rot / 180
111
+ src_dir = get_dir([0, src_w * -0.5], rot_rad)
112
+ dst_dir = np.array([0, dst_w * -0.5], np.float32)
113
+
114
+ src = np.zeros((3, 2), dtype=np.float32)
115
+ dst = np.zeros((3, 2), dtype=np.float32)
116
+ src[0, :] = center + scale_tmp * shift
117
+ src[1, :] = center + src_dir + scale_tmp * shift
118
+ dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
119
+ dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir
120
+
121
+ src[2:, :] = get_3rd_point(src[0, :], src[1, :])
122
+ dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])
123
+
124
+ if inv:
125
+ trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
126
+ else:
127
+ trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))
128
+
129
+ return trans
130
+
131
+
132
+ def affine_transform(pt, t):
133
+ new_pt = np.array([pt[0], pt[1], 1.]).T
134
+ new_pt = np.dot(t, new_pt)
135
+ return new_pt[:2]
136
+
137
+
138
+ def transform_preds(coords, center, scale, output_size):
139
+ target_coords = np.zeros(coords.shape)
140
+ trans = get_affine_transform(center, scale, 0, output_size, inv=1)
141
+ for p in range(coords.shape[0]):
142
+ target_coords[p, 0:2] = affine_transform(coords[p, 0:2], trans)
143
+ return target_coords
144
+
145
+
146
+ def taylor(hm, coord):
147
+ heatmap_height = hm.shape[0]
148
+ heatmap_width = hm.shape[1]
149
+ px = int(coord[0])
150
+ py = int(coord[1])
151
+ if 1 < px < heatmap_width-2 and 1 < py < heatmap_height-2:
152
+ dx = 0.5 * (hm[py][px+1] - hm[py][px-1])
153
+ dy = 0.5 * (hm[py+1][px] - hm[py-1][px])
154
+ dxx = 0.25 * (hm[py][px+2] - 2 * hm[py][px] + hm[py][px-2])
155
+ dxy = 0.25 * (hm[py+1][px+1] - hm[py-1][px+1] - hm[py+1][px-1]
156
+ + hm[py-1][px-1])
157
+ dyy = 0.25 * (hm[py+2*1][px] - 2 * hm[py][px] + hm[py-2*1][px])
158
+ derivative = np.matrix([[dx], [dy]])
159
+ hessian = np.matrix([[dxx, dxy], [dxy, dyy]])
160
+ if dxx * dyy - dxy ** 2 != 0:
161
+ hessianinv = hessian.I
162
+ offset = -hessianinv * derivative
163
+ offset = np.squeeze(np.array(offset.T), axis=0)
164
+ coord += offset
165
+ return coord
166
+
167
+
168
+ def gaussian_blur(hm, kernel):
169
+ border = (kernel - 1) // 2
170
+ batch_size = hm.shape[0]
171
+ num_joints = hm.shape[1]
172
+ height = hm.shape[2]
173
+ width = hm.shape[3]
174
+ for i in range(batch_size):
175
+ for j in range(num_joints):
176
+ origin_max = np.max(hm[i, j])
177
+ dr = np.zeros((height + 2 * border, width + 2 * border))
178
+ dr[border: -border, border: -border] = hm[i, j].copy()
179
+ dr = cv2.GaussianBlur(dr, (kernel, kernel), 0)
180
+ hm[i, j] = dr[border: -border, border: -border].copy()
181
+ hm[i, j] *= origin_max / np.max(hm[i, j])
182
+ return hm
183
+
184
+
185
+ def get_final_preds(hm, center, scale, transform_back=True, test_blur_kernel=3):
186
+ coords, maxvals = get_max_preds(hm)
187
+ heatmap_height = hm.shape[2]
188
+ heatmap_width = hm.shape[3]
189
+
190
+ # post-processing
191
+ hm = gaussian_blur(hm, test_blur_kernel)
192
+ hm = np.maximum(hm, 1e-10)
193
+ hm = np.log(hm)
194
+ for n in range(coords.shape[0]):
195
+ for p in range(coords.shape[1]):
196
+ coords[n, p] = taylor(hm[n][p], coords[n][p])
197
+
198
+ preds = coords.copy()
199
+
200
+ if transform_back:
201
+ # Transform back
202
+ for i in range(coords.shape[0]):
203
+ preds[i] = transform_preds(
204
+ coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
205
+ )
206
+
207
+ return preds, maxvals
208
+
209
+ SKELETON = [
210
+ [1, 3], [1, 0], [2, 4], [2, 0], [0, 5], [0, 6], [5, 7], [7, 9], [6, 8], [8, 10], [5, 11], [6, 12], [11, 12],
211
+ [11, 13], [13, 15], [12, 14], [14, 16]
212
+ ]
213
+
214
+ CocoColors = [[255, 0, 0], [255, 85, 0], [255, 170, 0], [255, 255, 0], [170, 255, 0], [85, 255, 0], [0, 255, 0],
215
+ [0, 255, 85], [0, 255, 170], [0, 255, 255], [0, 170, 255], [0, 85, 255], [0, 0, 255], [85, 0, 255],
216
+ [170, 0, 255], [255, 0, 255], [255, 0, 170], [255, 0, 85]]
217
+
218
+ NUM_KPTS = 17
219
+
220
+
221
+ def get_person_detection_boxes(model, img, threshold=0.5):
222
+ pred = model(img)
223
+ pred_classes = [COCO_INSTANCE_CATEGORY_NAMES[i]
224
+ for i in list(pred[0]['labels'].cpu().numpy())] # Get the Prediction Score
225
+ pred_boxes = [[(i[0], i[1]), (i[2], i[3])]
226
+ for i in list(pred[0]['boxes'].detach().cpu().numpy())] # Bounding boxes
227
+ pred_score = list(pred[0]['scores'].detach().cpu().numpy())
228
+ if not pred_score or max(pred_score) < threshold:
229
+ return []
230
+ # Get list of index with score greater than threshold
231
+ pred_t = [pred_score.index(x) for x in pred_score if x > threshold][-1]
232
+ pred_boxes = pred_boxes[:pred_t + 1]
233
+ pred_classes = pred_classes[:pred_t + 1]
234
+
235
+ person_boxes = []
236
+ for idx, box in enumerate(pred_boxes):
237
+ if pred_classes[idx] == 'person':
238
+ person_boxes.append(box)
239
+
240
+ return person_boxes
241
+
242
+
243
+ def draw_pose(keypoints, img):
244
+ """draw the keypoints and the skeletons.
245
+ :params keypoints: the shape should be equal to [17,2]
246
+ :params img:
247
+ """
248
+ assert keypoints.shape == (NUM_KPTS, 2)
249
+ for i in range(len(SKELETON)):
250
+ kpt_a, kpt_b = SKELETON[i][0], SKELETON[i][1]
251
+ x_a, y_a = keypoints[kpt_a][0], keypoints[kpt_a][1]
252
+ x_b, y_b = keypoints[kpt_b][0], keypoints[kpt_b][1]
253
+ cv2.circle(img, (int(x_a), int(y_a)), 6, CocoColors[i], -1)
254
+ cv2.circle(img, (int(x_b), int(y_b)), 6, CocoColors[i], -1)
255
+ cv2.line(img, (int(x_a), int(y_a)), (int(x_b), int(y_b)), CocoColors[i], 2)
256
+
257
+
258
+ def box_to_center_scale(box, model_image_width, model_image_height):
259
+ """convert a box to center,scale information required for pose transformation
260
+ Parameters
261
+ ----------
262
+ box : list of tuple
263
+ list of length 2 with two tuples of floats representing
264
+ bottom left and top right corner of a box
265
+ model_image_width : int
266
+ model_image_height : int
267
+
268
+ Returns
269
+ -------
270
+ (numpy array, numpy array)
271
+ Two numpy arrays, coordinates for the center of the box and the scale of the box
272
+ """
273
+ center = np.zeros((2), dtype=np.float32)
274
+
275
+ bottom_left_corner = box[0]
276
+ top_right_corner = box[1]
277
+ box_width = top_right_corner[0] - bottom_left_corner[0]
278
+ box_height = top_right_corner[1] - bottom_left_corner[1]
279
+ bottom_left_x = bottom_left_corner[0]
280
+ bottom_left_y = bottom_left_corner[1]
281
+ center[0] = bottom_left_x + box_width * 0.5
282
+ center[1] = bottom_left_y + box_height * 0.5
283
+
284
+ aspect_ratio = model_image_width * 1.0 / model_image_height
285
+ pixel_std = 200
286
+
287
+ if box_width > aspect_ratio * box_height:
288
+ box_height = box_width * 1.0 / aspect_ratio
289
+ elif box_width < aspect_ratio * box_height:
290
+ box_width = box_height * aspect_ratio
291
+ scale = np.array(
292
+ [box_width * 1.0 / pixel_std, box_height * 1.0 / pixel_std],
293
+ dtype=np.float32)
294
+ if center[0] != -1:
295
+ scale = scale * 1.25
296
+
297
+ return center, scale
298
+
299
+
300
+ def get_pose_estimation_prediction(pose_model, image, center, scale):
301
+ rotation = 0
302
+ img_size = (256, 192)
303
+ # pose estimation transformation
304
+ trans = get_affine_transform(center, scale, rotation, img_size)
305
+ model_input = cv2.warpAffine(
306
+ image,
307
+ trans,
308
+ (int(img_size[0]), int(img_size[1])),
309
+ flags=cv2.INTER_LINEAR)
310
+ transform = transforms.Compose([
311
+ transforms.ToTensor(),
312
+ transforms.Normalize(mean=[0.485, 0.456, 0.406],
313
+ std=[0.229, 0.224, 0.225]),
314
+ ])
315
+
316
+ # pose estimation inference
317
+ model_input = transform(model_input).unsqueeze(0)
318
+ # switch to evaluate mode
319
+ pose_model.eval()
320
+ with torch.no_grad():
321
+ # compute output heatmap
322
+ output = pose_model(model_input)
323
+ preds, _ = get_final_preds(
324
+ output.clone().cpu().numpy(),
325
+ np.asarray([center]),
326
+ np.asarray([scale]))
327
+
328
+ return preds
329
+
330
+
331
+ def main(image_bgr, box_model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)):
332
+ CTX = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
333
+
334
+
335
+ box_model.to(CTX)
336
+ box_model.eval()
337
+ model = torch.hub.load('yangsenius/TransPose:main', 'tph_a4_256x192', pretrained=True)
338
+
339
+ img_dimensions = (256, 192)
340
+
341
+ input = []
342
+ image_rgb = image_bgr[:, :, [2, 1, 0]]
343
+ img = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
344
+ img_tensor = torch.from_numpy(img / 255.).permute(2, 0, 1).float().to(CTX)
345
+ input.append(img_tensor)
346
+
347
+ pred_boxes = get_person_detection_boxes(box_model, input, threshold=0.9)
348
+
349
+ if len(pred_boxes) >= 1:
350
+ for box in pred_boxes:
351
+ center, scale = box_to_center_scale(box, img_dimensions[0], img_dimensions[1])
352
+ image_pose = image_rgb.copy()
353
+ pose_preds = get_pose_estimation_prediction(model, image_pose, center, scale)
354
+ if len(pose_preds) >= 1:
355
+ for kpt in pose_preds:
356
+ draw_pose(kpt, image_bgr) # draw the poses
357
+
358
+ im = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB)
359
+ return im
360
+
361
+ title = "TransPose"
362
+ description = "Gradio demo for TransPose: Keypoint localization via Transformer. Dataset: COCO train2017 & COCO val2017."
363
+ article = "<div style='text-align: center;'><a href='https://github.com/yangsenius/TransPose' target='_blank'>Full credits: github.com/yangsenius/TransPose</a></div>"
364
+
365
+ examples = [["./examples/one.jpg"], ["./examples/two.jpg"]]
366
+
367
+ iface = gr.Interface(main, inputs=gr.inputs.Image(), outputs="image", description=description, article=article, title=title, examples=examples)
368
+ iface.launch(enable_queue=True, debug='True')