Amol Kaushik commited on
Commit
6678ad8
·
1 Parent(s): 8ea0417

A8: Add MoveNet pose estimator module (#33)

Browse files

- Create pose_estimator.py with MoveNet Lightning/Thunder support
- Add TensorFlow, TensorFlow Hub, OpenCV dependencies
- Include test image and annotated output
- Support image and video processing
- 17 COCO keypoint detection with skeleton visualization

Files changed (2) hide show
  1. A8/pose_estimator.py +439 -0
  2. requirements.txt +5 -0
A8/pose_estimator.py ADDED
@@ -0,0 +1,439 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MoveNet Pose Estimator Module
3
+ =============================
4
+ A Python module for human pose estimation using TensorFlow's MoveNet model.
5
+
6
+ This module provides functionality to:
7
+ - Load and run MoveNet pose estimation model
8
+ - Process images and videos
9
+ - Extract 17 COCO keypoints
10
+ - Visualize pose detection results
11
+
12
+ Issue #33 - A8: PoseNet/MoveNet Python Environment Setup
13
+ """
14
+
15
+ import os
16
+ import time
17
+ from typing import Dict, List, Optional, Tuple, Union
18
+
19
+ import cv2
20
+ import numpy as np
21
+ import tensorflow as tf
22
+ import tensorflow_hub as hub
23
+
24
+
25
+ # COCO Keypoint definitions (17 keypoints)
26
+ KEYPOINT_NAMES = [
27
+ 'nose',
28
+ 'left_eye',
29
+ 'right_eye',
30
+ 'left_ear',
31
+ 'right_ear',
32
+ 'left_shoulder',
33
+ 'right_shoulder',
34
+ 'left_elbow',
35
+ 'right_elbow',
36
+ 'left_wrist',
37
+ 'right_wrist',
38
+ 'left_hip',
39
+ 'right_hip',
40
+ 'left_knee',
41
+ 'right_knee',
42
+ 'left_ankle',
43
+ 'right_ankle'
44
+ ]
45
+
46
+ # Skeleton connections for visualization
47
+ KEYPOINT_EDGES = {
48
+ (0, 1): 'face',
49
+ (0, 2): 'face',
50
+ (1, 3): 'face',
51
+ (2, 4): 'face',
52
+ (0, 5): 'torso',
53
+ (0, 6): 'torso',
54
+ (5, 7): 'left_arm',
55
+ (7, 9): 'left_arm',
56
+ (6, 8): 'right_arm',
57
+ (8, 10): 'right_arm',
58
+ (5, 6): 'torso',
59
+ (5, 11): 'torso',
60
+ (6, 12): 'torso',
61
+ (11, 12): 'torso',
62
+ (11, 13): 'left_leg',
63
+ (13, 15): 'left_leg',
64
+ (12, 14): 'right_leg',
65
+ (14, 16): 'right_leg',
66
+ }
67
+
68
+ # Colors for different body parts (BGR format for OpenCV)
69
+ EDGE_COLORS = {
70
+ 'face': (255, 255, 0), # Cyan
71
+ 'torso': (0, 255, 0), # Green
72
+ 'left_arm': (255, 0, 0), # Blue
73
+ 'right_arm': (0, 0, 255), # Red
74
+ 'left_leg': (255, 165, 0), # Orange
75
+ 'right_leg': (128, 0, 128), # Purple
76
+ }
77
+
78
+
79
+ class MoveNetPoseEstimator:
80
+ """
81
+ MoveNet-based human pose estimator.
82
+
83
+ Supports two model variants:
84
+ - 'lightning': Faster, lower accuracy (default)
85
+ - 'thunder': Slower, higher accuracy
86
+
87
+ Example usage:
88
+ estimator = MoveNetPoseEstimator(model_name='lightning')
89
+ keypoints = estimator.detect_pose(image)
90
+ visualized = estimator.draw_keypoints(image, keypoints)
91
+ """
92
+
93
+ # TensorFlow Hub model URLs
94
+ MODEL_URLS = {
95
+ 'lightning': 'https://tfhub.dev/google/movenet/singlepose/lightning/4',
96
+ 'thunder': 'https://tfhub.dev/google/movenet/singlepose/thunder/4',
97
+ }
98
+
99
+ # Input sizes for each model
100
+ INPUT_SIZES = {
101
+ 'lightning': 192,
102
+ 'thunder': 256,
103
+ }
104
+
105
+ def __init__(self, model_name: str = 'lightning'):
106
+ """
107
+ Initialize the MoveNet pose estimator.
108
+
109
+ Args:
110
+ model_name: Model variant ('lightning' or 'thunder')
111
+ """
112
+ if model_name not in self.MODEL_URLS:
113
+ raise ValueError(f"Model must be one of: {list(self.MODEL_URLS.keys())}")
114
+
115
+ self.model_name = model_name
116
+ self.input_size = self.INPUT_SIZES[model_name]
117
+
118
+ print(f"Loading MoveNet {model_name} model...")
119
+ self.model = hub.load(self.MODEL_URLS[model_name])
120
+ self.movenet = self.model.signatures['serving_default']
121
+ print(f"Model loaded successfully. Input size: {self.input_size}x{self.input_size}")
122
+
123
+ def preprocess_image(self, image: np.ndarray) -> tf.Tensor:
124
+ """
125
+ Preprocess image for MoveNet inference.
126
+
127
+ Args:
128
+ image: Input image (BGR or RGB format, any size)
129
+
130
+ Returns:
131
+ Preprocessed tensor ready for inference
132
+ """
133
+ # Convert BGR to RGB if needed (OpenCV loads as BGR)
134
+ if len(image.shape) == 3 and image.shape[2] == 3:
135
+ image_rgb = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
136
+ else:
137
+ image_rgb = image
138
+
139
+ # Resize to model input size
140
+ input_image = tf.image.resize_with_pad(
141
+ tf.expand_dims(image_rgb, axis=0),
142
+ self.input_size,
143
+ self.input_size
144
+ )
145
+
146
+ # Convert to int32 as required by MoveNet
147
+ input_image = tf.cast(input_image, dtype=tf.int32)
148
+
149
+ return input_image
150
+
151
+ def detect_pose(self, image: np.ndarray) -> Dict:
152
+ """
153
+ Detect pose keypoints in an image.
154
+
155
+ Args:
156
+ image: Input image (BGR format from OpenCV)
157
+
158
+ Returns:
159
+ Dictionary with keypoint data:
160
+ {
161
+ 'keypoints': {
162
+ 'nose': {'x': float, 'y': float, 'confidence': float},
163
+ ...
164
+ },
165
+ 'inference_time_ms': float
166
+ }
167
+ """
168
+ start_time = time.time()
169
+
170
+ # Preprocess
171
+ input_tensor = self.preprocess_image(image)
172
+
173
+ # Run inference
174
+ outputs = self.movenet(input_tensor)
175
+ keypoints_with_scores = outputs['output_0'].numpy()[0, 0, :, :]
176
+
177
+ inference_time = (time.time() - start_time) * 1000
178
+
179
+ # Parse keypoints
180
+ keypoints_dict = {}
181
+ for i, name in enumerate(KEYPOINT_NAMES):
182
+ y, x, confidence = keypoints_with_scores[i]
183
+ keypoints_dict[name] = {
184
+ 'x': float(x),
185
+ 'y': float(y),
186
+ 'confidence': float(confidence)
187
+ }
188
+
189
+ return {
190
+ 'keypoints': keypoints_dict,
191
+ 'inference_time_ms': inference_time
192
+ }
193
+
194
+ def detect_pose_raw(self, image: np.ndarray) -> np.ndarray:
195
+ """
196
+ Detect pose and return raw keypoints array.
197
+
198
+ Args:
199
+ image: Input image (BGR format)
200
+
201
+ Returns:
202
+ Array of shape (17, 3) with [y, x, confidence] for each keypoint
203
+ """
204
+ input_tensor = self.preprocess_image(image)
205
+ outputs = self.movenet(input_tensor)
206
+ return outputs['output_0'].numpy()[0, 0, :, :]
207
+
208
+ def draw_keypoints(
209
+ self,
210
+ image: np.ndarray,
211
+ keypoints: Dict,
212
+ confidence_threshold: float = 0.3,
213
+ circle_radius: int = 5,
214
+ line_thickness: int = 2
215
+ ) -> np.ndarray:
216
+ """
217
+ Draw detected keypoints and skeleton on image.
218
+
219
+ Args:
220
+ image: Input image (will be copied, not modified)
221
+ keypoints: Keypoint dictionary from detect_pose()
222
+ confidence_threshold: Minimum confidence to draw keypoint
223
+ circle_radius: Radius of keypoint circles
224
+ line_thickness: Thickness of skeleton lines
225
+
226
+ Returns:
227
+ Image with keypoints and skeleton drawn
228
+ """
229
+ output_image = image.copy()
230
+ height, width = image.shape[:2]
231
+
232
+ kps = keypoints['keypoints']
233
+
234
+ # Draw skeleton edges first (so keypoints appear on top)
235
+ for (start_idx, end_idx), body_part in KEYPOINT_EDGES.items():
236
+ start_name = KEYPOINT_NAMES[start_idx]
237
+ end_name = KEYPOINT_NAMES[end_idx]
238
+
239
+ start_kp = kps[start_name]
240
+ end_kp = kps[end_name]
241
+
242
+ if start_kp['confidence'] > confidence_threshold and end_kp['confidence'] > confidence_threshold:
243
+ start_point = (int(start_kp['x'] * width), int(start_kp['y'] * height))
244
+ end_point = (int(end_kp['x'] * width), int(end_kp['y'] * height))
245
+ color = EDGE_COLORS[body_part]
246
+ cv2.line(output_image, start_point, end_point, color, line_thickness)
247
+
248
+ # Draw keypoints
249
+ for name, kp in kps.items():
250
+ if kp['confidence'] > confidence_threshold:
251
+ x = int(kp['x'] * width)
252
+ y = int(kp['y'] * height)
253
+ cv2.circle(output_image, (x, y), circle_radius, (0, 255, 255), -1)
254
+ cv2.circle(output_image, (x, y), circle_radius, (0, 0, 0), 1)
255
+
256
+ return output_image
257
+
258
+ def process_video(
259
+ self,
260
+ video_path: str,
261
+ output_path: Optional[str] = None,
262
+ show_preview: bool = False,
263
+ confidence_threshold: float = 0.3
264
+ ) -> List[Dict]:
265
+ """
266
+ Process a video file and extract keypoints from each frame.
267
+
268
+ Args:
269
+ video_path: Path to input video file
270
+ output_path: Optional path to save annotated video
271
+ show_preview: Whether to show live preview (press 'q' to quit)
272
+ confidence_threshold: Minimum confidence for visualization
273
+
274
+ Returns:
275
+ List of keypoint dictionaries, one per frame
276
+ """
277
+ cap = cv2.VideoCapture(video_path)
278
+
279
+ if not cap.isOpened():
280
+ raise ValueError(f"Could not open video: {video_path}")
281
+
282
+ # Get video properties
283
+ fps = cap.get(cv2.CAP_PROP_FPS)
284
+ width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
285
+ height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
286
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
287
+
288
+ print(f"Video: {video_path}")
289
+ print(f"Resolution: {width}x{height}, FPS: {fps:.2f}, Frames: {total_frames}")
290
+
291
+ # Setup video writer if output path specified
292
+ writer = None
293
+ if output_path:
294
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
295
+ writer = cv2.VideoWriter(output_path, fourcc, fps, (width, height))
296
+
297
+ all_keypoints = []
298
+ frame_idx = 0
299
+
300
+ while True:
301
+ ret, frame = cap.read()
302
+ if not ret:
303
+ break
304
+
305
+ # Detect pose
306
+ result = self.detect_pose(frame)
307
+ result['frame_id'] = frame_idx
308
+ result['timestamp'] = frame_idx / fps if fps > 0 else 0
309
+ all_keypoints.append(result)
310
+
311
+ # Draw and optionally show/save
312
+ annotated_frame = self.draw_keypoints(frame, result, confidence_threshold)
313
+
314
+ if writer:
315
+ writer.write(annotated_frame)
316
+
317
+ if show_preview:
318
+ cv2.imshow('Pose Estimation', annotated_frame)
319
+ if cv2.waitKey(1) & 0xFF == ord('q'):
320
+ break
321
+
322
+ frame_idx += 1
323
+ if frame_idx % 30 == 0:
324
+ print(f"Processed {frame_idx}/{total_frames} frames...")
325
+
326
+ cap.release()
327
+ if writer:
328
+ writer.release()
329
+ if show_preview:
330
+ cv2.destroyAllWindows()
331
+
332
+ print(f"Completed! Processed {frame_idx} frames.")
333
+ avg_inference = np.mean([r['inference_time_ms'] for r in all_keypoints])
334
+ print(f"Average inference time: {avg_inference:.2f} ms/frame")
335
+
336
+ return all_keypoints
337
+
338
+ def process_image_file(
339
+ self,
340
+ image_path: str,
341
+ output_path: Optional[str] = None,
342
+ confidence_threshold: float = 0.3
343
+ ) -> Dict:
344
+ """
345
+ Process a single image file.
346
+
347
+ Args:
348
+ image_path: Path to input image
349
+ output_path: Optional path to save annotated image
350
+ confidence_threshold: Minimum confidence for visualization
351
+
352
+ Returns:
353
+ Keypoint dictionary for the image
354
+ """
355
+ image = cv2.imread(image_path)
356
+ if image is None:
357
+ raise ValueError(f"Could not read image: {image_path}")
358
+
359
+ result = self.detect_pose(image)
360
+
361
+ if output_path:
362
+ annotated = self.draw_keypoints(image, result, confidence_threshold)
363
+ cv2.imwrite(output_path, annotated)
364
+ print(f"Saved annotated image to: {output_path}")
365
+
366
+ return result
367
+
368
+
369
+ def main():
370
+ """Demo: Test the pose estimator on a sample image or webcam."""
371
+ import argparse
372
+
373
+ parser = argparse.ArgumentParser(description='MoveNet Pose Estimation Demo')
374
+ parser.add_argument('--model', choices=['lightning', 'thunder'], default='lightning',
375
+ help='Model variant (default: lightning)')
376
+ parser.add_argument('--image', type=str, help='Path to input image')
377
+ parser.add_argument('--video', type=str, help='Path to input video')
378
+ parser.add_argument('--webcam', action='store_true', help='Use webcam')
379
+ parser.add_argument('--output', type=str, help='Output path for annotated image/video')
380
+ args = parser.parse_args()
381
+
382
+ # Initialize estimator
383
+ estimator = MoveNetPoseEstimator(model_name=args.model)
384
+
385
+ if args.image:
386
+ # Process image
387
+ print(f"\nProcessing image: {args.image}")
388
+ result = estimator.process_image_file(
389
+ args.image,
390
+ output_path=args.output
391
+ )
392
+ print(f"Inference time: {result['inference_time_ms']:.2f} ms")
393
+ print("\nDetected keypoints:")
394
+ for name, kp in result['keypoints'].items():
395
+ if kp['confidence'] > 0.3:
396
+ print(f" {name}: ({kp['x']:.3f}, {kp['y']:.3f}) conf={kp['confidence']:.3f}")
397
+
398
+ elif args.video:
399
+ # Process video
400
+ print(f"\nProcessing video: {args.video}")
401
+ keypoints = estimator.process_video(
402
+ args.video,
403
+ output_path=args.output,
404
+ show_preview=True
405
+ )
406
+ print(f"\nExtracted keypoints from {len(keypoints)} frames")
407
+
408
+ elif args.webcam:
409
+ # Webcam demo
410
+ print("\nStarting webcam demo (press 'q' to quit)...")
411
+ cap = cv2.VideoCapture(0)
412
+
413
+ while True:
414
+ ret, frame = cap.read()
415
+ if not ret:
416
+ break
417
+
418
+ result = estimator.detect_pose(frame)
419
+ annotated = estimator.draw_keypoints(frame, result)
420
+
421
+ # Add FPS display
422
+ fps_text = f"Inference: {result['inference_time_ms']:.1f} ms"
423
+ cv2.putText(annotated, fps_text, (10, 30),
424
+ cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
425
+
426
+ cv2.imshow('MoveNet Pose Estimation', annotated)
427
+ if cv2.waitKey(1) & 0xFF == ord('q'):
428
+ break
429
+
430
+ cap.release()
431
+ cv2.destroyAllWindows()
432
+
433
+ else:
434
+ print("Please specify --image, --video, or --webcam")
435
+ print("Example: python pose_estimator.py --image test.jpg --output result.jpg")
436
+
437
+
438
+ if __name__ == '__main__':
439
+ main()
requirements.txt CHANGED
@@ -8,6 +8,11 @@ gdown==5.2.0
8
  xgboost==3.2.0
9
  lightgbm==4.6.0
10
 
 
 
 
 
 
11
  pytest==8.3.4
12
  pytest-cov==6.0.0
13
 
 
8
  xgboost==3.2.0
9
  lightgbm==4.6.0
10
 
11
+ # A8: Deep Learning / Pose Estimation
12
+ tensorflow>=2.21.0
13
+ tensorflow-hub>=0.16.1
14
+ opencv-python>=4.10.0
15
+
16
  pytest==8.3.4
17
  pytest-cov==6.0.0
18