| import argparse
|
| import time
|
| from pathlib import Path
|
| from typing import List, Dict, Tuple
|
| import sys
|
| import os
|
|
|
| import cv2
|
| import numpy as np
|
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
| from miner import Miner
|
|
|
|
|
| def parse_args() -> argparse.Namespace:
|
| parser = argparse.ArgumentParser(
|
| description="Test keypoint prediction on video file with maximum speed optimization."
|
| )
|
| parser.add_argument(
|
| "--repo-path",
|
| type=Path,
|
| default="",
|
| help="Path to the HuggingFace/SecretVision repository (models, configs).",
|
| )
|
| parser.add_argument(
|
| "--video-path",
|
| type=Path,
|
| default="test.mp4",
|
| help="Path to the input video file.",
|
| )
|
| parser.add_argument(
|
| "--output-video",
|
| type=Path,
|
| default="outputs-keypoints/annotated.mp4",
|
| help="Optional path to save an annotated video with keypoints.",
|
| )
|
| parser.add_argument(
|
| "--output-dir",
|
| type=Path,
|
| default="outputs-keypoints/frames",
|
| help="Optional directory to save annotated frames.",
|
| )
|
| parser.add_argument(
|
| "--batch-size",
|
| type=int,
|
| default=None,
|
| help="Batch size for keypoint prediction (None = auto, processes all frames at once for max speed).",
|
| )
|
| parser.add_argument(
|
| "--stride",
|
| type=int,
|
| default=1,
|
| help="Sample every Nth frame from the video (1 = all frames).",
|
| )
|
| parser.add_argument(
|
| "--max-frames",
|
| type=int,
|
| default=None,
|
| help="Maximum number of frames to process (after stride).",
|
| )
|
| parser.add_argument(
|
| "--n-keypoints",
|
| type=int,
|
| default=32,
|
| help="Number of keypoints expected per frame.",
|
| )
|
| parser.add_argument(
|
| "--conf-threshold",
|
| type=float,
|
| default=0.5,
|
| help="Confidence threshold for regular keypoints.",
|
| )
|
| parser.add_argument(
|
| "--corner-conf-threshold",
|
| type=float,
|
| default=0.3,
|
| help="Confidence threshold for corner keypoints.",
|
| )
|
| parser.add_argument(
|
| "--no-visualization",
|
| action="store_true",
|
| help="Skip visualization to maximize speed.",
|
| )
|
| return parser.parse_args()
|
|
|
|
|
| def draw_keypoints(frame: np.ndarray, keypoints: List[Tuple[int, int]],
|
| color: Tuple[int, int, int] = (0, 255, 255)) -> None:
|
| """Draw keypoints on frame."""
|
| for x, y in keypoints:
|
| if x == 0 and y == 0:
|
| continue
|
| cv2.circle(frame, (x, y), radius=3, color=color, thickness=-1)
|
| cv2.circle(frame, (x, y), radius=5, color=(0, 0, 0), thickness=1)
|
|
|
|
|
| def annotate_frame(frame: np.ndarray, keypoints: List[Tuple[int, int]],
|
| frame_id: int) -> np.ndarray:
|
| """Annotate frame with keypoints and frame ID."""
|
| annotated = frame.copy()
|
| draw_keypoints(annotated, keypoints)
|
|
|
|
|
| valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0)
|
|
|
|
|
| info_text = f"Frame {frame_id} | Keypoints: {valid_count}/{len(keypoints)}"
|
| cv2.putText(
|
| annotated,
|
| info_text,
|
| (10, 30),
|
| cv2.FONT_HERSHEY_SIMPLEX,
|
| 0.7,
|
| (255, 255, 255),
|
| 2,
|
| lineType=cv2.LINE_AA,
|
| )
|
| return annotated
|
|
|
|
|
| def load_video_frames(video_path: Path, stride: int = 1, max_frames: int = None) -> List[np.ndarray]:
|
| """Load frames from video file."""
|
| cap = cv2.VideoCapture(str(video_path))
|
| if not cap.isOpened():
|
| raise RuntimeError(f"Unable to open video: {video_path}")
|
|
|
| frames = []
|
| frame_count = 0
|
| source_frame_idx = 0
|
|
|
| print(f"Loading frames from video: {video_path}")
|
| while True:
|
| ret, frame = cap.read()
|
| if not ret:
|
| break
|
|
|
| if source_frame_idx % stride != 0:
|
| source_frame_idx += 1
|
| continue
|
|
|
| frames.append(frame)
|
| frame_count += 1
|
| source_frame_idx += 1
|
|
|
| if max_frames and frame_count >= max_frames:
|
| break
|
|
|
| if frame_count % 100 == 0:
|
| print(f"Loaded {frame_count} frames...")
|
|
|
| cap.release()
|
| print(f"Total frames loaded: {len(frames)}")
|
| return frames
|
|
|
|
|
| def save_results(
|
| frames: List[np.ndarray],
|
| keypoints_dict: Dict[int, List[Tuple[int, int]]],
|
| output_video: Path = None,
|
| output_dir: Path = None,
|
| fps: float = 25.0,
|
| width: int = None,
|
| height: int = None,
|
| ) -> None:
|
| """Save annotated frames and/or video."""
|
| if output_video is None and output_dir is None:
|
| return
|
|
|
| if width is None or height is None:
|
| height, width = frames[0].shape[:2]
|
|
|
| writer = None
|
| if output_video:
|
| output_video.parent.mkdir(parents=True, exist_ok=True)
|
| writer = cv2.VideoWriter(
|
| str(output_video),
|
| cv2.VideoWriter_fourcc(*"mp4v"),
|
| fps,
|
| (width, height),
|
| )
|
| print(f"Saving annotated video to: {output_video}")
|
|
|
| if output_dir:
|
| output_dir.mkdir(parents=True, exist_ok=True)
|
| print(f"Saving annotated frames to: {output_dir}")
|
|
|
| for frame_idx, frame in enumerate(frames):
|
| keypoints = keypoints_dict.get(frame_idx, [])
|
| annotated = annotate_frame(frame, keypoints, frame_idx)
|
|
|
| if writer:
|
| writer.write(annotated)
|
|
|
| if output_dir:
|
| frame_path = output_dir / f"frame_{frame_idx:06d}.jpg"
|
| cv2.imwrite(str(frame_path), annotated)
|
|
|
| if (frame_idx + 1) % 100 == 0:
|
| print(f"Saved {frame_idx + 1}/{len(frames)} frames...")
|
|
|
| if writer:
|
| writer.release()
|
| print(f"Video saved: {output_video}")
|
|
|
|
|
| def calculate_statistics(keypoints_dict: Dict[int, List[Tuple[int, int]]]) -> Dict[str, float]:
|
| """Calculate keypoint detection statistics."""
|
| total_frames = len(keypoints_dict)
|
| if total_frames == 0:
|
| return {
|
| "total_frames": 0,
|
| "avg_valid_keypoints": 0.0,
|
| "max_valid_keypoints": 0,
|
| "min_valid_keypoints": 0,
|
| "frames_with_keypoints": 0,
|
| }
|
|
|
| valid_counts = []
|
| frames_with_keypoints = 0
|
|
|
| for keypoints in keypoints_dict.values():
|
| valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0)
|
| valid_counts.append(valid_count)
|
| if valid_count > 0:
|
| frames_with_keypoints += 1
|
|
|
| return {
|
| "total_frames": total_frames,
|
| "avg_valid_keypoints": sum(valid_counts) / len(valid_counts) if valid_counts else 0.0,
|
| "max_valid_keypoints": max(valid_counts) if valid_counts else 0,
|
| "min_valid_keypoints": min(valid_counts) if valid_counts else 0,
|
| "frames_with_keypoints": frames_with_keypoints,
|
| "keypoint_detection_rate": frames_with_keypoints / total_frames if total_frames > 0 else 0.0,
|
| }
|
|
|
|
|
| def main() -> None:
|
| args = parse_args()
|
|
|
|
|
| print("Initializing Miner...")
|
| init_start = time.time()
|
| miner = Miner(args.repo_path)
|
| init_time = time.time() - init_start
|
| print(f"Miner initialized in {init_time:.2f} seconds")
|
|
|
|
|
| print("\n" + "="*60)
|
| print("Loading video frames...")
|
| load_start = time.time()
|
| frames = load_video_frames(args.video_path, args.stride, args.max_frames)
|
| load_time = time.time() - load_start
|
| print(f"Frames loaded in {load_time:.2f} seconds")
|
|
|
| if len(frames) == 0:
|
| print("No frames loaded. Exiting.")
|
| return
|
|
|
|
|
| height, width = frames[0].shape[:2]
|
| cap = cv2.VideoCapture(str(args.video_path))
|
| fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
|
| cap.release()
|
|
|
|
|
| print("\n" + "="*60)
|
| print("Predicting keypoints...")
|
| print(f"Total frames: {len(frames)}")
|
| print(f"Batch size: {args.batch_size if args.batch_size else 'auto (all frames)'}")
|
| print(f"Confidence threshold: {args.conf_threshold}")
|
| print(f"Corner confidence threshold: {args.corner_conf_threshold}")
|
|
|
| predict_start = time.time()
|
| keypoints_dict = miner.predict_keypoints(
|
| images=frames,
|
| n_keypoints=args.n_keypoints,
|
| batch_size=args.batch_size,
|
| conf_threshold=args.conf_threshold,
|
| corner_conf_threshold=args.corner_conf_threshold,
|
| verbose=True,
|
| )
|
| predict_time = time.time() - predict_start
|
|
|
|
|
| total_frames = len(frames)
|
| fps_achieved = total_frames / predict_time if predict_time > 0 else 0
|
| time_per_frame = predict_time / total_frames if total_frames > 0 else 0
|
|
|
|
|
| print("\n" + "="*60)
|
| print("KEYPOINT PREDICTION PERFORMANCE")
|
| print("="*60)
|
| print(f"Total frames processed: {total_frames}")
|
| print(f"Total prediction time: {predict_time:.3f} seconds")
|
| print(f"Average time per frame: {time_per_frame*1000:.2f} ms")
|
| print(f"Throughput: {fps_achieved:.2f} FPS")
|
| print(f"Batch processing: {'Yes' if args.batch_size else 'No (single batch)'}")
|
|
|
|
|
| stats = calculate_statistics(keypoints_dict)
|
| print("\n" + "="*60)
|
| print("KEYPOINT DETECTION STATISTICS")
|
| print("="*60)
|
| for key, value in stats.items():
|
| if isinstance(value, float):
|
| print(f"{key}: {value:.2f}")
|
| else:
|
| print(f"{key}: {value}")
|
|
|
|
|
| if not args.no_visualization and (args.output_video or args.output_dir):
|
| print("\n" + "="*60)
|
| print("Saving results...")
|
| save_start = time.time()
|
| save_results(
|
| frames=frames,
|
| keypoints_dict=keypoints_dict,
|
| output_video=args.output_video,
|
| output_dir=args.output_dir,
|
| fps=fps / args.stride,
|
| width=width,
|
| height=height,
|
| )
|
| save_time = time.time() - save_start
|
| print(f"Results saved in {save_time:.2f} seconds")
|
|
|
| print("\n" + "="*60)
|
| print("Done!")
|
| print("="*60)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|
|
|