visiontest / test_predict_keypoints_video.py
tarto2's picture
Upload folder using huggingface_hub
e4189f9 verified
import argparse
import time
from pathlib import Path
from typing import List, Dict, Tuple
import sys
import os
import cv2
import numpy as np
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from miner import Miner
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Test keypoint prediction on video file with maximum speed optimization."
)
parser.add_argument(
"--repo-path",
type=Path,
default="",
help="Path to the HuggingFace/SecretVision repository (models, configs).",
)
parser.add_argument(
"--video-path",
type=Path,
default="test.mp4",
help="Path to the input video file.",
)
parser.add_argument(
"--output-video",
type=Path,
default="outputs-keypoints/annotated.mp4",
help="Optional path to save an annotated video with keypoints.",
)
parser.add_argument(
"--output-dir",
type=Path,
default="outputs-keypoints/frames",
help="Optional directory to save annotated frames.",
)
parser.add_argument(
"--batch-size",
type=int,
default=None,
help="Batch size for keypoint prediction (None = auto, processes all frames at once for max speed).",
)
parser.add_argument(
"--stride",
type=int,
default=1,
help="Sample every Nth frame from the video (1 = all frames).",
)
parser.add_argument(
"--max-frames",
type=int,
default=None,
help="Maximum number of frames to process (after stride).",
)
parser.add_argument(
"--n-keypoints",
type=int,
default=32,
help="Number of keypoints expected per frame.",
)
parser.add_argument(
"--conf-threshold",
type=float,
default=0.5,
help="Confidence threshold for regular keypoints.",
)
parser.add_argument(
"--corner-conf-threshold",
type=float,
default=0.3,
help="Confidence threshold for corner keypoints.",
)
parser.add_argument(
"--no-visualization",
action="store_true",
help="Skip visualization to maximize speed.",
)
return parser.parse_args()
def draw_keypoints(frame: np.ndarray, keypoints: List[Tuple[int, int]],
color: Tuple[int, int, int] = (0, 255, 255)) -> None:
"""Draw keypoints on frame."""
for x, y in keypoints:
if x == 0 and y == 0:
continue
cv2.circle(frame, (x, y), radius=3, color=color, thickness=-1)
cv2.circle(frame, (x, y), radius=5, color=(0, 0, 0), thickness=1)
def annotate_frame(frame: np.ndarray, keypoints: List[Tuple[int, int]],
frame_id: int) -> np.ndarray:
"""Annotate frame with keypoints and frame ID."""
annotated = frame.copy()
draw_keypoints(annotated, keypoints)
# Count valid keypoints
valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0)
# Draw frame info
info_text = f"Frame {frame_id} | Keypoints: {valid_count}/{len(keypoints)}"
cv2.putText(
annotated,
info_text,
(10, 30),
cv2.FONT_HERSHEY_SIMPLEX,
0.7,
(255, 255, 255),
2,
lineType=cv2.LINE_AA,
)
return annotated
def load_video_frames(video_path: Path, stride: int = 1, max_frames: int = None) -> List[np.ndarray]:
"""Load frames from video file."""
cap = cv2.VideoCapture(str(video_path))
if not cap.isOpened():
raise RuntimeError(f"Unable to open video: {video_path}")
frames = []
frame_count = 0
source_frame_idx = 0
print(f"Loading frames from video: {video_path}")
while True:
ret, frame = cap.read()
if not ret:
break
if source_frame_idx % stride != 0:
source_frame_idx += 1
continue
frames.append(frame)
frame_count += 1
source_frame_idx += 1
if max_frames and frame_count >= max_frames:
break
if frame_count % 100 == 0:
print(f"Loaded {frame_count} frames...")
cap.release()
print(f"Total frames loaded: {len(frames)}")
return frames
def save_results(
frames: List[np.ndarray],
keypoints_dict: Dict[int, List[Tuple[int, int]]],
output_video: Path = None,
output_dir: Path = None,
fps: float = 25.0,
width: int = None,
height: int = None,
) -> None:
"""Save annotated frames and/or video."""
if output_video is None and output_dir is None:
return
if width is None or height is None:
height, width = frames[0].shape[:2]
writer = None
if output_video:
output_video.parent.mkdir(parents=True, exist_ok=True)
writer = cv2.VideoWriter(
str(output_video),
cv2.VideoWriter_fourcc(*"mp4v"),
fps,
(width, height),
)
print(f"Saving annotated video to: {output_video}")
if output_dir:
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Saving annotated frames to: {output_dir}")
for frame_idx, frame in enumerate(frames):
keypoints = keypoints_dict.get(frame_idx, [])
annotated = annotate_frame(frame, keypoints, frame_idx)
if writer:
writer.write(annotated)
if output_dir:
frame_path = output_dir / f"frame_{frame_idx:06d}.jpg"
cv2.imwrite(str(frame_path), annotated)
if (frame_idx + 1) % 100 == 0:
print(f"Saved {frame_idx + 1}/{len(frames)} frames...")
if writer:
writer.release()
print(f"Video saved: {output_video}")
def calculate_statistics(keypoints_dict: Dict[int, List[Tuple[int, int]]]) -> Dict[str, float]:
"""Calculate keypoint detection statistics."""
total_frames = len(keypoints_dict)
if total_frames == 0:
return {
"total_frames": 0,
"avg_valid_keypoints": 0.0,
"max_valid_keypoints": 0,
"min_valid_keypoints": 0,
"frames_with_keypoints": 0,
}
valid_counts = []
frames_with_keypoints = 0
for keypoints in keypoints_dict.values():
valid_count = sum(1 for kp in keypoints if kp[0] != 0 or kp[1] != 0)
valid_counts.append(valid_count)
if valid_count > 0:
frames_with_keypoints += 1
return {
"total_frames": total_frames,
"avg_valid_keypoints": sum(valid_counts) / len(valid_counts) if valid_counts else 0.0,
"max_valid_keypoints": max(valid_counts) if valid_counts else 0,
"min_valid_keypoints": min(valid_counts) if valid_counts else 0,
"frames_with_keypoints": frames_with_keypoints,
"keypoint_detection_rate": frames_with_keypoints / total_frames if total_frames > 0 else 0.0,
}
def main() -> None:
args = parse_args()
# Initialize miner
print("Initializing Miner...")
init_start = time.time()
miner = Miner(args.repo_path)
init_time = time.time() - init_start
print(f"Miner initialized in {init_time:.2f} seconds")
# Load video frames
print("\n" + "="*60)
print("Loading video frames...")
load_start = time.time()
frames = load_video_frames(args.video_path, args.stride, args.max_frames)
load_time = time.time() - load_start
print(f"Frames loaded in {load_time:.2f} seconds")
if len(frames) == 0:
print("No frames loaded. Exiting.")
return
# Get video properties for output
height, width = frames[0].shape[:2]
cap = cv2.VideoCapture(str(args.video_path))
fps = cap.get(cv2.CAP_PROP_FPS) or 25.0
cap.release()
# Predict keypoints
print("\n" + "="*60)
print("Predicting keypoints...")
print(f"Total frames: {len(frames)}")
print(f"Batch size: {args.batch_size if args.batch_size else 'auto (all frames)'}")
print(f"Confidence threshold: {args.conf_threshold}")
print(f"Corner confidence threshold: {args.corner_conf_threshold}")
predict_start = time.time()
keypoints_dict = miner.predict_keypoints(
images=frames,
n_keypoints=args.n_keypoints,
batch_size=args.batch_size,
conf_threshold=args.conf_threshold,
corner_conf_threshold=args.corner_conf_threshold,
verbose=True,
)
predict_time = time.time() - predict_start
# Calculate performance metrics
total_frames = len(frames)
fps_achieved = total_frames / predict_time if predict_time > 0 else 0
time_per_frame = predict_time / total_frames if total_frames > 0 else 0
# Print performance results
print("\n" + "="*60)
print("KEYPOINT PREDICTION PERFORMANCE")
print("="*60)
print(f"Total frames processed: {total_frames}")
print(f"Total prediction time: {predict_time:.3f} seconds")
print(f"Average time per frame: {time_per_frame*1000:.2f} ms")
print(f"Throughput: {fps_achieved:.2f} FPS")
print(f"Batch processing: {'Yes' if args.batch_size else 'No (single batch)'}")
# Calculate and print statistics
stats = calculate_statistics(keypoints_dict)
print("\n" + "="*60)
print("KEYPOINT DETECTION STATISTICS")
print("="*60)
for key, value in stats.items():
if isinstance(value, float):
print(f"{key}: {value:.2f}")
else:
print(f"{key}: {value}")
# Save results if requested
if not args.no_visualization and (args.output_video or args.output_dir):
print("\n" + "="*60)
print("Saving results...")
save_start = time.time()
save_results(
frames=frames,
keypoints_dict=keypoints_dict,
output_video=args.output_video,
output_dir=args.output_dir,
fps=fps / args.stride,
width=width,
height=height,
)
save_time = time.time() - save_start
print(f"Results saved in {save_time:.2f} seconds")
print("\n" + "="*60)
print("Done!")
print("="*60)
if __name__ == "__main__":
main()