Spaces:
Sleeping
Sleeping
Deploy gesture detection & validation API
Browse files- Add Docker container with FastAPI application
- Include gesture detection and identity validation endpoints
- Add ONNX models for hand detection and classification
- Provide comprehensive API documentation
- Support for multiple gesture types: thumbs_up, peace, ok_sign, open_palm, call_me, grabbing
- Facial validation in placeholder mode (always returns success)
- Gesture validation fully functional with configurable parameters
This view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +57 -0
- README.md +121 -6
- main.py +337 -0
- models/crops_classifier.onnx +3 -0
- models/hand_detector.onnx +3 -0
- requirements.txt +13 -0
- src/.DS_Store +0 -0
- src/facialembeddingsmatch/__init__.py +15 -0
- src/facialembeddingsmatch/__pycache__/__init__.cpython-312.pyc +0 -0
- src/facialembeddingsmatch/__pycache__/facial_matcher.cpython-312.pyc +0 -0
- src/facialembeddingsmatch/facial_matcher.py +433 -0
- src/gesturedetection/.DS_Store +0 -0
- src/gesturedetection/__init__.py +23 -0
- src/gesturedetection/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gesturedetection/__pycache__/api.cpython-312.pyc +0 -0
- src/gesturedetection/__pycache__/config.cpython-312.pyc +0 -0
- src/gesturedetection/__pycache__/main_controller.cpython-312.pyc +0 -0
- src/gesturedetection/__pycache__/models.cpython-312.pyc +0 -0
- src/gesturedetection/__pycache__/onnx_models.cpython-312.pyc +0 -0
- src/gesturedetection/api.py +318 -0
- src/gesturedetection/config.py +55 -0
- src/gesturedetection/main_controller.py +271 -0
- src/gesturedetection/models.py +89 -0
- src/gesturedetection/ocsort/__init__.py +2 -0
- src/gesturedetection/ocsort/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/__init__.cpython-39.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/association.cpython-312.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/association.cpython-39.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/kalmanboxtracker.cpython-312.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/kalmanboxtracker.cpython-39.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/kalmanfilter.cpython-312.pyc +0 -0
- src/gesturedetection/ocsort/__pycache__/kalmanfilter.cpython-39.pyc +0 -0
- src/gesturedetection/ocsort/association.py +511 -0
- src/gesturedetection/ocsort/kalmanboxtracker.py +157 -0
- src/gesturedetection/ocsort/kalmanfilter.py +1557 -0
- src/gesturedetection/onnx_models.py +194 -0
- src/gesturedetection/utils/__init__.py +16 -0
- src/gesturedetection/utils/__pycache__/__init__.cpython-312.pyc +0 -0
- src/gesturedetection/utils/__pycache__/__init__.cpython-39.pyc +0 -0
- src/gesturedetection/utils/__pycache__/action_controller.cpython-312.pyc +0 -0
- src/gesturedetection/utils/__pycache__/action_controller.cpython-39.pyc +0 -0
- src/gesturedetection/utils/__pycache__/box_utils_numpy.cpython-312.pyc +0 -0
- src/gesturedetection/utils/__pycache__/box_utils_numpy.cpython-39.pyc +0 -0
- src/gesturedetection/utils/__pycache__/drawer.cpython-312.pyc +0 -0
- src/gesturedetection/utils/__pycache__/drawer.cpython-39.pyc +0 -0
- src/gesturedetection/utils/__pycache__/enums.cpython-312.pyc +0 -0
- src/gesturedetection/utils/__pycache__/enums.cpython-39.pyc +0 -0
- src/gesturedetection/utils/__pycache__/hand.cpython-312.pyc +0 -0
- src/gesturedetection/utils/__pycache__/hand.cpython-39.pyc +0 -0
- src/gesturedetection/utils/action_controller.py +598 -0
Dockerfile
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use Python 3.12 as base image
|
| 2 |
+
FROM python:3.12-slim
|
| 3 |
+
|
| 4 |
+
# Install system dependencies including OpenCV requirements
|
| 5 |
+
RUN apt-get update && apt-get install -y \
|
| 6 |
+
curl \
|
| 7 |
+
libgl1 \
|
| 8 |
+
libglib2.0-0 \
|
| 9 |
+
libsm6 \
|
| 10 |
+
libxext6 \
|
| 11 |
+
libxrender1 \
|
| 12 |
+
libgomp1 \
|
| 13 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 14 |
+
|
| 15 |
+
# Set up a new user named "user" with user ID 1000 (HF Spaces requirement)
|
| 16 |
+
RUN useradd -m -u 1000 user
|
| 17 |
+
|
| 18 |
+
# Switch to the "user" user
|
| 19 |
+
USER user
|
| 20 |
+
|
| 21 |
+
# Set home to the user's home directory
|
| 22 |
+
ENV HOME=/home/user \
|
| 23 |
+
PATH=/home/user/.local/bin:$PATH
|
| 24 |
+
|
| 25 |
+
# Set the working directory to the user's home directory
|
| 26 |
+
WORKDIR $HOME/app
|
| 27 |
+
|
| 28 |
+
# Upgrade pip and install dependencies
|
| 29 |
+
RUN pip install --no-cache-dir --upgrade pip
|
| 30 |
+
|
| 31 |
+
# Copy requirements first for better Docker layer caching
|
| 32 |
+
COPY --chown=user docker/requirements.txt $HOME/app/
|
| 33 |
+
|
| 34 |
+
# Install Python dependencies
|
| 35 |
+
RUN pip install --no-cache-dir --user -r requirements.txt
|
| 36 |
+
|
| 37 |
+
# Copy the source code from parent directory
|
| 38 |
+
COPY --chown=user ../src/ $HOME/app/src/
|
| 39 |
+
COPY --chown=user ../models/ $HOME/app/models/
|
| 40 |
+
|
| 41 |
+
# Copy the main entry point from parent directory
|
| 42 |
+
COPY --chown=user ../main.py $HOME/app/
|
| 43 |
+
COPY --chown=user ../README.md $HOME/app/
|
| 44 |
+
|
| 45 |
+
# Expose the port that the app runs on (HF Spaces default is 7860)
|
| 46 |
+
EXPOSE 7860
|
| 47 |
+
|
| 48 |
+
# Set environment variables
|
| 49 |
+
ENV PYTHONPATH=$HOME/app
|
| 50 |
+
ENV PORT=7860
|
| 51 |
+
|
| 52 |
+
# Health check to ensure the API is running
|
| 53 |
+
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
| 54 |
+
CMD curl -f http://localhost:7860/health || exit 1
|
| 55 |
+
|
| 56 |
+
# Start the application directly
|
| 57 |
+
CMD ["python", "main.py"]
|
README.md
CHANGED
|
@@ -1,11 +1,126 @@
|
|
| 1 |
---
|
| 2 |
-
title: Validation
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
-
license:
|
|
|
|
| 9 |
---
|
| 10 |
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
---
|
| 2 |
+
title: Gesture Detection & Identity Validation API
|
| 3 |
+
emoji: 👋
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: purple
|
| 6 |
sdk: docker
|
| 7 |
pinned: false
|
| 8 |
+
license: mit
|
| 9 |
+
app_port: 7860
|
| 10 |
---
|
| 11 |
|
| 12 |
+
# 👋 Gesture Detection & Identity Validation API
|
| 13 |
+
|
| 14 |
+
A unified API for gesture detection in videos and identity validation using facial recognition and gesture verification.
|
| 15 |
+
|
| 16 |
+
## 🚀 Features
|
| 17 |
+
|
| 18 |
+
- **Gesture Detection**: Detect and track hand gestures in video files
|
| 19 |
+
- **Identity Validation**: Validate user identity using facial recognition and required gestures
|
| 20 |
+
- **Real-time Processing**: Efficient video processing with configurable frame skip
|
| 21 |
+
- **RESTful API**: Clean, documented API endpoints
|
| 22 |
+
|
| 23 |
+
## 📋 API Endpoints
|
| 24 |
+
|
| 25 |
+
### `GET /`
|
| 26 |
+
Get API information and available endpoints
|
| 27 |
+
|
| 28 |
+
### `GET /health`
|
| 29 |
+
Health check endpoint showing service status
|
| 30 |
+
|
| 31 |
+
### `POST /gestures`
|
| 32 |
+
Detect gestures in an uploaded video file
|
| 33 |
+
|
| 34 |
+
**Parameters:**
|
| 35 |
+
- `video` (file): Video file to process
|
| 36 |
+
- `frame_skip` (int, optional): Number of frames to skip (default: 1)
|
| 37 |
+
|
| 38 |
+
**Response:**
|
| 39 |
+
```json
|
| 40 |
+
{
|
| 41 |
+
"gestures": [
|
| 42 |
+
{
|
| 43 |
+
"gesture": "thumbs_up",
|
| 44 |
+
"duration": 45,
|
| 45 |
+
"confidence": 0.92
|
| 46 |
+
}
|
| 47 |
+
]
|
| 48 |
+
}
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### `POST /validate`
|
| 52 |
+
Validate user identity using facial recognition and gesture verification
|
| 53 |
+
|
| 54 |
+
**Parameters:**
|
| 55 |
+
- `photo` (file): ID document photo
|
| 56 |
+
- `video` (file): User video containing face and gestures
|
| 57 |
+
- `gestures` (JSON array): Required gestures (e.g., `["thumbs_up","peace"]`)
|
| 58 |
+
- `error_margin` (float, optional): Error margin for validation (default: 0.33)
|
| 59 |
+
- `require_all_gestures` (bool, optional): Whether all gestures must be present
|
| 60 |
+
- `similarity_threshold` (float, optional): Facial similarity threshold
|
| 61 |
+
- `include_details` (bool, optional): Include detailed validation results
|
| 62 |
+
|
| 63 |
+
**Response:**
|
| 64 |
+
```json
|
| 65 |
+
{
|
| 66 |
+
"face": true,
|
| 67 |
+
"gestures": true,
|
| 68 |
+
"overall": true,
|
| 69 |
+
"status": "success",
|
| 70 |
+
"processing_time_ms": 6925,
|
| 71 |
+
"timestamp": "2025-09-30T08:30:22Z"
|
| 72 |
+
}
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## 🎯 Supported Gestures
|
| 76 |
+
|
| 77 |
+
- `thumbs_up` (👍)
|
| 78 |
+
- `peace` (✌️)
|
| 79 |
+
- `ok_sign` (👌)
|
| 80 |
+
- `open_palm` (👋)
|
| 81 |
+
- `call_me` (🤙)
|
| 82 |
+
- `grabbing` (✊)
|
| 83 |
+
|
| 84 |
+
## 📖 Documentation
|
| 85 |
+
|
| 86 |
+
Interactive API documentation is available at:
|
| 87 |
+
- **Swagger UI**: `/docs`
|
| 88 |
+
- **ReDoc**: `/redoc`
|
| 89 |
+
|
| 90 |
+
## 🔧 Usage Example
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
# Detect gestures in a video
|
| 94 |
+
curl -X POST http://localhost:7860/gestures \
|
| 95 |
+
-F "video=@my_video.mp4" \
|
| 96 |
+
-F "frame_skip=3"
|
| 97 |
+
|
| 98 |
+
# Validate identity
|
| 99 |
+
curl -X POST http://localhost:7860/validate \
|
| 100 |
+
-F "photo=@id_photo.jpg" \
|
| 101 |
+
-F "video=@user_video.mp4" \
|
| 102 |
+
-F 'gestures=["thumbs_up","peace"]' \
|
| 103 |
+
-F "include_details=true"
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
## 🏗️ Technology Stack
|
| 107 |
+
|
| 108 |
+
- **Framework**: FastAPI
|
| 109 |
+
- **ML Models**: ONNX Runtime
|
| 110 |
+
- **Computer Vision**: OpenCV
|
| 111 |
+
- **Tracking**: OCSort with Kalman filters
|
| 112 |
+
- **Facial Recognition**: Custom embeddings module
|
| 113 |
+
|
| 114 |
+
## 📝 Note
|
| 115 |
+
|
| 116 |
+
Facial validation is currently in placeholder mode and always returns success. Gesture validation is fully functional.
|
| 117 |
+
|
| 118 |
+
## 📄 License
|
| 119 |
+
|
| 120 |
+
MIT License - See LICENSE file for details
|
| 121 |
+
|
| 122 |
+
## 🔗 Links
|
| 123 |
+
|
| 124 |
+
- [GitHub Repository](https://github.com/kybtech/gesture-detection)
|
| 125 |
+
- [API Documentation](/docs)
|
| 126 |
+
- [Hugging Face Space](https://huggingface.co/spaces/algoryn/validation)
|
main.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Main entry point for the unified gesture detection and identity validation API.
|
| 4 |
+
Provides a flat API structure with all endpoints at the root level.
|
| 5 |
+
"""
|
| 6 |
+
import uvicorn
|
| 7 |
+
import os
|
| 8 |
+
import sys
|
| 9 |
+
import tempfile
|
| 10 |
+
import time
|
| 11 |
+
import json
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Optional
|
| 14 |
+
from datetime import datetime, timezone
|
| 15 |
+
|
| 16 |
+
# Add the project root to Python path
|
| 17 |
+
project_root = os.path.dirname(os.path.abspath(__file__))
|
| 18 |
+
sys.path.insert(0, project_root)
|
| 19 |
+
|
| 20 |
+
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Depends
|
| 21 |
+
from fastapi.responses import ORJSONResponse
|
| 22 |
+
|
| 23 |
+
# Import gesture detection functionality
|
| 24 |
+
from src.gesturedetection.api import process_video_for_gestures
|
| 25 |
+
from src.gesturedetection.models import GestureResponse
|
| 26 |
+
|
| 27 |
+
# Import validation functionality
|
| 28 |
+
from src.validate.models import ValidationRequest, ValidationResponse, ValidationStatus
|
| 29 |
+
from src.validate.facial_validator import FacialValidator
|
| 30 |
+
from src.validate.gesture_validator import GestureValidator
|
| 31 |
+
from src.validate.api import get_validation_request
|
| 32 |
+
from src.validate.config import config
|
| 33 |
+
|
| 34 |
+
# Configure logging
|
| 35 |
+
logging.basicConfig(
|
| 36 |
+
level=logging.INFO,
|
| 37 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
| 38 |
+
)
|
| 39 |
+
logger = logging.getLogger(__name__)
|
| 40 |
+
|
| 41 |
+
# Create main FastAPI application
|
| 42 |
+
app = FastAPI(
|
| 43 |
+
title="Gesture Detection & Identity Validation API",
|
| 44 |
+
description="Unified API for gesture detection and identity validation services",
|
| 45 |
+
version="1.0.0",
|
| 46 |
+
docs_url="/docs",
|
| 47 |
+
redoc_url="/redoc",
|
| 48 |
+
default_response_class=ORJSONResponse
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Initialize validators for validation endpoint
|
| 52 |
+
facial_validator = FacialValidator()
|
| 53 |
+
gesture_validator = GestureValidator()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
@app.get("/")
|
| 57 |
+
async def root():
|
| 58 |
+
"""
|
| 59 |
+
Root endpoint providing API information.
|
| 60 |
+
|
| 61 |
+
Returns
|
| 62 |
+
-------
|
| 63 |
+
dict
|
| 64 |
+
API information and available endpoints
|
| 65 |
+
"""
|
| 66 |
+
return {
|
| 67 |
+
"name": "Gesture Detection & Identity Validation API",
|
| 68 |
+
"version": "1.0.0",
|
| 69 |
+
"description": "Unified API providing gesture detection and identity validation services",
|
| 70 |
+
"endpoints": {
|
| 71 |
+
"GET /": "API information",
|
| 72 |
+
"GET /health": "Health check",
|
| 73 |
+
"POST /validate": "Validate identity using facial recognition and gestures",
|
| 74 |
+
"POST /gestures": "Detect gestures in video",
|
| 75 |
+
"GET /docs": "Interactive API documentation"
|
| 76 |
+
}
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
@app.get("/health")
|
| 81 |
+
async def health():
|
| 82 |
+
"""
|
| 83 |
+
Health check endpoint for the unified API.
|
| 84 |
+
|
| 85 |
+
Returns
|
| 86 |
+
-------
|
| 87 |
+
dict
|
| 88 |
+
Health status of all service components
|
| 89 |
+
"""
|
| 90 |
+
return {
|
| 91 |
+
"status": "healthy",
|
| 92 |
+
"service": "unified-api",
|
| 93 |
+
"version": "1.0.0",
|
| 94 |
+
"timestamp": datetime.now(timezone.utc).isoformat(),
|
| 95 |
+
"components": {
|
| 96 |
+
"gesture_detection": "available",
|
| 97 |
+
"identity_validation": "available",
|
| 98 |
+
"facial_validator": "initialized",
|
| 99 |
+
"gesture_validator": "initialized"
|
| 100 |
+
}
|
| 101 |
+
}
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
@app.post("/gestures", response_model=GestureResponse)
|
| 105 |
+
async def detect_gestures(video: UploadFile = File(...), frame_skip: int = Form(1)):
|
| 106 |
+
"""
|
| 107 |
+
Detect gestures in an uploaded video file.
|
| 108 |
+
|
| 109 |
+
Parameters
|
| 110 |
+
----------
|
| 111 |
+
video : UploadFile
|
| 112 |
+
The video file to process
|
| 113 |
+
frame_skip : int
|
| 114 |
+
Number of frames to skip between processing (1 = process every frame, 3 = process every 3rd frame)
|
| 115 |
+
|
| 116 |
+
Returns
|
| 117 |
+
-------
|
| 118 |
+
GestureResponse
|
| 119 |
+
Response containing detected gestures with duration and confidence
|
| 120 |
+
"""
|
| 121 |
+
logger.info(f"Gesture detection request received: {video.filename}")
|
| 122 |
+
|
| 123 |
+
# Validate file type
|
| 124 |
+
if not video.content_type or not video.content_type.startswith('video/'):
|
| 125 |
+
raise HTTPException(status_code=400, detail="File must be a video")
|
| 126 |
+
|
| 127 |
+
# Create temporary file to save uploaded video
|
| 128 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
|
| 129 |
+
try:
|
| 130 |
+
# Write uploaded content to temporary file
|
| 131 |
+
content = await video.read()
|
| 132 |
+
temp_file.write(content)
|
| 133 |
+
temp_file.flush()
|
| 134 |
+
|
| 135 |
+
logger.info(f"Processing video: {temp_file.name} ({len(content)} bytes)")
|
| 136 |
+
|
| 137 |
+
# Process the video with frame skip parameter
|
| 138 |
+
gestures = process_video_for_gestures(temp_file.name, frame_skip=frame_skip)
|
| 139 |
+
|
| 140 |
+
logger.info(f"Gesture detection completed: {len(gestures)} gestures detected")
|
| 141 |
+
|
| 142 |
+
return GestureResponse(gestures=gestures)
|
| 143 |
+
|
| 144 |
+
except Exception as e:
|
| 145 |
+
logger.error(f"Error processing video: {str(e)}", exc_info=True)
|
| 146 |
+
raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")
|
| 147 |
+
|
| 148 |
+
finally:
|
| 149 |
+
# Clean up temporary file
|
| 150 |
+
if os.path.exists(temp_file.name):
|
| 151 |
+
os.unlink(temp_file.name)
|
| 152 |
+
logger.debug(f"Cleaned up temporary file: {temp_file.name}")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@app.post("/validate", response_model=ValidationResponse)
|
| 156 |
+
async def validate_identity(
|
| 157 |
+
photo: UploadFile = File(...),
|
| 158 |
+
video: UploadFile = File(...),
|
| 159 |
+
request: ValidationRequest = Depends(get_validation_request)
|
| 160 |
+
):
|
| 161 |
+
"""
|
| 162 |
+
Validate user identity using facial recognition and gesture validation.
|
| 163 |
+
|
| 164 |
+
This endpoint accepts an ID document photo, a user video containing
|
| 165 |
+
the person's face and required gestures, and a list of gestures that
|
| 166 |
+
must be performed. It returns validation results for both facial
|
| 167 |
+
recognition and gesture compliance.
|
| 168 |
+
|
| 169 |
+
Parameters
|
| 170 |
+
----------
|
| 171 |
+
photo : UploadFile
|
| 172 |
+
ID document photo file (image format)
|
| 173 |
+
video : UploadFile
|
| 174 |
+
User video file containing face and gestures (video format)
|
| 175 |
+
request : ValidationRequest
|
| 176 |
+
Validation configuration and gesture requirements
|
| 177 |
+
|
| 178 |
+
Returns
|
| 179 |
+
-------
|
| 180 |
+
ValidationResponse
|
| 181 |
+
Validation results with success indicators and optional details
|
| 182 |
+
|
| 183 |
+
Raises
|
| 184 |
+
------
|
| 185 |
+
HTTPException
|
| 186 |
+
If validation fails or processing errors occur
|
| 187 |
+
"""
|
| 188 |
+
start_time = time.time()
|
| 189 |
+
logger.info(f"Identity validation request received for {request.asked_gestures}")
|
| 190 |
+
|
| 191 |
+
# Validate file types
|
| 192 |
+
if not photo.content_type or not photo.content_type.startswith(('image/', 'application/')):
|
| 193 |
+
raise HTTPException(
|
| 194 |
+
status_code=400,
|
| 195 |
+
detail="Photo file must be an image"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
if not video.content_type or not video.content_type.startswith('video/'):
|
| 199 |
+
raise HTTPException(
|
| 200 |
+
status_code=400,
|
| 201 |
+
detail="Video file must be a video"
|
| 202 |
+
)
|
| 203 |
+
|
| 204 |
+
# Validate file sizes (basic check)
|
| 205 |
+
MAX_FILE_SIZE = 100 * 1024 * 1024 # 100MB
|
| 206 |
+
if photo.size and photo.size > MAX_FILE_SIZE:
|
| 207 |
+
raise HTTPException(
|
| 208 |
+
status_code=413,
|
| 209 |
+
detail="Photo file too large (max 100MB)"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
if video.size and video.size > MAX_FILE_SIZE:
|
| 213 |
+
raise HTTPException(
|
| 214 |
+
status_code=413,
|
| 215 |
+
detail="Video file too large (max 100MB)"
|
| 216 |
+
)
|
| 217 |
+
|
| 218 |
+
# Create temporary files for processing
|
| 219 |
+
temp_photo = None
|
| 220 |
+
temp_video = None
|
| 221 |
+
|
| 222 |
+
try:
|
| 223 |
+
# Save uploaded files to temporary location
|
| 224 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_photo.{photo.filename.split('.')[-1] if '.' in photo.filename else 'jpg'}") as temp_photo_file:
|
| 225 |
+
temp_photo = temp_photo_file.name
|
| 226 |
+
photo_content = await photo.read()
|
| 227 |
+
temp_photo_file.write(photo_content)
|
| 228 |
+
|
| 229 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=f"_video.{video.filename.split('.')[-1] if '.' in video.filename else 'mp4'}") as temp_video_file:
|
| 230 |
+
temp_video = temp_video_file.name
|
| 231 |
+
video_content = await video.read()
|
| 232 |
+
temp_video_file.write(video_content)
|
| 233 |
+
|
| 234 |
+
logger.info(f"Files saved: photo={temp_photo}, video={temp_video}")
|
| 235 |
+
|
| 236 |
+
# Perform facial validation
|
| 237 |
+
logger.info("Starting facial validation")
|
| 238 |
+
|
| 239 |
+
# Update facial validator with request-specific parameters if provided
|
| 240 |
+
if request.similarity_threshold is not None:
|
| 241 |
+
facial_validator.similarity_threshold = request.similarity_threshold
|
| 242 |
+
if request.frame_sample_rate is not None:
|
| 243 |
+
facial_validator.frame_sample_rate = request.frame_sample_rate
|
| 244 |
+
|
| 245 |
+
face_result = facial_validator.validate_facial_match(temp_photo, temp_video)
|
| 246 |
+
|
| 247 |
+
# Perform gesture validation
|
| 248 |
+
logger.info("Starting gesture validation")
|
| 249 |
+
|
| 250 |
+
# Update gesture validator with request-specific parameters if provided
|
| 251 |
+
if request.confidence_threshold is not None:
|
| 252 |
+
gesture_validator.confidence_threshold = request.confidence_threshold
|
| 253 |
+
if request.min_gesture_duration is not None:
|
| 254 |
+
gesture_validator.min_gesture_duration = request.min_gesture_duration
|
| 255 |
+
|
| 256 |
+
gesture_result = gesture_validator.validate_gestures(
|
| 257 |
+
temp_video,
|
| 258 |
+
request.asked_gestures,
|
| 259 |
+
error_margin=request.error_margin,
|
| 260 |
+
require_all=request.require_all_gestures
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Determine overall result
|
| 264 |
+
overall_success = face_result.success and gesture_result.success
|
| 265 |
+
overall_status = ValidationStatus.SUCCESS if overall_success else ValidationStatus.PARTIAL
|
| 266 |
+
|
| 267 |
+
# Calculate processing time
|
| 268 |
+
processing_time_ms = int((time.time() - start_time) * 1000)
|
| 269 |
+
|
| 270 |
+
# Build response
|
| 271 |
+
response = ValidationResponse(
|
| 272 |
+
face=face_result.success,
|
| 273 |
+
gestures=gesture_result.success,
|
| 274 |
+
overall=overall_success,
|
| 275 |
+
status=overall_status,
|
| 276 |
+
face_result=face_result if request.include_details else None,
|
| 277 |
+
gesture_result=gesture_result if request.include_details else None,
|
| 278 |
+
processing_time_ms=processing_time_ms,
|
| 279 |
+
timestamp=datetime.now(timezone.utc).isoformat()
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Log results
|
| 283 |
+
logger.info(
|
| 284 |
+
"Identity validation completed",
|
| 285 |
+
extra={
|
| 286 |
+
"face_success": face_result.success,
|
| 287 |
+
"gesture_success": gesture_result.success,
|
| 288 |
+
"overall_success": overall_success,
|
| 289 |
+
"processing_time_ms": processing_time_ms,
|
| 290 |
+
"requested_gestures": request.asked_gestures
|
| 291 |
+
}
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
return response
|
| 295 |
+
|
| 296 |
+
except Exception as e:
|
| 297 |
+
logger.error(f"Error during identity validation: {str(e)}", exc_info=True)
|
| 298 |
+
raise HTTPException(
|
| 299 |
+
status_code=500,
|
| 300 |
+
detail=f"Internal server error during validation: {str(e)}"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
finally:
|
| 304 |
+
# Clean up temporary files
|
| 305 |
+
for temp_file in [temp_photo, temp_video]:
|
| 306 |
+
if temp_file and os.path.exists(temp_file):
|
| 307 |
+
try:
|
| 308 |
+
os.unlink(temp_file)
|
| 309 |
+
logger.debug(f"Cleaned up temporary file: {temp_file}")
|
| 310 |
+
except Exception as e:
|
| 311 |
+
logger.warning(f"Failed to clean up temporary file {temp_file}: {e}")
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def main():
|
| 315 |
+
"""Start the unified API server."""
|
| 316 |
+
# Get port from environment variable, default to 7860 for HF Spaces compatibility
|
| 317 |
+
port = int(os.getenv("PORT", 7860))
|
| 318 |
+
|
| 319 |
+
print("🚀 Starting Unified Gesture Detection & Identity Validation API")
|
| 320 |
+
print(f"📍 API will be available at: http://localhost:{port}")
|
| 321 |
+
print(f"📚 API documentation at: http://localhost:{port}/docs")
|
| 322 |
+
print(f"❤️ Health check at: http://localhost:{port}/health")
|
| 323 |
+
print(f"🔐 Identity validation at: POST http://localhost:{port}/validate")
|
| 324 |
+
print(f"👋 Gesture detection at: POST http://localhost:{port}/gestures")
|
| 325 |
+
print("\nPress Ctrl+C to stop the server")
|
| 326 |
+
|
| 327 |
+
uvicorn.run(
|
| 328 |
+
app,
|
| 329 |
+
host="0.0.0.0",
|
| 330 |
+
port=port,
|
| 331 |
+
reload=False, # Disable reload in production/Docker
|
| 332 |
+
log_level="info"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
main()
|
models/crops_classifier.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:12a02344f63a7c4f2a2ca90f8740ca10a08c17b683b5585d73c3e88323056762
|
| 3 |
+
size 411683
|
models/hand_detector.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a8ef73d466b61a8e8677be9c47008b217a11d1b265d95e36bf2521ff93329af6
|
| 3 |
+
size 1219959
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Direct dependencies from pyproject.toml
|
| 2 |
+
filterpy>=1.4.5
|
| 3 |
+
onnx>=1.19.0
|
| 4 |
+
onnxruntime>=1.22.1
|
| 5 |
+
opencv-contrib-python>=4.12.0.88
|
| 6 |
+
fastapi>=0.104.0
|
| 7 |
+
pydantic>=2.0.0
|
| 8 |
+
uvicorn>=0.24.0
|
| 9 |
+
python-multipart>=0.0.6
|
| 10 |
+
orjson>=3.9.0
|
| 11 |
+
numpy>=1.24.0
|
| 12 |
+
scipy>=1.11.0
|
| 13 |
+
logfire[fastapi,sqlite3,httpx]>=0.0.0
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/facialembeddingsmatch/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Facial embeddings matching module for identity verification.
|
| 3 |
+
|
| 4 |
+
This module provides facial recognition functionality using embedding-based
|
| 5 |
+
matching algorithms. It handles face detection, feature extraction, and
|
| 6 |
+
similarity comparison for identity verification purposes.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
__version__ = "1.0.0"
|
| 10 |
+
__all__ = [
|
| 11 |
+
"FacialEmbeddingMatcher",
|
| 12 |
+
"FaceDetector",
|
| 13 |
+
"EmbeddingExtractor",
|
| 14 |
+
"SimilarityCalculator"
|
| 15 |
+
]
|
src/facialembeddingsmatch/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (617 Bytes). View file
|
|
|
src/facialembeddingsmatch/__pycache__/facial_matcher.cpython-312.pyc
ADDED
|
Binary file (13.5 kB). View file
|
|
|
src/facialembeddingsmatch/facial_matcher.py
ADDED
|
@@ -0,0 +1,433 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Facial embedding matcher for identity verification.
|
| 3 |
+
|
| 4 |
+
This module provides comprehensive facial recognition functionality including
|
| 5 |
+
face detection, embedding extraction, and similarity comparison. It serves
|
| 6 |
+
as the core facial matching component for the identity validation system.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import logging
|
| 11 |
+
import tempfile
|
| 12 |
+
from typing import List, Dict, Any, Optional, Tuple
|
| 13 |
+
from datetime import datetime, timezone
|
| 14 |
+
import numpy as np
|
| 15 |
+
|
| 16 |
+
logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class FaceDetector:
|
| 20 |
+
"""
|
| 21 |
+
Face detection component for identifying faces in images.
|
| 22 |
+
|
| 23 |
+
This class handles face detection in both ID photos and video frames.
|
| 24 |
+
Currently implemented as a stub, designed to be replaced with actual
|
| 25 |
+
face detection algorithms (e.g., MTCNN, DLib, or OpenCV cascades).
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, confidence_threshold: float = 0.8):
|
| 29 |
+
"""
|
| 30 |
+
Initialize the face detector.
|
| 31 |
+
|
| 32 |
+
Parameters
|
| 33 |
+
----------
|
| 34 |
+
confidence_threshold : float, optional
|
| 35 |
+
Minimum confidence threshold for face detection, by default 0.8
|
| 36 |
+
"""
|
| 37 |
+
self.confidence_threshold = confidence_threshold
|
| 38 |
+
logger.info(f"FaceDetector initialized with confidence_threshold={confidence_threshold}")
|
| 39 |
+
|
| 40 |
+
def detect_faces(self, image_path: str) -> List[Dict[str, Any]]:
|
| 41 |
+
"""
|
| 42 |
+
Detect faces in an image.
|
| 43 |
+
|
| 44 |
+
This is currently a stub implementation that simulates face detection.
|
| 45 |
+
In the future, this will be replaced with actual face detection algorithms.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
image_path : str
|
| 50 |
+
Path to the image file
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
List[Dict[str, Any]]
|
| 55 |
+
List of detected faces with bounding boxes and confidence scores
|
| 56 |
+
"""
|
| 57 |
+
logger.debug(f"Detecting faces in {image_path} (stub implementation)")
|
| 58 |
+
|
| 59 |
+
# Validate input file
|
| 60 |
+
if not os.path.exists(image_path):
|
| 61 |
+
logger.error(f"Image file not found: {image_path}")
|
| 62 |
+
raise FileNotFoundError(f"Image file not found: {image_path}")
|
| 63 |
+
|
| 64 |
+
# Stub implementation: simulate detecting one face
|
| 65 |
+
# In a real implementation, this would use actual face detection
|
| 66 |
+
detected_faces = [
|
| 67 |
+
{
|
| 68 |
+
"bbox": [100, 100, 200, 200], # x1, y1, x2, y2
|
| 69 |
+
"confidence": 0.95,
|
| 70 |
+
"landmarks": None, # Facial landmarks if available
|
| 71 |
+
"image_path": image_path
|
| 72 |
+
}
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
logger.debug(f"Detected {len(detected_faces)} faces")
|
| 76 |
+
return detected_faces
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class EmbeddingExtractor:
|
| 80 |
+
"""
|
| 81 |
+
Facial embedding extraction component.
|
| 82 |
+
|
| 83 |
+
This class extracts facial feature embeddings from detected faces.
|
| 84 |
+
Currently implemented as a stub, designed to be replaced with actual
|
| 85 |
+
embedding extraction models (e.g., FaceNet, ArcFace, or VGGFace).
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(self, model_path: Optional[str] = None):
|
| 89 |
+
"""
|
| 90 |
+
Initialize the embedding extractor.
|
| 91 |
+
|
| 92 |
+
Parameters
|
| 93 |
+
----------
|
| 94 |
+
model_path : Optional[str], optional
|
| 95 |
+
Path to the embedding extraction model, by default None
|
| 96 |
+
"""
|
| 97 |
+
self.model_path = model_path
|
| 98 |
+
logger.info(f"EmbeddingExtractor initialized with model_path={model_path}")
|
| 99 |
+
|
| 100 |
+
def extract_embedding(self, image_path: str, face_bbox: List[int]) -> Optional[np.ndarray]:
|
| 101 |
+
"""
|
| 102 |
+
Extract facial embedding from a face region.
|
| 103 |
+
|
| 104 |
+
This is currently a stub implementation that returns a random embedding.
|
| 105 |
+
In the future, this will extract actual facial embeddings using deep learning models.
|
| 106 |
+
|
| 107 |
+
Parameters
|
| 108 |
+
----------
|
| 109 |
+
image_path : str
|
| 110 |
+
Path to the image file
|
| 111 |
+
face_bbox : List[int]
|
| 112 |
+
Bounding box coordinates [x1, y1, x2, y2]
|
| 113 |
+
|
| 114 |
+
Returns
|
| 115 |
+
-------
|
| 116 |
+
Optional[np.ndarray]
|
| 117 |
+
Facial embedding vector, or None if extraction fails
|
| 118 |
+
"""
|
| 119 |
+
logger.debug(f"Extracting embedding from {image_path} with bbox {face_bbox}")
|
| 120 |
+
|
| 121 |
+
# Validate input file
|
| 122 |
+
if not os.path.exists(image_path):
|
| 123 |
+
logger.error(f"Image file not found: {image_path}")
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
# Stub implementation: return deterministic 128-dimensional embedding for testing
|
| 127 |
+
# In a real implementation, this would use a trained model
|
| 128 |
+
# Use a seed based on the image path to make it deterministic for testing
|
| 129 |
+
import hashlib
|
| 130 |
+
seed = int(hashlib.md5(image_path.encode()).hexdigest()[:8], 16) % 2**32
|
| 131 |
+
np.random.seed(seed)
|
| 132 |
+
embedding = np.random.randn(128).astype(np.float32)
|
| 133 |
+
# Normalize the embedding
|
| 134 |
+
embedding = embedding / np.linalg.norm(embedding)
|
| 135 |
+
|
| 136 |
+
logger.debug(f"Extracted embedding with shape {embedding.shape}")
|
| 137 |
+
return embedding
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class SimilarityCalculator:
|
| 141 |
+
"""
|
| 142 |
+
Similarity calculation component for comparing facial embeddings.
|
| 143 |
+
|
| 144 |
+
This class computes similarity scores between facial embeddings using
|
| 145 |
+
various distance metrics. Currently supports cosine similarity.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
def __init__(self):
|
| 149 |
+
"""Initialize the similarity calculator."""
|
| 150 |
+
logger.info("SimilarityCalculator initialized")
|
| 151 |
+
|
| 152 |
+
def calculate_similarity(self, embedding1: np.ndarray, embedding2: np.ndarray) -> float:
|
| 153 |
+
"""
|
| 154 |
+
Calculate similarity between two facial embeddings.
|
| 155 |
+
|
| 156 |
+
Parameters
|
| 157 |
+
----------
|
| 158 |
+
embedding1 : np.ndarray
|
| 159 |
+
First facial embedding
|
| 160 |
+
embedding2 : np.ndarray
|
| 161 |
+
Second facial embedding
|
| 162 |
+
|
| 163 |
+
Returns
|
| 164 |
+
-------
|
| 165 |
+
float
|
| 166 |
+
Similarity score between 0.0 (dissimilar) and 1.0 (identical)
|
| 167 |
+
"""
|
| 168 |
+
# Calculate cosine similarity
|
| 169 |
+
dot_product = np.dot(embedding1, embedding2)
|
| 170 |
+
norm1 = np.linalg.norm(embedding1)
|
| 171 |
+
norm2 = np.linalg.norm(embedding2)
|
| 172 |
+
|
| 173 |
+
if norm1 == 0 or norm2 == 0:
|
| 174 |
+
return 0.0
|
| 175 |
+
|
| 176 |
+
cosine_similarity = dot_product / (norm1 * norm2)
|
| 177 |
+
|
| 178 |
+
# Convert to similarity score (0.0 to 1.0)
|
| 179 |
+
similarity = (cosine_similarity + 1.0) / 2.0
|
| 180 |
+
|
| 181 |
+
logger.debug(f"Calculated similarity: {similarity}")
|
| 182 |
+
return similarity
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class FacialEmbeddingMatcher:
|
| 186 |
+
"""
|
| 187 |
+
Main facial embedding matcher for identity verification.
|
| 188 |
+
|
| 189 |
+
This class orchestrates the complete facial recognition pipeline:
|
| 190 |
+
face detection, embedding extraction, and similarity comparison.
|
| 191 |
+
It serves as the primary interface for facial matching functionality.
|
| 192 |
+
"""
|
| 193 |
+
|
| 194 |
+
def __init__(
|
| 195 |
+
self,
|
| 196 |
+
detector_confidence: float = 0.8,
|
| 197 |
+
similarity_threshold: float = 0.7,
|
| 198 |
+
embedding_model_path: Optional[str] = None
|
| 199 |
+
):
|
| 200 |
+
"""
|
| 201 |
+
Initialize the facial embedding matcher.
|
| 202 |
+
|
| 203 |
+
Parameters
|
| 204 |
+
----------
|
| 205 |
+
detector_confidence : float, optional
|
| 206 |
+
Confidence threshold for face detection, by default 0.8
|
| 207 |
+
similarity_threshold : float, optional
|
| 208 |
+
Similarity threshold for facial matching, by default 0.7
|
| 209 |
+
embedding_model_path : Optional[str], optional
|
| 210 |
+
Path to embedding extraction model, by default None
|
| 211 |
+
"""
|
| 212 |
+
self.detector_confidence = detector_confidence
|
| 213 |
+
self.similarity_threshold = similarity_threshold
|
| 214 |
+
self.embedding_model_path = embedding_model_path
|
| 215 |
+
|
| 216 |
+
# Initialize components
|
| 217 |
+
self.face_detector = FaceDetector(confidence_threshold=detector_confidence)
|
| 218 |
+
self.embedding_extractor = EmbeddingExtractor(model_path=embedding_model_path)
|
| 219 |
+
self.similarity_calculator = SimilarityCalculator()
|
| 220 |
+
|
| 221 |
+
logger.info(
|
| 222 |
+
"FacialEmbeddingMatcher initialized",
|
| 223 |
+
extra={
|
| 224 |
+
"detector_confidence": detector_confidence,
|
| 225 |
+
"similarity_threshold": similarity_threshold,
|
| 226 |
+
"embedding_model_path": embedding_model_path
|
| 227 |
+
}
|
| 228 |
+
)
|
| 229 |
+
|
| 230 |
+
def match_faces(
|
| 231 |
+
self,
|
| 232 |
+
id_image_path: str,
|
| 233 |
+
video_path: str,
|
| 234 |
+
frame_sample_rate: int = 10
|
| 235 |
+
) -> Dict[str, Any]:
|
| 236 |
+
"""
|
| 237 |
+
Match faces between ID image and video frames.
|
| 238 |
+
|
| 239 |
+
This method performs comprehensive facial matching by:
|
| 240 |
+
1. Detecting faces in the ID image
|
| 241 |
+
2. Sampling frames from the video and detecting faces
|
| 242 |
+
3. Extracting embeddings from detected faces
|
| 243 |
+
4. Computing similarity scores
|
| 244 |
+
5. Determining overall match result
|
| 245 |
+
|
| 246 |
+
Parameters
|
| 247 |
+
----------
|
| 248 |
+
id_image_path : str
|
| 249 |
+
Path to the ID document image
|
| 250 |
+
video_path : str
|
| 251 |
+
Path to the user video
|
| 252 |
+
frame_sample_rate : int, optional
|
| 253 |
+
Rate at which to sample video frames, by default 10
|
| 254 |
+
|
| 255 |
+
Returns
|
| 256 |
+
-------
|
| 257 |
+
Dict[str, Any]
|
| 258 |
+
Matching results with similarity scores and metadata
|
| 259 |
+
"""
|
| 260 |
+
logger.info(f"Starting facial matching between {id_image_path} and {video_path}")
|
| 261 |
+
|
| 262 |
+
try:
|
| 263 |
+
# Step 1: Extract reference embedding from ID image
|
| 264 |
+
id_faces = self.face_detector.detect_faces(id_image_path)
|
| 265 |
+
|
| 266 |
+
if not id_faces:
|
| 267 |
+
return {
|
| 268 |
+
"success": False,
|
| 269 |
+
"error": "No faces detected in ID image",
|
| 270 |
+
"similarity_score": 0.0,
|
| 271 |
+
"matches": False,
|
| 272 |
+
"details": {
|
| 273 |
+
"id_faces_detected": 0,
|
| 274 |
+
"video_faces_detected": 0,
|
| 275 |
+
"processing_timestamp": datetime.now(timezone.utc).isoformat()
|
| 276 |
+
}
|
| 277 |
+
}
|
| 278 |
+
|
| 279 |
+
except FileNotFoundError as e:
|
| 280 |
+
return {
|
| 281 |
+
"success": False,
|
| 282 |
+
"error": f"File not found: {str(e)}",
|
| 283 |
+
"similarity_score": 0.0,
|
| 284 |
+
"matches": False,
|
| 285 |
+
"details": {
|
| 286 |
+
"id_faces_detected": 0,
|
| 287 |
+
"video_faces_detected": 0,
|
| 288 |
+
"processing_timestamp": datetime.now(timezone.utc).isoformat()
|
| 289 |
+
}
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
# Extract embedding from the first (best) face in ID image
|
| 293 |
+
id_face = id_faces[0]
|
| 294 |
+
id_embedding = self.embedding_extractor.extract_embedding(
|
| 295 |
+
id_image_path, id_face["bbox"]
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
if id_embedding is None:
|
| 299 |
+
return {
|
| 300 |
+
"success": False,
|
| 301 |
+
"error": "Failed to extract embedding from ID image",
|
| 302 |
+
"similarity_score": 0.0,
|
| 303 |
+
"matches": False,
|
| 304 |
+
"details": {
|
| 305 |
+
"id_faces_detected": len(id_faces),
|
| 306 |
+
"video_faces_detected": 0,
|
| 307 |
+
"processing_timestamp": datetime.now(timezone.utc).isoformat()
|
| 308 |
+
}
|
| 309 |
+
}
|
| 310 |
+
|
| 311 |
+
# Step 2: Extract faces from video frames
|
| 312 |
+
video_faces = self._extract_faces_from_video(video_path, frame_sample_rate)
|
| 313 |
+
|
| 314 |
+
if not video_faces:
|
| 315 |
+
return {
|
| 316 |
+
"success": False,
|
| 317 |
+
"error": "No faces detected in video",
|
| 318 |
+
"similarity_score": 0.0,
|
| 319 |
+
"matches": False,
|
| 320 |
+
"details": {
|
| 321 |
+
"id_faces_detected": len(id_faces),
|
| 322 |
+
"video_faces_detected": 0,
|
| 323 |
+
"processing_timestamp": datetime.now(timezone.utc).isoformat()
|
| 324 |
+
}
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
# Step 3: Compare embeddings and find best match
|
| 328 |
+
best_similarity = 0.0
|
| 329 |
+
best_video_face = None
|
| 330 |
+
|
| 331 |
+
for video_face in video_faces:
|
| 332 |
+
video_embedding = self.embedding_extractor.extract_embedding(
|
| 333 |
+
video_path, video_face["bbox"]
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if video_embedding is not None:
|
| 337 |
+
similarity = self.similarity_calculator.calculate_similarity(
|
| 338 |
+
id_embedding, video_embedding
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
if similarity > best_similarity:
|
| 342 |
+
best_similarity = similarity
|
| 343 |
+
best_video_face = video_face
|
| 344 |
+
|
| 345 |
+
# Step 4: Determine if faces match
|
| 346 |
+
matches = best_similarity >= self.similarity_threshold
|
| 347 |
+
|
| 348 |
+
result = {
|
| 349 |
+
"success": True,
|
| 350 |
+
"matches": matches,
|
| 351 |
+
"similarity_score": best_similarity,
|
| 352 |
+
"similarity_threshold": self.similarity_threshold,
|
| 353 |
+
"details": {
|
| 354 |
+
"id_faces_detected": len(id_faces),
|
| 355 |
+
"video_faces_detected": len(video_faces),
|
| 356 |
+
"best_video_face": best_video_face,
|
| 357 |
+
"processing_timestamp": datetime.now(timezone.utc).isoformat(),
|
| 358 |
+
"frame_sample_rate": frame_sample_rate,
|
| 359 |
+
"note": "This is a stub implementation. Real facial recognition will be implemented in the future."
|
| 360 |
+
}
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
logger.info(
|
| 364 |
+
"Facial matching completed",
|
| 365 |
+
extra={
|
| 366 |
+
"matches": matches,
|
| 367 |
+
"similarity_score": best_similarity,
|
| 368 |
+
"faces_detected_id": len(id_faces),
|
| 369 |
+
"faces_detected_video": len(video_faces)
|
| 370 |
+
}
|
| 371 |
+
)
|
| 372 |
+
|
| 373 |
+
return result
|
| 374 |
+
|
| 375 |
+
except Exception as e:
|
| 376 |
+
logger.error(f"Error during facial matching: {str(e)}", exc_info=True)
|
| 377 |
+
return {
|
| 378 |
+
"success": False,
|
| 379 |
+
"error": f"Processing error: {str(e)}",
|
| 380 |
+
"similarity_score": 0.0,
|
| 381 |
+
"matches": False,
|
| 382 |
+
"details": {
|
| 383 |
+
"processing_timestamp": datetime.now(timezone.utc).isoformat()
|
| 384 |
+
}
|
| 385 |
+
}
|
| 386 |
+
|
| 387 |
+
def _extract_faces_from_video(self, video_path: str, frame_sample_rate: int) -> List[Dict[str, Any]]:
|
| 388 |
+
"""
|
| 389 |
+
Extract faces from video frames.
|
| 390 |
+
|
| 391 |
+
This method samples frames from the video and detects faces in each frame.
|
| 392 |
+
Currently implemented as a stub that simulates face detection.
|
| 393 |
+
|
| 394 |
+
Parameters
|
| 395 |
+
----------
|
| 396 |
+
video_path : str
|
| 397 |
+
Path to the video file
|
| 398 |
+
frame_sample_rate : int
|
| 399 |
+
Rate at which to sample frames
|
| 400 |
+
|
| 401 |
+
Returns
|
| 402 |
+
-------
|
| 403 |
+
List[Dict[str, Any]]
|
| 404 |
+
List of detected faces with frame information
|
| 405 |
+
"""
|
| 406 |
+
logger.debug(f"Extracting faces from video: {video_path}")
|
| 407 |
+
|
| 408 |
+
# Stub implementation: simulate detecting faces in video
|
| 409 |
+
# In a real implementation, this would:
|
| 410 |
+
# 1. Open the video file
|
| 411 |
+
# 2. Sample frames at the specified rate
|
| 412 |
+
# 3. Detect faces in each sampled frame
|
| 413 |
+
# 4. Return face information with frame metadata
|
| 414 |
+
|
| 415 |
+
detected_faces = [
|
| 416 |
+
{
|
| 417 |
+
"bbox": [120, 120, 220, 220], # x1, y1, x2, y2
|
| 418 |
+
"confidence": 0.92,
|
| 419 |
+
"frame_number": 15,
|
| 420 |
+
"timestamp": 0.5, # seconds
|
| 421 |
+
"image_path": video_path
|
| 422 |
+
},
|
| 423 |
+
{
|
| 424 |
+
"bbox": [110, 110, 210, 210],
|
| 425 |
+
"confidence": 0.88,
|
| 426 |
+
"frame_number": 30,
|
| 427 |
+
"timestamp": 1.0,
|
| 428 |
+
"image_path": video_path
|
| 429 |
+
}
|
| 430 |
+
]
|
| 431 |
+
|
| 432 |
+
logger.debug(f"Extracted {len(detected_faces)} faces from video")
|
| 433 |
+
return detected_faces
|
src/gesturedetection/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/gesturedetection/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Gesture detection package
|
| 2 |
+
from .api import app
|
| 3 |
+
from .models import Gesture, GestureResponse, GESTURE_MAPPING, FULL_GESTURE_MAPPING
|
| 4 |
+
from .main_controller import MainController
|
| 5 |
+
from .onnx_models import HandDetection, HandClassification
|
| 6 |
+
from .utils import Deque, Drawer, Hand, Event, HandPosition, targets
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"app",
|
| 10 |
+
"Gesture",
|
| 11 |
+
"GestureResponse",
|
| 12 |
+
"GESTURE_MAPPING",
|
| 13 |
+
"FULL_GESTURE_MAPPING",
|
| 14 |
+
"MainController",
|
| 15 |
+
"HandDetection",
|
| 16 |
+
"HandClassification",
|
| 17 |
+
"Deque",
|
| 18 |
+
"Drawer",
|
| 19 |
+
"Hand",
|
| 20 |
+
"Event",
|
| 21 |
+
"HandPosition",
|
| 22 |
+
"targets"
|
| 23 |
+
]
|
src/gesturedetection/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (703 Bytes). View file
|
|
|
src/gesturedetection/__pycache__/api.cpython-312.pyc
ADDED
|
Binary file (11.7 kB). View file
|
|
|
src/gesturedetection/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (2.18 kB). View file
|
|
|
src/gesturedetection/__pycache__/main_controller.cpython-312.pyc
ADDED
|
Binary file (12.3 kB). View file
|
|
|
src/gesturedetection/__pycache__/models.cpython-312.pyc
ADDED
|
Binary file (2.5 kB). View file
|
|
|
src/gesturedetection/__pycache__/onnx_models.cpython-312.pyc
ADDED
|
Binary file (10 kB). View file
|
|
|
src/gesturedetection/api.py
ADDED
|
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import cv2
|
| 2 |
+
import numpy as np
|
| 3 |
+
import tempfile
|
| 4 |
+
import os
|
| 5 |
+
from collections import defaultdict
|
| 6 |
+
from typing import Dict, List, Tuple, Optional
|
| 7 |
+
from fastapi import FastAPI, UploadFile, File, HTTPException, Form
|
| 8 |
+
from fastapi.responses import ORJSONResponse
|
| 9 |
+
from fastapi.encoders import jsonable_encoder
|
| 10 |
+
|
| 11 |
+
from .models import Gesture, GestureResponse, GESTURE_MAPPING, FULL_GESTURE_MAPPING
|
| 12 |
+
from .config import get_logfire_token, is_monitoring_enabled
|
| 13 |
+
|
| 14 |
+
# Import the gesture detection components
|
| 15 |
+
from .main_controller import MainController
|
| 16 |
+
|
| 17 |
+
# Configure logfire monitoring if token is available
|
| 18 |
+
logfire = None
|
| 19 |
+
if is_monitoring_enabled():
|
| 20 |
+
try:
|
| 21 |
+
import logfire
|
| 22 |
+
logfire.configure(token=get_logfire_token())
|
| 23 |
+
logfire.instrument_fastapi = logfire.instrument_fastapi
|
| 24 |
+
except ImportError:
|
| 25 |
+
logfire = None
|
| 26 |
+
|
| 27 |
+
app = FastAPI(default_response_class=ORJSONResponse)
|
| 28 |
+
|
| 29 |
+
# Instrument FastAPI with logfire if monitoring is enabled
|
| 30 |
+
if logfire is not None:
|
| 31 |
+
logfire.instrument_fastapi(app, capture_headers=True)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def process_video_for_gestures(video_path: str, detector_path: str = "models/hand_detector.onnx",
|
| 35 |
+
classifier_path: str = "models/crops_classifier.onnx",
|
| 36 |
+
frame_skip: int = 1) -> List[Gesture]:
|
| 37 |
+
"""
|
| 38 |
+
Process a video file to detect gestures using the MainController.
|
| 39 |
+
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
video_path : str
|
| 43 |
+
Path to the video file to process
|
| 44 |
+
detector_path : str
|
| 45 |
+
Path to the hand detection ONNX model
|
| 46 |
+
classifier_path : str
|
| 47 |
+
Path to the gesture classification ONNX model
|
| 48 |
+
frame_skip : int
|
| 49 |
+
Number of frames to skip between processing (1 = process every frame, 3 = process every 3rd frame)
|
| 50 |
+
|
| 51 |
+
Returns
|
| 52 |
+
-------
|
| 53 |
+
List[Gesture]
|
| 54 |
+
List of detected gestures with duration and confidence
|
| 55 |
+
"""
|
| 56 |
+
# Create monitoring span for video processing
|
| 57 |
+
span_context = None
|
| 58 |
+
if logfire is not None:
|
| 59 |
+
span_context = logfire.span('process_video_for_gestures',
|
| 60 |
+
video_path=video_path,
|
| 61 |
+
detector_path=detector_path,
|
| 62 |
+
classifier_path=classifier_path)
|
| 63 |
+
span_context.__enter__()
|
| 64 |
+
|
| 65 |
+
try:
|
| 66 |
+
# Initialize the main controller
|
| 67 |
+
if logfire is not None:
|
| 68 |
+
with logfire.span('initialize_controller'):
|
| 69 |
+
controller = MainController(detector_path, classifier_path)
|
| 70 |
+
else:
|
| 71 |
+
controller = MainController(detector_path, classifier_path)
|
| 72 |
+
|
| 73 |
+
# Open video file
|
| 74 |
+
cap = cv2.VideoCapture(video_path)
|
| 75 |
+
if not cap.isOpened():
|
| 76 |
+
raise ValueError(f"Could not open video file: {video_path}")
|
| 77 |
+
|
| 78 |
+
# Get video properties for monitoring
|
| 79 |
+
total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
|
| 80 |
+
fps = cap.get(cv2.CAP_PROP_FPS)
|
| 81 |
+
|
| 82 |
+
if logfire is not None:
|
| 83 |
+
logfire.info('Video properties',
|
| 84 |
+
total_frames=total_frames,
|
| 85 |
+
fps=fps,
|
| 86 |
+
duration_seconds=total_frames/fps if fps > 0 else 0)
|
| 87 |
+
|
| 88 |
+
# Track gestures per hand ID
|
| 89 |
+
gesture_tracks: Dict[int, List[Tuple[int, float]]] = defaultdict(list) # {hand_id: [(gesture_id, confidence), ...]}
|
| 90 |
+
frame_count = 0
|
| 91 |
+
processed_frames = 0
|
| 92 |
+
detection_stats = {
|
| 93 |
+
'frames_with_detections': 0,
|
| 94 |
+
'total_detections': 0,
|
| 95 |
+
'gesture_counts': defaultdict(int)
|
| 96 |
+
}
|
| 97 |
+
|
| 98 |
+
try:
|
| 99 |
+
while True:
|
| 100 |
+
ret, frame = cap.read()
|
| 101 |
+
if not ret:
|
| 102 |
+
break
|
| 103 |
+
|
| 104 |
+
# Skip frames based on frame_skip parameter
|
| 105 |
+
if frame_count % frame_skip == 0:
|
| 106 |
+
# Process frame through the controller
|
| 107 |
+
bboxes, ids, labels = controller(frame)
|
| 108 |
+
processed_frames += 1
|
| 109 |
+
|
| 110 |
+
if bboxes is not None and ids is not None and labels is not None:
|
| 111 |
+
detection_stats['frames_with_detections'] += 1
|
| 112 |
+
detection_stats['total_detections'] += len(bboxes)
|
| 113 |
+
|
| 114 |
+
# Track gestures for each detected hand
|
| 115 |
+
for i in range(len(bboxes)):
|
| 116 |
+
hand_id = int(ids[i])
|
| 117 |
+
gesture_id = labels[i]
|
| 118 |
+
|
| 119 |
+
if gesture_id is not None:
|
| 120 |
+
# Get confidence from bbox (assuming it's the last element)
|
| 121 |
+
confidence = 0.8 # Default confidence, could be extracted from bbox if available
|
| 122 |
+
gesture_tracks[hand_id].append((gesture_id, confidence))
|
| 123 |
+
detection_stats['gesture_counts'][gesture_id] += 1
|
| 124 |
+
|
| 125 |
+
# Log individual detections for debugging
|
| 126 |
+
if logfire is not None:
|
| 127 |
+
gesture_name = FULL_GESTURE_MAPPING.get(gesture_id, f"unknown_{gesture_id}")
|
| 128 |
+
logfire.debug('Hand detection',
|
| 129 |
+
frame=frame_count,
|
| 130 |
+
hand_id=hand_id,
|
| 131 |
+
gesture_id=gesture_id,
|
| 132 |
+
gesture_name=gesture_name,
|
| 133 |
+
confidence=confidence,
|
| 134 |
+
bbox=bboxes[i].tolist() if len(bboxes[i]) >= 4 else None)
|
| 135 |
+
else:
|
| 136 |
+
# Advance tracker on skipped frames to keep state consistent
|
| 137 |
+
controller.update(np.empty((0, 5)), None)
|
| 138 |
+
|
| 139 |
+
frame_count += 1
|
| 140 |
+
|
| 141 |
+
# Log progress every 100 frames
|
| 142 |
+
if frame_count % 100 == 0 and logfire is not None:
|
| 143 |
+
progress = (frame_count / total_frames) * 100 if total_frames > 0 else 0
|
| 144 |
+
logfire.info('Processing progress',
|
| 145 |
+
frame=frame_count,
|
| 146 |
+
total_frames=total_frames,
|
| 147 |
+
progress_percent=round(progress, 2))
|
| 148 |
+
|
| 149 |
+
finally:
|
| 150 |
+
cap.release()
|
| 151 |
+
|
| 152 |
+
# Log final detection statistics
|
| 153 |
+
if logfire is not None:
|
| 154 |
+
logfire.info('Detection statistics',
|
| 155 |
+
total_frames=frame_count,
|
| 156 |
+
processed_frames=processed_frames,
|
| 157 |
+
frame_skip=frame_skip,
|
| 158 |
+
frames_with_detections=detection_stats['frames_with_detections'],
|
| 159 |
+
total_detections=detection_stats['total_detections'],
|
| 160 |
+
detection_rate=detection_stats['frames_with_detections']/processed_frames if processed_frames > 0 else 0,
|
| 161 |
+
gesture_counts=dict(detection_stats['gesture_counts']))
|
| 162 |
+
|
| 163 |
+
# Process gesture tracks to find continuous gestures
|
| 164 |
+
detected_gestures = []
|
| 165 |
+
|
| 166 |
+
for hand_id, gesture_sequence in gesture_tracks.items():
|
| 167 |
+
if not gesture_sequence:
|
| 168 |
+
continue
|
| 169 |
+
|
| 170 |
+
# Group consecutive identical gestures
|
| 171 |
+
current_gesture = None
|
| 172 |
+
current_duration = 0
|
| 173 |
+
current_confidence = 0.0
|
| 174 |
+
|
| 175 |
+
for gesture_id, confidence in gesture_sequence:
|
| 176 |
+
if current_gesture is None or current_gesture != gesture_id:
|
| 177 |
+
# Save previous gesture if it was significant
|
| 178 |
+
# Adjust minimum duration based on frame skip
|
| 179 |
+
min_duration = max(5, frame_skip * 2) # At least 2 processed frames
|
| 180 |
+
if current_gesture is not None and current_duration >= min_duration:
|
| 181 |
+
gesture_name = FULL_GESTURE_MAPPING.get(current_gesture, f"unknown_{current_gesture}")
|
| 182 |
+
avg_confidence = current_confidence / current_duration if current_duration > 0 else 0.0
|
| 183 |
+
# Scale duration back to original frame count
|
| 184 |
+
scaled_duration = current_duration * frame_skip
|
| 185 |
+
detected_gestures.append(Gesture(
|
| 186 |
+
gesture=gesture_name,
|
| 187 |
+
duration=scaled_duration,
|
| 188 |
+
confidence=avg_confidence
|
| 189 |
+
))
|
| 190 |
+
|
| 191 |
+
# Log significant gesture detection
|
| 192 |
+
if logfire is not None:
|
| 193 |
+
logfire.info('Significant gesture detected',
|
| 194 |
+
hand_id=hand_id,
|
| 195 |
+
gesture=gesture_name,
|
| 196 |
+
duration_frames=current_duration,
|
| 197 |
+
confidence=avg_confidence)
|
| 198 |
+
|
| 199 |
+
# Start new gesture
|
| 200 |
+
current_gesture = gesture_id
|
| 201 |
+
current_duration = 1
|
| 202 |
+
current_confidence = confidence
|
| 203 |
+
else:
|
| 204 |
+
# Continue current gesture
|
| 205 |
+
current_duration += 1
|
| 206 |
+
current_confidence += confidence
|
| 207 |
+
|
| 208 |
+
# Don't forget the last gesture
|
| 209 |
+
min_duration = max(5, frame_skip * 2) # At least 2 processed frames
|
| 210 |
+
if current_gesture is not None and current_duration >= min_duration:
|
| 211 |
+
gesture_name = FULL_GESTURE_MAPPING.get(current_gesture, f"unknown_{current_gesture}")
|
| 212 |
+
avg_confidence = current_confidence / current_duration if current_duration > 0 else 0.0
|
| 213 |
+
# Scale duration back to original frame count
|
| 214 |
+
scaled_duration = current_duration * frame_skip
|
| 215 |
+
detected_gestures.append(Gesture(
|
| 216 |
+
gesture=gesture_name,
|
| 217 |
+
duration=scaled_duration,
|
| 218 |
+
confidence=avg_confidence
|
| 219 |
+
))
|
| 220 |
+
|
| 221 |
+
# Log final gesture detection
|
| 222 |
+
if logfire is not None:
|
| 223 |
+
logfire.info('Final gesture detected',
|
| 224 |
+
hand_id=hand_id,
|
| 225 |
+
gesture=gesture_name,
|
| 226 |
+
duration_frames=current_duration,
|
| 227 |
+
confidence=avg_confidence)
|
| 228 |
+
|
| 229 |
+
# Log final results
|
| 230 |
+
if logfire is not None:
|
| 231 |
+
logfire.info('Video processing completed',
|
| 232 |
+
total_gestures_detected=len(detected_gestures),
|
| 233 |
+
unique_hands=len(gesture_tracks),
|
| 234 |
+
gestures=[{'gesture': g.gesture, 'duration': g.duration, 'confidence': g.confidence} for g in detected_gestures])
|
| 235 |
+
|
| 236 |
+
return detected_gestures
|
| 237 |
+
|
| 238 |
+
finally:
|
| 239 |
+
if span_context is not None:
|
| 240 |
+
span_context.__exit__(None, None, None)
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
@app.get("/health")
|
| 244 |
+
async def health():
|
| 245 |
+
"""Health check endpoint."""
|
| 246 |
+
if logfire is not None:
|
| 247 |
+
logfire.info('Health check requested')
|
| 248 |
+
return {"message": "OK"}
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
@app.post("/gestures", response_model=GestureResponse)
|
| 252 |
+
async def detect_gestures(video: UploadFile = File(...), frame_skip: int = Form(1)):
|
| 253 |
+
"""
|
| 254 |
+
Detect gestures in an uploaded video file.
|
| 255 |
+
|
| 256 |
+
Parameters
|
| 257 |
+
----------
|
| 258 |
+
video : UploadFile
|
| 259 |
+
The video file to process
|
| 260 |
+
frame_skip : int
|
| 261 |
+
Number of frames to skip between processing (1 = process every frame, 3 = process every 3rd frame)
|
| 262 |
+
|
| 263 |
+
Returns
|
| 264 |
+
-------
|
| 265 |
+
GestureResponse
|
| 266 |
+
Response containing detected gestures with duration and confidence
|
| 267 |
+
"""
|
| 268 |
+
# Log request details
|
| 269 |
+
if logfire is not None:
|
| 270 |
+
logfire.info('Gesture detection request received',
|
| 271 |
+
filename=video.filename,
|
| 272 |
+
content_type=video.content_type,
|
| 273 |
+
content_length=video.size if hasattr(video, 'size') else 'unknown')
|
| 274 |
+
|
| 275 |
+
# Validate file type
|
| 276 |
+
if not video.content_type.startswith('video/'):
|
| 277 |
+
if logfire is not None:
|
| 278 |
+
logfire.warning('Invalid file type received', content_type=video.content_type)
|
| 279 |
+
raise HTTPException(status_code=400, detail="File must be a video")
|
| 280 |
+
|
| 281 |
+
# Create temporary file to save uploaded video
|
| 282 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix='.mp4') as temp_file:
|
| 283 |
+
try:
|
| 284 |
+
# Write uploaded content to temporary file
|
| 285 |
+
content = await video.read()
|
| 286 |
+
temp_file.write(content)
|
| 287 |
+
temp_file.flush()
|
| 288 |
+
|
| 289 |
+
if logfire is not None:
|
| 290 |
+
logfire.info('Video file saved for processing',
|
| 291 |
+
temp_file=temp_file.name,
|
| 292 |
+
file_size_bytes=len(content))
|
| 293 |
+
|
| 294 |
+
# Process the video with frame skip parameter
|
| 295 |
+
gestures = process_video_for_gestures(temp_file.name, frame_skip=frame_skip)
|
| 296 |
+
|
| 297 |
+
if logfire is not None:
|
| 298 |
+
logfire.info('Gesture detection completed successfully',
|
| 299 |
+
total_gestures=len(gestures),
|
| 300 |
+
gestures=[g.gesture for g in gestures])
|
| 301 |
+
|
| 302 |
+
return GestureResponse(gestures=gestures)
|
| 303 |
+
|
| 304 |
+
except Exception as e:
|
| 305 |
+
if logfire is not None:
|
| 306 |
+
logfire.error('Error processing video',
|
| 307 |
+
error=str(e),
|
| 308 |
+
error_type=type(e).__name__,
|
| 309 |
+
temp_file=temp_file.name)
|
| 310 |
+
raise HTTPException(status_code=500, detail=f"Error processing video: {str(e)}")
|
| 311 |
+
|
| 312 |
+
finally:
|
| 313 |
+
# Clean up temporary file
|
| 314 |
+
if os.path.exists(temp_file.name):
|
| 315 |
+
os.unlink(temp_file.name)
|
| 316 |
+
if logfire is not None:
|
| 317 |
+
logfire.debug('Temporary file cleaned up', temp_file=temp_file.name)
|
| 318 |
+
|
src/gesturedetection/config.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Configuration module for gesture detection system.
|
| 3 |
+
Handles environment variables and logfire token configuration.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Optional
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def get_logfire_token() -> Optional[str]:
|
| 12 |
+
"""
|
| 13 |
+
Get the logfire token from environment variables or local configuration.
|
| 14 |
+
|
| 15 |
+
Priority order:
|
| 16 |
+
1. LOGFIRE_TOKEN environment variable (for production/deployment)
|
| 17 |
+
2. .env file in project root (for local development)
|
| 18 |
+
3. None (monitoring disabled)
|
| 19 |
+
|
| 20 |
+
Returns
|
| 21 |
+
-------
|
| 22 |
+
Optional[str]
|
| 23 |
+
The logfire token if found, None otherwise
|
| 24 |
+
"""
|
| 25 |
+
# First check environment variable (for production)
|
| 26 |
+
token = os.getenv("LOGFIRE_TOKEN")
|
| 27 |
+
if token:
|
| 28 |
+
return token
|
| 29 |
+
|
| 30 |
+
# Check for .env file in project root (for local development)
|
| 31 |
+
env_file = Path(__file__).parent.parent.parent / ".env"
|
| 32 |
+
if env_file.exists():
|
| 33 |
+
try:
|
| 34 |
+
with open(env_file, "r") as f:
|
| 35 |
+
for line in f:
|
| 36 |
+
line = line.strip()
|
| 37 |
+
if line.startswith("LOGFIRE_TOKEN="):
|
| 38 |
+
return line.split("=", 1)[1].strip('"\'')
|
| 39 |
+
except Exception:
|
| 40 |
+
# If we can't read the .env file, continue without token
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def is_monitoring_enabled() -> bool:
|
| 47 |
+
"""
|
| 48 |
+
Check if monitoring is enabled by checking if we have a logfire token.
|
| 49 |
+
|
| 50 |
+
Returns
|
| 51 |
+
-------
|
| 52 |
+
bool
|
| 53 |
+
True if monitoring is enabled, False otherwise
|
| 54 |
+
"""
|
| 55 |
+
return get_logfire_token() is not None
|
src/gesturedetection/main_controller.py
ADDED
|
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
from .ocsort import (
|
| 4 |
+
KalmanBoxTracker,
|
| 5 |
+
associate,
|
| 6 |
+
ciou_batch,
|
| 7 |
+
ct_dist,
|
| 8 |
+
diou_batch,
|
| 9 |
+
giou_batch,
|
| 10 |
+
iou_batch,
|
| 11 |
+
linear_assignment,
|
| 12 |
+
)
|
| 13 |
+
from .onnx_models import HandClassification, HandDetection
|
| 14 |
+
from .utils import Deque, Drawer, Hand
|
| 15 |
+
from .config import is_monitoring_enabled
|
| 16 |
+
|
| 17 |
+
# Configure logfire monitoring if available
|
| 18 |
+
logfire = None
|
| 19 |
+
if is_monitoring_enabled():
|
| 20 |
+
try:
|
| 21 |
+
import logfire
|
| 22 |
+
except ImportError:
|
| 23 |
+
logfire = None
|
| 24 |
+
|
| 25 |
+
ASSO_FUNCS = {"iou": iou_batch, "giou": giou_batch, "ciou": ciou_batch, "diou": diou_batch, "ct_dist": ct_dist}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def k_previous_obs(observations, cur_age, k):
|
| 29 |
+
if len(observations) == 0:
|
| 30 |
+
return [-1, -1, -1, -1, -1]
|
| 31 |
+
for i in range(k):
|
| 32 |
+
dt = k - i
|
| 33 |
+
if cur_age - dt in observations:
|
| 34 |
+
return observations[cur_age - dt]
|
| 35 |
+
max_age = max(observations.keys())
|
| 36 |
+
return observations[max_age]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MainController:
|
| 40 |
+
"""
|
| 41 |
+
Main tracking function.
|
| 42 |
+
Class contains a list of tracks, each track contains a KalmanBoxTracker object and a Deque object with Hand objects.
|
| 43 |
+
"""
|
| 44 |
+
|
| 45 |
+
def __init__(
|
| 46 |
+
self, detection_model, classification_model, max_age=30, min_hits=3, iou_threshold=0.3, maxlen=30, min_frames=20
|
| 47 |
+
):
|
| 48 |
+
"""
|
| 49 |
+
Parameters
|
| 50 |
+
----------
|
| 51 |
+
detection_model : str
|
| 52 |
+
Path to detection model.
|
| 53 |
+
classification_model : str
|
| 54 |
+
Path to classification model.
|
| 55 |
+
max_age : int
|
| 56 |
+
Maximum age of track.
|
| 57 |
+
min_hits : int
|
| 58 |
+
Minimum number of hits to confirm track.
|
| 59 |
+
iou_threshold : float
|
| 60 |
+
IOU threshold for track association.
|
| 61 |
+
maxlen : int
|
| 62 |
+
Maximum length of deque in track.
|
| 63 |
+
min_frames : int
|
| 64 |
+
Minimum number of frames to confirm track.
|
| 65 |
+
"""
|
| 66 |
+
self.maxlen = maxlen
|
| 67 |
+
self.min_frames = min_frames
|
| 68 |
+
self.max_age = max_age
|
| 69 |
+
self.min_hits = min_hits
|
| 70 |
+
self.delta_t = 3
|
| 71 |
+
self.iou_threshold = iou_threshold
|
| 72 |
+
self.inertia = 0.2
|
| 73 |
+
self.asso_func = ASSO_FUNCS["giou"]
|
| 74 |
+
self.tracks = []
|
| 75 |
+
self.frame_count = 0
|
| 76 |
+
self.detection_model = HandDetection(detection_model)
|
| 77 |
+
self.classification_model = HandClassification(classification_model)
|
| 78 |
+
self.drawer = Drawer()
|
| 79 |
+
|
| 80 |
+
def update(self, dets=np.empty((0, 5)), labels=None):
|
| 81 |
+
"""
|
| 82 |
+
Parameters
|
| 83 |
+
----------
|
| 84 |
+
dets : np.array
|
| 85 |
+
Bounding boxes with shape [[x1,y1,x2,y2,score],[x1,y1,x2,y2,score],...] .
|
| 86 |
+
Requires: this method must be called once for each frame even with empty detections (use np.empty((0, 5)) for frames without detections).
|
| 87 |
+
labels : np.array
|
| 88 |
+
Labels with shape (N, 1) where N is number of bounding boxes.
|
| 89 |
+
|
| 90 |
+
Returns
|
| 91 |
+
-------
|
| 92 |
+
np.array
|
| 93 |
+
Returns the similar array, where the last column is the object ID.
|
| 94 |
+
|
| 95 |
+
Notes
|
| 96 |
+
-----
|
| 97 |
+
The number of objects returned may differ from the number of detections provided.
|
| 98 |
+
|
| 99 |
+
"""
|
| 100 |
+
# Advance frame count on every call to keep tracker state in sync with real time.
|
| 101 |
+
# This method is required to be called once per frame (even if there are no detections),
|
| 102 |
+
# so we must advance the internal Kalman filters and aging logic on empty frames as well.
|
| 103 |
+
self.frame_count += 1
|
| 104 |
+
|
| 105 |
+
# Get predicted locations from existing trackers for this frame.
|
| 106 |
+
# This advances age/time_since_update and is required also when there are no detections,
|
| 107 |
+
# ensuring tracks can age out (max_age) and do not persist indefinitely across gaps.
|
| 108 |
+
trks = np.zeros((len(self.tracks), 5))
|
| 109 |
+
to_del = []
|
| 110 |
+
ret = []
|
| 111 |
+
lbs = []
|
| 112 |
+
for t, trk in enumerate(trks):
|
| 113 |
+
pos = self.tracks[t]["tracker"].predict()[0]
|
| 114 |
+
trk[:] = [pos[0], pos[1], pos[2], pos[3], 0]
|
| 115 |
+
if np.any(np.isnan(pos)):
|
| 116 |
+
to_del.append(t)
|
| 117 |
+
trks = np.ma.compress_rows(np.ma.masked_invalid(trks))
|
| 118 |
+
for t in reversed(to_del):
|
| 119 |
+
self.tracks.pop(t)
|
| 120 |
+
|
| 121 |
+
velocities = np.array(
|
| 122 |
+
[
|
| 123 |
+
trk["tracker"].velocity if trk["tracker"].velocity is not None else np.array((0, 0))
|
| 124 |
+
for trk in self.tracks
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
last_boxes = np.array([trk["tracker"].last_observation for trk in self.tracks])
|
| 128 |
+
k_observations = np.array(
|
| 129 |
+
[k_previous_obs(trk["tracker"].observations, trk["tracker"].age, self.delta_t) for trk in self.tracks]
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
"""
|
| 133 |
+
First round of association
|
| 134 |
+
"""
|
| 135 |
+
matched, unmatched_dets, unmatched_trks = associate(
|
| 136 |
+
dets, trks, self.iou_threshold, velocities, k_observations, self.inertia
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
for m in matched:
|
| 140 |
+
self.tracks[m[1]]["tracker"].update(dets[m[0], :])
|
| 141 |
+
self.tracks[m[1]]["hands"].append(Hand(bbox=dets[m[0], :4], gesture=labels[m[0]]))
|
| 142 |
+
|
| 143 |
+
"""
|
| 144 |
+
Second round of associaton by OCR
|
| 145 |
+
"""
|
| 146 |
+
if unmatched_dets.shape[0] > 0 and unmatched_trks.shape[0] > 0:
|
| 147 |
+
left_dets = dets[unmatched_dets]
|
| 148 |
+
left_trks = last_boxes[unmatched_trks]
|
| 149 |
+
iou_left = self.asso_func(left_dets, left_trks)
|
| 150 |
+
iou_left = np.array(iou_left)
|
| 151 |
+
if iou_left.max() > self.iou_threshold:
|
| 152 |
+
"""
|
| 153 |
+
NOTE: by using a lower threshold, e.g., self.iou_threshold - 0.1, you may
|
| 154 |
+
get a higher performance especially on MOT17/MOT20 datasets. But we keep it
|
| 155 |
+
uniform here for simplicity
|
| 156 |
+
"""
|
| 157 |
+
rematched_indices = linear_assignment(-iou_left)
|
| 158 |
+
to_remove_det_indices = []
|
| 159 |
+
to_remove_trk_indices = []
|
| 160 |
+
for m in rematched_indices:
|
| 161 |
+
det_ind, trk_ind = unmatched_dets[m[0]], unmatched_trks[m[1]]
|
| 162 |
+
if iou_left[m[0], m[1]] < self.iou_threshold:
|
| 163 |
+
continue
|
| 164 |
+
self.tracks[trk_ind]["tracker"].update(dets[det_ind, :])
|
| 165 |
+
self.tracks[trk_ind]["hands"].append(Hand(bbox=dets[det_ind, :4], gesture=labels[det_ind]))
|
| 166 |
+
to_remove_det_indices.append(det_ind)
|
| 167 |
+
to_remove_trk_indices.append(trk_ind)
|
| 168 |
+
unmatched_dets = np.setdiff1d(unmatched_dets, np.array(to_remove_det_indices))
|
| 169 |
+
unmatched_trks = np.setdiff1d(unmatched_trks, np.array(to_remove_trk_indices))
|
| 170 |
+
|
| 171 |
+
# For unmatched trackers (including the case with no detections),
|
| 172 |
+
# update with None to keep the filter consistent and append a dummy Hand.
|
| 173 |
+
for m in unmatched_trks:
|
| 174 |
+
self.tracks[m]["tracker"].update(None)
|
| 175 |
+
self.tracks[m]["hands"].append(Hand(bbox=None, gesture=None))
|
| 176 |
+
|
| 177 |
+
# create and initialise new trackers for unmatched detections
|
| 178 |
+
for i in unmatched_dets:
|
| 179 |
+
self.tracks.append(
|
| 180 |
+
{
|
| 181 |
+
"hands": Deque(self.maxlen, self.min_frames),
|
| 182 |
+
"tracker": KalmanBoxTracker(dets[i, :], delta_t=self.delta_t),
|
| 183 |
+
}
|
| 184 |
+
)
|
| 185 |
+
i = len(self.tracks)
|
| 186 |
+
for trk in reversed(self.tracks):
|
| 187 |
+
if trk["tracker"].last_observation.sum() < 0:
|
| 188 |
+
d = trk["tracker"].get_state()[0]
|
| 189 |
+
else:
|
| 190 |
+
"""
|
| 191 |
+
this is optional to use the recent observation or the kalman filter prediction,
|
| 192 |
+
we didn't notice significant difference here
|
| 193 |
+
"""
|
| 194 |
+
d = trk["tracker"].last_observation[:4]
|
| 195 |
+
if (trk["tracker"].time_since_update < 1) and (
|
| 196 |
+
trk["tracker"].hit_streak >= self.min_hits or self.frame_count <= self.min_hits
|
| 197 |
+
):
|
| 198 |
+
# +1 as MOT benchmark requires positive
|
| 199 |
+
ret.append(np.concatenate((d, [trk["tracker"].id + 1])).reshape(1, -1))
|
| 200 |
+
if len(trk["hands"]) > 0:
|
| 201 |
+
lbs.append(trk["hands"][-1].gesture)
|
| 202 |
+
else:
|
| 203 |
+
lbs.append(None)
|
| 204 |
+
|
| 205 |
+
i -= 1
|
| 206 |
+
# remove dead tracklet
|
| 207 |
+
if trk["tracker"].time_since_update > self.max_age:
|
| 208 |
+
self.tracks.pop(i)
|
| 209 |
+
if len(ret) > 0:
|
| 210 |
+
return np.concatenate(ret), lbs
|
| 211 |
+
return np.empty((0, 5)), np.empty((0, 1))
|
| 212 |
+
|
| 213 |
+
def __call__(self, frame):
|
| 214 |
+
"""
|
| 215 |
+
Parameters
|
| 216 |
+
----------
|
| 217 |
+
frame : np.array
|
| 218 |
+
Image frame with shape (H, W, 3).
|
| 219 |
+
|
| 220 |
+
Returns
|
| 221 |
+
-------
|
| 222 |
+
list of np.array
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
"""
|
| 226 |
+
# Log frame processing if monitoring is enabled
|
| 227 |
+
if logfire is not None:
|
| 228 |
+
with logfire.span('frame_processing', frame_shape=frame.shape):
|
| 229 |
+
bboxes, probs = self.detection_model(frame)
|
| 230 |
+
|
| 231 |
+
if len(bboxes):
|
| 232 |
+
detection_scores = np.asarray(probs).tolist()
|
| 233 |
+
logfire.debug(
|
| 234 |
+
'Hand detections found',
|
| 235 |
+
num_detections=len(bboxes),
|
| 236 |
+
detection_scores=detection_scores,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
labels = self.classification_model(frame, bboxes)
|
| 240 |
+
bboxes = np.concatenate((bboxes, np.expand_dims(probs, axis=1)), axis=1)
|
| 241 |
+
new_bboxes, labels = self.update(dets=bboxes, labels=labels)
|
| 242 |
+
|
| 243 |
+
# Log classification results
|
| 244 |
+
if labels is not None and len(labels) > 0:
|
| 245 |
+
labels_list = np.asarray(labels).tolist()
|
| 246 |
+
gesture_names = [
|
| 247 |
+
f"gesture_{label}" if label is not None else "none"
|
| 248 |
+
for label in labels_list
|
| 249 |
+
]
|
| 250 |
+
logfire.debug(
|
| 251 |
+
'Gesture classifications',
|
| 252 |
+
labels=labels_list,
|
| 253 |
+
gesture_names=gesture_names,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return new_bboxes[:, :-1], new_bboxes[:, -1], labels
|
| 257 |
+
else:
|
| 258 |
+
logfire.debug('No hand detections in frame')
|
| 259 |
+
self.update(np.empty((0, 5)), None)
|
| 260 |
+
return None, None, None
|
| 261 |
+
else:
|
| 262 |
+
# Original logic without monitoring
|
| 263 |
+
bboxes, probs = self.detection_model(frame)
|
| 264 |
+
if len(bboxes):
|
| 265 |
+
labels = self.classification_model(frame, bboxes)
|
| 266 |
+
bboxes = np.concatenate((bboxes, np.expand_dims(probs, axis=1)), axis=1)
|
| 267 |
+
new_bboxes, labels = self.update(dets=bboxes, labels=labels)
|
| 268 |
+
return new_bboxes[:, :-1], new_bboxes[:, -1], labels
|
| 269 |
+
else:
|
| 270 |
+
self.update(np.empty((0, 5)), None)
|
| 271 |
+
return None, None, None
|
src/gesturedetection/models.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
from typing import List, Optional
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Gesture(BaseModel):
|
| 6 |
+
"""Represents a detected gesture with metadata."""
|
| 7 |
+
gesture: str
|
| 8 |
+
duration: int # Duration in frames
|
| 9 |
+
confidence: float
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GestureResponse(BaseModel):
|
| 13 |
+
"""Response model containing a list of detected gestures."""
|
| 14 |
+
gestures: List[Gesture]
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Primary gesture mappings for the main gestures + additional ones
|
| 18 |
+
GESTURE_MAPPING = {
|
| 19 |
+
# Original 5 main gestures
|
| 20 |
+
27: "thumbs_up", # like
|
| 21 |
+
31: "palm", # open palm wave (5 fingers)
|
| 22 |
+
32: "peace", # peace sign (2 fingers)
|
| 23 |
+
29: "ok", # OK sign
|
| 24 |
+
20: "call", # call me (little finger)
|
| 25 |
+
|
| 26 |
+
# Finger counting (1-5)
|
| 27 |
+
30: "one", # 1 finger
|
| 28 |
+
39: "two_up", # 2 fingers (peace sign)
|
| 29 |
+
37: "three", # 3 fingers
|
| 30 |
+
26: "four", # 4 fingers
|
| 31 |
+
# Note: 5 fingers is same as palm (31)
|
| 32 |
+
|
| 33 |
+
# Surprise gesture
|
| 34 |
+
23: "middle_finger", # middle finger (surprise!)
|
| 35 |
+
|
| 36 |
+
# Additional useful gestures
|
| 37 |
+
25: "fist", # closed fist
|
| 38 |
+
19: "point", # pointing with index finger
|
| 39 |
+
35: "stop", # stop gesture
|
| 40 |
+
}
|
| 41 |
+
|
| 42 |
+
# Additional gesture mappings for completeness
|
| 43 |
+
FULL_GESTURE_MAPPING = {
|
| 44 |
+
0: "hand_down",
|
| 45 |
+
1: "hand_right",
|
| 46 |
+
2: "hand_left",
|
| 47 |
+
3: "thumb_index",
|
| 48 |
+
4: "thumb_left",
|
| 49 |
+
5: "thumb_right",
|
| 50 |
+
6: "thumb_down",
|
| 51 |
+
7: "half_up",
|
| 52 |
+
8: "half_left",
|
| 53 |
+
9: "half_right",
|
| 54 |
+
10: "half_down",
|
| 55 |
+
11: "part_hand_heart",
|
| 56 |
+
12: "part_hand_heart2",
|
| 57 |
+
13: "fist_inverted",
|
| 58 |
+
14: "two_left",
|
| 59 |
+
15: "two_right",
|
| 60 |
+
16: "two_down",
|
| 61 |
+
17: "grabbing",
|
| 62 |
+
18: "grip",
|
| 63 |
+
19: "point",
|
| 64 |
+
20: "call",
|
| 65 |
+
21: "three3",
|
| 66 |
+
22: "little_finger",
|
| 67 |
+
23: "middle_finger",
|
| 68 |
+
24: "dislike",
|
| 69 |
+
25: "fist",
|
| 70 |
+
26: "four",
|
| 71 |
+
27: "like",
|
| 72 |
+
28: "mute",
|
| 73 |
+
29: "ok",
|
| 74 |
+
30: "one",
|
| 75 |
+
31: "palm",
|
| 76 |
+
32: "peace",
|
| 77 |
+
33: "peace_inverted",
|
| 78 |
+
34: "rock",
|
| 79 |
+
35: "stop",
|
| 80 |
+
36: "stop_inverted",
|
| 81 |
+
37: "three",
|
| 82 |
+
38: "three2",
|
| 83 |
+
39: "two_up",
|
| 84 |
+
40: "two_up_inverted",
|
| 85 |
+
41: "three_gun",
|
| 86 |
+
42: "one_left",
|
| 87 |
+
43: "one_right",
|
| 88 |
+
44: "one_down"
|
| 89 |
+
}
|
src/gesturedetection/ocsort/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .association import associate, ciou_batch, ct_dist, diou_batch, giou_batch, iou_batch, linear_assignment
|
| 2 |
+
from .kalmanboxtracker import KalmanBoxTracker
|
src/gesturedetection/ocsort/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (434 Bytes). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (400 Bytes). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/association.cpython-312.pyc
ADDED
|
Binary file (22.3 kB). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/association.cpython-39.pyc
ADDED
|
Binary file (11.4 kB). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/kalmanboxtracker.cpython-312.pyc
ADDED
|
Binary file (7.89 kB). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/kalmanboxtracker.cpython-39.pyc
ADDED
|
Binary file (4.63 kB). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/kalmanfilter.cpython-312.pyc
ADDED
|
Binary file (69.3 kB). View file
|
|
|
src/gesturedetection/ocsort/__pycache__/kalmanfilter.cpython-39.pyc
ADDED
|
Binary file (50.3 kB). View file
|
|
|
src/gesturedetection/ocsort/association.py
ADDED
|
@@ -0,0 +1,511 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def iou_batch(bboxes1, bboxes2):
|
| 5 |
+
"""
|
| 6 |
+
Calculate the Intersection of Unions (IoUs) between bounding boxes.
|
| 7 |
+
Parameters
|
| 8 |
+
----------
|
| 9 |
+
bboxes1: numpy.ndarray
|
| 10 |
+
shape is [N, 4]
|
| 11 |
+
bboxes2: numpy.ndarray
|
| 12 |
+
shape is [M, 4]
|
| 13 |
+
|
| 14 |
+
Returns
|
| 15 |
+
-------
|
| 16 |
+
ious: numpy.ndarray
|
| 17 |
+
shape is [N, M]
|
| 18 |
+
"""
|
| 19 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
| 20 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
| 21 |
+
|
| 22 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
| 23 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
| 24 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
| 25 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
| 26 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 27 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 28 |
+
wh = w * h
|
| 29 |
+
o = wh / (
|
| 30 |
+
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
| 31 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
| 32 |
+
- wh
|
| 33 |
+
)
|
| 34 |
+
return o
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def giou_batch(bboxes1, bboxes2):
|
| 38 |
+
"""
|
| 39 |
+
Calculate the Generalized Intersection over Union (GIoUs) between bounding boxes.
|
| 40 |
+
Parameters
|
| 41 |
+
----------
|
| 42 |
+
bboxes1: numpy.ndarray
|
| 43 |
+
shape is [N, 4]
|
| 44 |
+
bboxes2: numpy.ndarray
|
| 45 |
+
shape is [M, 4]
|
| 46 |
+
|
| 47 |
+
Returns
|
| 48 |
+
-------
|
| 49 |
+
gious: numpy.ndarray
|
| 50 |
+
shape is [N, M]
|
| 51 |
+
"""
|
| 52 |
+
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
| 53 |
+
# ensure predict's bbox form
|
| 54 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
| 55 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
| 56 |
+
|
| 57 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
| 58 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
| 59 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
| 60 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
| 61 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 62 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 63 |
+
wh = w * h
|
| 64 |
+
union = (
|
| 65 |
+
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
| 66 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
| 67 |
+
- wh
|
| 68 |
+
)
|
| 69 |
+
iou = wh / union
|
| 70 |
+
|
| 71 |
+
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
| 72 |
+
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
| 73 |
+
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
| 74 |
+
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
| 75 |
+
wc = xxc2 - xxc1
|
| 76 |
+
hc = yyc2 - yyc1
|
| 77 |
+
assert (wc > 0).all() and (hc > 0).all()
|
| 78 |
+
area_enclose = wc * hc
|
| 79 |
+
giou = iou - (area_enclose - union) / area_enclose
|
| 80 |
+
giou = (giou + 1.0) / 2.0 # resize from (-1,1) to (0,1)
|
| 81 |
+
return giou
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def diou_batch(bboxes1, bboxes2):
|
| 85 |
+
"""
|
| 86 |
+
Calculate the Distance Intersection over Union (DIoUs) between bounding boxes.
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
bboxes1: numpy.ndarray
|
| 90 |
+
shape is [N, 4]
|
| 91 |
+
|
| 92 |
+
bboxes2: numpy.ndarray
|
| 93 |
+
shape is [M, 4]
|
| 94 |
+
|
| 95 |
+
Returns
|
| 96 |
+
-------
|
| 97 |
+
dious: numpy.ndarray
|
| 98 |
+
"""
|
| 99 |
+
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
| 100 |
+
# ensure predict's bbox form
|
| 101 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
| 102 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
| 103 |
+
|
| 104 |
+
# calculate the intersection box
|
| 105 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
| 106 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
| 107 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
| 108 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
| 109 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 110 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 111 |
+
wh = w * h
|
| 112 |
+
union = (
|
| 113 |
+
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
| 114 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
| 115 |
+
- wh
|
| 116 |
+
)
|
| 117 |
+
iou = wh / union
|
| 118 |
+
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
| 119 |
+
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
| 120 |
+
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
| 121 |
+
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
| 122 |
+
|
| 123 |
+
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
| 124 |
+
|
| 125 |
+
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
| 126 |
+
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
| 127 |
+
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
| 128 |
+
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
| 129 |
+
|
| 130 |
+
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
| 131 |
+
diou = iou - inner_diag / outer_diag
|
| 132 |
+
|
| 133 |
+
return (diou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def ciou_batch(bboxes1, bboxes2):
|
| 137 |
+
"""
|
| 138 |
+
Calculate the Complete Intersection over Union (CIoUs) between bounding boxes.
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
bboxes1: numpy.ndarray
|
| 142 |
+
shape is [N, 4]
|
| 143 |
+
|
| 144 |
+
bboxes2: numpy.ndarray
|
| 145 |
+
shape is [M, 4]
|
| 146 |
+
|
| 147 |
+
Returns
|
| 148 |
+
-------
|
| 149 |
+
ciou: numpy.ndarray
|
| 150 |
+
"""
|
| 151 |
+
# for details should go to https://arxiv.org/pdf/1902.09630.pdf
|
| 152 |
+
# ensure predict's bbox form
|
| 153 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
| 154 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
| 155 |
+
|
| 156 |
+
# calculate the intersection box
|
| 157 |
+
xx1 = np.maximum(bboxes1[..., 0], bboxes2[..., 0])
|
| 158 |
+
yy1 = np.maximum(bboxes1[..., 1], bboxes2[..., 1])
|
| 159 |
+
xx2 = np.minimum(bboxes1[..., 2], bboxes2[..., 2])
|
| 160 |
+
yy2 = np.minimum(bboxes1[..., 3], bboxes2[..., 3])
|
| 161 |
+
w = np.maximum(0.0, xx2 - xx1)
|
| 162 |
+
h = np.maximum(0.0, yy2 - yy1)
|
| 163 |
+
wh = w * h
|
| 164 |
+
union = (
|
| 165 |
+
(bboxes1[..., 2] - bboxes1[..., 0]) * (bboxes1[..., 3] - bboxes1[..., 1])
|
| 166 |
+
+ (bboxes2[..., 2] - bboxes2[..., 0]) * (bboxes2[..., 3] - bboxes2[..., 1])
|
| 167 |
+
- wh
|
| 168 |
+
)
|
| 169 |
+
iou = wh / union
|
| 170 |
+
|
| 171 |
+
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
| 172 |
+
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
| 173 |
+
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
| 174 |
+
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
| 175 |
+
|
| 176 |
+
inner_diag = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
| 177 |
+
|
| 178 |
+
xxc1 = np.minimum(bboxes1[..., 0], bboxes2[..., 0])
|
| 179 |
+
yyc1 = np.minimum(bboxes1[..., 1], bboxes2[..., 1])
|
| 180 |
+
xxc2 = np.maximum(bboxes1[..., 2], bboxes2[..., 2])
|
| 181 |
+
yyc2 = np.maximum(bboxes1[..., 3], bboxes2[..., 3])
|
| 182 |
+
|
| 183 |
+
outer_diag = (xxc2 - xxc1) ** 2 + (yyc2 - yyc1) ** 2
|
| 184 |
+
|
| 185 |
+
w1 = bboxes1[..., 2] - bboxes1[..., 0]
|
| 186 |
+
h1 = bboxes1[..., 3] - bboxes1[..., 1]
|
| 187 |
+
w2 = bboxes2[..., 2] - bboxes2[..., 0]
|
| 188 |
+
h2 = bboxes2[..., 3] - bboxes2[..., 1]
|
| 189 |
+
|
| 190 |
+
# prevent dividing over zero. add one pixel shift
|
| 191 |
+
h2 = h2 + 1.0
|
| 192 |
+
h1 = h1 + 1.0
|
| 193 |
+
arctan = np.arctan(w2 / h2) - np.arctan(w1 / h1)
|
| 194 |
+
v = (4 / (np.pi**2)) * (arctan**2)
|
| 195 |
+
S = 1 - iou
|
| 196 |
+
alpha = v / (S + v)
|
| 197 |
+
ciou = iou - inner_diag / outer_diag - alpha * v
|
| 198 |
+
|
| 199 |
+
return (ciou + 1) / 2.0 # resize from (-1,1) to (0,1)
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def ct_dist(bboxes1, bboxes2):
|
| 203 |
+
"""
|
| 204 |
+
Measure the center distance between two sets of bounding boxes,
|
| 205 |
+
this is a coarse implementation, we don't recommend using it only
|
| 206 |
+
for association, which can be unstable and sensitive to frame rate
|
| 207 |
+
and object speed.
|
| 208 |
+
Parameters
|
| 209 |
+
----------
|
| 210 |
+
bboxes1: numpy.ndarray
|
| 211 |
+
shape is [N, 4]
|
| 212 |
+
|
| 213 |
+
bboxes2: numpy.ndarray
|
| 214 |
+
shape is [M, 4]
|
| 215 |
+
|
| 216 |
+
Returns
|
| 217 |
+
-------
|
| 218 |
+
ct_dist: numpy.ndarray
|
| 219 |
+
"""
|
| 220 |
+
bboxes2 = np.expand_dims(bboxes2, 0)
|
| 221 |
+
bboxes1 = np.expand_dims(bboxes1, 1)
|
| 222 |
+
|
| 223 |
+
centerx1 = (bboxes1[..., 0] + bboxes1[..., 2]) / 2.0
|
| 224 |
+
centery1 = (bboxes1[..., 1] + bboxes1[..., 3]) / 2.0
|
| 225 |
+
centerx2 = (bboxes2[..., 0] + bboxes2[..., 2]) / 2.0
|
| 226 |
+
centery2 = (bboxes2[..., 1] + bboxes2[..., 3]) / 2.0
|
| 227 |
+
|
| 228 |
+
ct_dist2 = (centerx1 - centerx2) ** 2 + (centery1 - centery2) ** 2
|
| 229 |
+
|
| 230 |
+
ct_dist = np.sqrt(ct_dist2)
|
| 231 |
+
|
| 232 |
+
# The linear rescaling is a naive version and needs more study
|
| 233 |
+
ct_dist = ct_dist / ct_dist.max()
|
| 234 |
+
return ct_dist.max() - ct_dist # resize to (0,1)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def speed_direction_batch(dets, tracks):
|
| 238 |
+
"""
|
| 239 |
+
Calculate the speed and direction between detections and tracks.
|
| 240 |
+
Parameters
|
| 241 |
+
----------
|
| 242 |
+
dets: numpy.ndarray
|
| 243 |
+
shape is [N, 4]
|
| 244 |
+
|
| 245 |
+
tracks: numpy.ndarray
|
| 246 |
+
shape is [M, 4]
|
| 247 |
+
|
| 248 |
+
Returns
|
| 249 |
+
-------
|
| 250 |
+
dy: numpy.ndarray
|
| 251 |
+
dx: numpy.ndarray
|
| 252 |
+
|
| 253 |
+
"""
|
| 254 |
+
tracks = tracks[..., np.newaxis]
|
| 255 |
+
CX1, CY1 = (dets[:, 0] + dets[:, 2]) / 2.0, (dets[:, 1] + dets[:, 3]) / 2.0
|
| 256 |
+
CX2, CY2 = (tracks[:, 0] + tracks[:, 2]) / 2.0, (tracks[:, 1] + tracks[:, 3]) / 2.0
|
| 257 |
+
dx = CX1 - CX2
|
| 258 |
+
dy = CY1 - CY2
|
| 259 |
+
norm = np.sqrt(dx**2 + dy**2) + 1e-6
|
| 260 |
+
dx = dx / norm
|
| 261 |
+
dy = dy / norm
|
| 262 |
+
return dy, dx # size: num_track x num_det
|
| 263 |
+
|
| 264 |
+
|
| 265 |
+
def linear_assignment(cost_matrix):
|
| 266 |
+
"""
|
| 267 |
+
Solve the linear assignment problem using scipy.optimize.linear_sum_assignment.
|
| 268 |
+
Parameters
|
| 269 |
+
----------
|
| 270 |
+
cost_matrix: numpy.ndarray
|
| 271 |
+
shape is [N, M]
|
| 272 |
+
|
| 273 |
+
Returns
|
| 274 |
+
-------
|
| 275 |
+
indices: numpy.ndarray
|
| 276 |
+
shape is [N, 2]
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
import lap
|
| 280 |
+
|
| 281 |
+
_, x, y = lap.lapjv(cost_matrix, extend_cost=True)
|
| 282 |
+
return np.array([[y[i], i] for i in x if i >= 0]) #
|
| 283 |
+
except ImportError:
|
| 284 |
+
from scipy.optimize import linear_sum_assignment
|
| 285 |
+
|
| 286 |
+
x, y = linear_sum_assignment(cost_matrix)
|
| 287 |
+
return np.array(list(zip(x, y)))
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def associate_detections_to_trackers(detections, trackers, iou_threshold=0.3):
|
| 291 |
+
"""
|
| 292 |
+
Assigns detections to tracked object (both represented as bounding boxes)
|
| 293 |
+
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
|
| 294 |
+
Parameters
|
| 295 |
+
----------
|
| 296 |
+
|
| 297 |
+
detections: numpy.ndarray
|
| 298 |
+
shape is [N, 4]
|
| 299 |
+
|
| 300 |
+
trackers: numpy.ndarray
|
| 301 |
+
shape is [M, 4]
|
| 302 |
+
|
| 303 |
+
iou_threshold: float
|
| 304 |
+
in [0, 1]. Default is 0.3
|
| 305 |
+
"""
|
| 306 |
+
if len(trackers) == 0:
|
| 307 |
+
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
|
| 308 |
+
|
| 309 |
+
iou_matrix = iou_batch(detections, trackers)
|
| 310 |
+
|
| 311 |
+
if min(iou_matrix.shape) > 0:
|
| 312 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
| 313 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
| 314 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
| 315 |
+
else:
|
| 316 |
+
matched_indices = linear_assignment(-iou_matrix)
|
| 317 |
+
else:
|
| 318 |
+
matched_indices = np.empty(shape=(0, 2))
|
| 319 |
+
|
| 320 |
+
unmatched_detections = []
|
| 321 |
+
for d, det in enumerate(detections):
|
| 322 |
+
if d not in matched_indices[:, 0]:
|
| 323 |
+
unmatched_detections.append(d)
|
| 324 |
+
unmatched_trackers = []
|
| 325 |
+
for t, trk in enumerate(trackers):
|
| 326 |
+
if t not in matched_indices[:, 1]:
|
| 327 |
+
unmatched_trackers.append(t)
|
| 328 |
+
|
| 329 |
+
# filter out matched with low IOU
|
| 330 |
+
matches = []
|
| 331 |
+
for m in matched_indices:
|
| 332 |
+
if iou_matrix[m[0], m[1]] < iou_threshold:
|
| 333 |
+
unmatched_detections.append(m[0])
|
| 334 |
+
unmatched_trackers.append(m[1])
|
| 335 |
+
else:
|
| 336 |
+
matches.append(m.reshape(1, 2))
|
| 337 |
+
if len(matches) == 0:
|
| 338 |
+
matches = np.empty((0, 2), dtype=int)
|
| 339 |
+
else:
|
| 340 |
+
matches = np.concatenate(matches, axis=0)
|
| 341 |
+
|
| 342 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def associate(detections, trackers, iou_threshold, velocities, previous_obs, vdc_weight):
|
| 346 |
+
"""
|
| 347 |
+
Assigns detections to tracked object (both represented as bounding boxes)
|
| 348 |
+
Returns 3 lists of matches, unmatched_detections and unmatched_trackers
|
| 349 |
+
Parameters
|
| 350 |
+
----------
|
| 351 |
+
detections: numpy.ndarray
|
| 352 |
+
shape is [N, 4]
|
| 353 |
+
trackers: numpy.ndarray
|
| 354 |
+
shape is [M, 4]
|
| 355 |
+
iou_threshold: float
|
| 356 |
+
in [0, 1]. Default is 0.3
|
| 357 |
+
velocities: numpy.ndarray
|
| 358 |
+
shape is [M, 2]
|
| 359 |
+
previous_obs: numpy.ndarray
|
| 360 |
+
shape is [M, 4]
|
| 361 |
+
vdc_weight: float
|
| 362 |
+
"""
|
| 363 |
+
if len(trackers) == 0:
|
| 364 |
+
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
|
| 365 |
+
|
| 366 |
+
Y, X = speed_direction_batch(detections, previous_obs)
|
| 367 |
+
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
|
| 368 |
+
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
| 369 |
+
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
| 370 |
+
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
| 371 |
+
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
| 372 |
+
diff_angle = np.arccos(diff_angle_cos)
|
| 373 |
+
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
|
| 374 |
+
|
| 375 |
+
valid_mask = np.ones(previous_obs.shape[0])
|
| 376 |
+
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
|
| 377 |
+
|
| 378 |
+
iou_matrix = iou_batch(detections, trackers)
|
| 379 |
+
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
|
| 380 |
+
# iou_matrix = iou_matrix * scores # a trick sometiems works, we don't encourage this
|
| 381 |
+
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
| 382 |
+
|
| 383 |
+
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
| 384 |
+
angle_diff_cost = angle_diff_cost.T
|
| 385 |
+
angle_diff_cost = angle_diff_cost * scores
|
| 386 |
+
|
| 387 |
+
if min(iou_matrix.shape) > 0:
|
| 388 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
| 389 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
| 390 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
| 391 |
+
else:
|
| 392 |
+
matched_indices = linear_assignment(-(iou_matrix + angle_diff_cost))
|
| 393 |
+
else:
|
| 394 |
+
matched_indices = np.empty(shape=(0, 2))
|
| 395 |
+
|
| 396 |
+
unmatched_detections = []
|
| 397 |
+
for d, det in enumerate(detections):
|
| 398 |
+
if d not in matched_indices[:, 0]:
|
| 399 |
+
unmatched_detections.append(d)
|
| 400 |
+
unmatched_trackers = []
|
| 401 |
+
for t, trk in enumerate(trackers):
|
| 402 |
+
if t not in matched_indices[:, 1]:
|
| 403 |
+
unmatched_trackers.append(t)
|
| 404 |
+
|
| 405 |
+
# filter out matched with low IOU
|
| 406 |
+
matches = []
|
| 407 |
+
for m in matched_indices:
|
| 408 |
+
if iou_matrix[m[0], m[1]] < iou_threshold:
|
| 409 |
+
unmatched_detections.append(m[0])
|
| 410 |
+
unmatched_trackers.append(m[1])
|
| 411 |
+
else:
|
| 412 |
+
matches.append(m.reshape(1, 2))
|
| 413 |
+
if len(matches) == 0:
|
| 414 |
+
matches = np.empty((0, 2), dtype=int)
|
| 415 |
+
else:
|
| 416 |
+
matches = np.concatenate(matches, axis=0)
|
| 417 |
+
|
| 418 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def associate_kitti(detections, trackers, det_cates, iou_threshold, velocities, previous_obs, vdc_weight):
|
| 422 |
+
if len(trackers) == 0:
|
| 423 |
+
return np.empty((0, 2), dtype=int), np.arange(len(detections)), np.empty((0, 5), dtype=int)
|
| 424 |
+
|
| 425 |
+
"""
|
| 426 |
+
Cost from the velocity direction consistency
|
| 427 |
+
Parameters
|
| 428 |
+
----------
|
| 429 |
+
detections: numpy.ndarray
|
| 430 |
+
shape is [N, 4]
|
| 431 |
+
trackers: numpy.ndarray
|
| 432 |
+
shape is [M, 4]
|
| 433 |
+
det_cates: numpy.ndarray
|
| 434 |
+
shape is [N, 1]
|
| 435 |
+
iou_threshold: float
|
| 436 |
+
in [0, 1]. Default is 0.3
|
| 437 |
+
velocities: numpy.ndarray
|
| 438 |
+
shape is [M, 2]
|
| 439 |
+
previous_obs: numpy.ndarray
|
| 440 |
+
shape is [M, 4]
|
| 441 |
+
vdc_weight: float
|
| 442 |
+
|
| 443 |
+
"""
|
| 444 |
+
Y, X = speed_direction_batch(detections, previous_obs)
|
| 445 |
+
inertia_Y, inertia_X = velocities[:, 0], velocities[:, 1]
|
| 446 |
+
inertia_Y = np.repeat(inertia_Y[:, np.newaxis], Y.shape[1], axis=1)
|
| 447 |
+
inertia_X = np.repeat(inertia_X[:, np.newaxis], X.shape[1], axis=1)
|
| 448 |
+
diff_angle_cos = inertia_X * X + inertia_Y * Y
|
| 449 |
+
diff_angle_cos = np.clip(diff_angle_cos, a_min=-1, a_max=1)
|
| 450 |
+
diff_angle = np.arccos(diff_angle_cos)
|
| 451 |
+
diff_angle = (np.pi / 2.0 - np.abs(diff_angle)) / np.pi
|
| 452 |
+
|
| 453 |
+
valid_mask = np.ones(previous_obs.shape[0])
|
| 454 |
+
valid_mask[np.where(previous_obs[:, 4] < 0)] = 0
|
| 455 |
+
valid_mask = np.repeat(valid_mask[:, np.newaxis], X.shape[1], axis=1)
|
| 456 |
+
|
| 457 |
+
scores = np.repeat(detections[:, -1][:, np.newaxis], trackers.shape[0], axis=1)
|
| 458 |
+
angle_diff_cost = (valid_mask * diff_angle) * vdc_weight
|
| 459 |
+
angle_diff_cost = angle_diff_cost.T
|
| 460 |
+
angle_diff_cost = angle_diff_cost * scores
|
| 461 |
+
|
| 462 |
+
"""
|
| 463 |
+
Cost from IoU
|
| 464 |
+
"""
|
| 465 |
+
iou_matrix = iou_batch(detections, trackers)
|
| 466 |
+
|
| 467 |
+
"""
|
| 468 |
+
With multiple categories, generate the cost for catgory mismatch
|
| 469 |
+
"""
|
| 470 |
+
num_dets = detections.shape[0]
|
| 471 |
+
num_trk = trackers.shape[0]
|
| 472 |
+
cate_matrix = np.zeros((num_dets, num_trk))
|
| 473 |
+
for i in range(num_dets):
|
| 474 |
+
for j in range(num_trk):
|
| 475 |
+
if det_cates[i] != trackers[j, 4]:
|
| 476 |
+
cate_matrix[i][j] = -1e6
|
| 477 |
+
|
| 478 |
+
cost_matrix = -iou_matrix - angle_diff_cost - cate_matrix
|
| 479 |
+
|
| 480 |
+
if min(iou_matrix.shape) > 0:
|
| 481 |
+
a = (iou_matrix > iou_threshold).astype(np.int32)
|
| 482 |
+
if a.sum(1).max() == 1 and a.sum(0).max() == 1:
|
| 483 |
+
matched_indices = np.stack(np.where(a), axis=1)
|
| 484 |
+
else:
|
| 485 |
+
matched_indices = linear_assignment(cost_matrix)
|
| 486 |
+
else:
|
| 487 |
+
matched_indices = np.empty(shape=(0, 2))
|
| 488 |
+
|
| 489 |
+
unmatched_detections = []
|
| 490 |
+
for d, det in enumerate(detections):
|
| 491 |
+
if d not in matched_indices[:, 0]:
|
| 492 |
+
unmatched_detections.append(d)
|
| 493 |
+
unmatched_trackers = []
|
| 494 |
+
for t, trk in enumerate(trackers):
|
| 495 |
+
if t not in matched_indices[:, 1]:
|
| 496 |
+
unmatched_trackers.append(t)
|
| 497 |
+
|
| 498 |
+
# filter out matched with low IOU
|
| 499 |
+
matches = []
|
| 500 |
+
for m in matched_indices:
|
| 501 |
+
if iou_matrix[m[0], m[1]] < iou_threshold:
|
| 502 |
+
unmatched_detections.append(m[0])
|
| 503 |
+
unmatched_trackers.append(m[1])
|
| 504 |
+
else:
|
| 505 |
+
matches.append(m.reshape(1, 2))
|
| 506 |
+
if len(matches) == 0:
|
| 507 |
+
matches = np.empty((0, 2), dtype=int)
|
| 508 |
+
else:
|
| 509 |
+
matches = np.concatenate(matches, axis=0)
|
| 510 |
+
|
| 511 |
+
return matches, np.array(unmatched_detections), np.array(unmatched_trackers)
|
src/gesturedetection/ocsort/kalmanboxtracker.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import print_function
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def convert_bbox_to_z(bbox):
|
| 7 |
+
"""
|
| 8 |
+
Takes a bounding box in the form [x1,y1,x2,y2] and returns z in the form
|
| 9 |
+
[x,y,s,r] where x,y is the centre of the box and s is the scale/area and r is
|
| 10 |
+
the aspect ratio
|
| 11 |
+
"""
|
| 12 |
+
w = bbox[2] - bbox[0]
|
| 13 |
+
h = bbox[3] - bbox[1]
|
| 14 |
+
x = bbox[0] + w / 2.0
|
| 15 |
+
y = bbox[1] + h / 2.0
|
| 16 |
+
s = w * h # scale is just area
|
| 17 |
+
r = w / float(h + 1e-6)
|
| 18 |
+
return np.array([x, y, s, r]).reshape((4, 1))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def speed_direction(bbox1, bbox2):
|
| 22 |
+
cx1, cy1 = (bbox1[0] + bbox1[2]) / 2.0, (bbox1[1] + bbox1[3]) / 2.0
|
| 23 |
+
cx2, cy2 = (bbox2[0] + bbox2[2]) / 2.0, (bbox2[1] + bbox2[3]) / 2.0
|
| 24 |
+
speed = np.array([cy2 - cy1, cx2 - cx1])
|
| 25 |
+
norm = np.sqrt((cy2 - cy1) ** 2 + (cx2 - cx1) ** 2) + 1e-6
|
| 26 |
+
return speed / norm
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def convert_x_to_bbox(x, score=None):
|
| 30 |
+
"""
|
| 31 |
+
Takes a bounding box in the centre form [x,y,s,r] and returns it in the form
|
| 32 |
+
[x1,y1,x2,y2] where x1,y1 is the top left and x2,y2 is the bottom right
|
| 33 |
+
"""
|
| 34 |
+
w = np.sqrt(x[2] * x[3])
|
| 35 |
+
h = x[2] / w
|
| 36 |
+
if score is None:
|
| 37 |
+
return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0]).reshape((1, 4))
|
| 38 |
+
else:
|
| 39 |
+
return np.array([x[0] - w / 2.0, x[1] - h / 2.0, x[0] + w / 2.0, x[1] + h / 2.0, score]).reshape((1, 5))
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class KalmanBoxTracker(object):
|
| 43 |
+
"""
|
| 44 |
+
This class represents the internal state of individual tracked objects observed as bbox.
|
| 45 |
+
"""
|
| 46 |
+
|
| 47 |
+
count = 0
|
| 48 |
+
|
| 49 |
+
def __init__(self, bbox, delta_t=3, orig=False):
|
| 50 |
+
"""
|
| 51 |
+
Initialises a tracker using initial bounding box.
|
| 52 |
+
|
| 53 |
+
"""
|
| 54 |
+
# define constant velocity model
|
| 55 |
+
if not orig:
|
| 56 |
+
from .kalmanfilter import KalmanFilterNew as KalmanFilter
|
| 57 |
+
|
| 58 |
+
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
| 59 |
+
else:
|
| 60 |
+
from filterpy.kalman import KalmanFilter
|
| 61 |
+
|
| 62 |
+
self.kf = KalmanFilter(dim_x=7, dim_z=4)
|
| 63 |
+
self.kf.F = np.array(
|
| 64 |
+
[
|
| 65 |
+
[1, 0, 0, 0, 1, 0, 0],
|
| 66 |
+
[0, 1, 0, 0, 0, 1, 0],
|
| 67 |
+
[0, 0, 1, 0, 0, 0, 1],
|
| 68 |
+
[0, 0, 0, 1, 0, 0, 0],
|
| 69 |
+
[0, 0, 0, 0, 1, 0, 0],
|
| 70 |
+
[0, 0, 0, 0, 0, 1, 0],
|
| 71 |
+
[0, 0, 0, 0, 0, 0, 1],
|
| 72 |
+
]
|
| 73 |
+
)
|
| 74 |
+
self.kf.H = np.array(
|
| 75 |
+
[[1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0], [0, 0, 1, 0, 0, 0, 0], [0, 0, 0, 1, 0, 0, 0]]
|
| 76 |
+
)
|
| 77 |
+
|
| 78 |
+
self.kf.R[2:, 2:] *= 10.0
|
| 79 |
+
self.kf.P[4:, 4:] *= 1000.0 # give high uncertainty to the unobservable initial velocities
|
| 80 |
+
self.kf.P *= 10.0
|
| 81 |
+
self.kf.Q[-1, -1] *= 0.01
|
| 82 |
+
self.kf.Q[4:, 4:] *= 0.01
|
| 83 |
+
|
| 84 |
+
self.kf.x[:4] = convert_bbox_to_z(bbox)
|
| 85 |
+
self.time_since_update = 0
|
| 86 |
+
self.id = KalmanBoxTracker.count
|
| 87 |
+
KalmanBoxTracker.count += 1
|
| 88 |
+
self.history = []
|
| 89 |
+
self.hits = 0
|
| 90 |
+
self.hit_streak = 0
|
| 91 |
+
self.age = 0
|
| 92 |
+
"""
|
| 93 |
+
NOTE: [-1,-1,-1,-1,-1] is a compromising placeholder for non-observation status, the same for the return of
|
| 94 |
+
function k_previous_obs. It is ugly and I do not like it. But to support generate observation array in a
|
| 95 |
+
fast and unified way, which you would see below k_observations = np.array([k_previous_obs(...]]), let's bear it for now.
|
| 96 |
+
"""
|
| 97 |
+
self.last_observation = np.array([-1, -1, -1, -1, -1]) # placeholder
|
| 98 |
+
self.observations = dict()
|
| 99 |
+
self.history_observations = []
|
| 100 |
+
self.velocity = None
|
| 101 |
+
self.delta_t = delta_t
|
| 102 |
+
|
| 103 |
+
def update(self, bbox):
|
| 104 |
+
"""
|
| 105 |
+
Updates the state vector with observed bbox.
|
| 106 |
+
"""
|
| 107 |
+
if bbox is not None:
|
| 108 |
+
if self.last_observation.sum() >= 0: # no previous observation
|
| 109 |
+
previous_box = None
|
| 110 |
+
for i in range(self.delta_t):
|
| 111 |
+
dt = self.delta_t - i
|
| 112 |
+
if self.age - dt in self.observations:
|
| 113 |
+
previous_box = self.observations[self.age - dt]
|
| 114 |
+
break
|
| 115 |
+
if previous_box is None:
|
| 116 |
+
previous_box = self.last_observation
|
| 117 |
+
"""
|
| 118 |
+
Estimate the track speed direction with observations Delta t steps away
|
| 119 |
+
"""
|
| 120 |
+
self.velocity = speed_direction(previous_box, bbox)
|
| 121 |
+
|
| 122 |
+
"""
|
| 123 |
+
Insert new observations. This is a ugly way to maintain both self.observations
|
| 124 |
+
and self.history_observations. Bear it for the moment.
|
| 125 |
+
"""
|
| 126 |
+
self.last_observation = bbox
|
| 127 |
+
self.observations[self.age] = bbox
|
| 128 |
+
self.history_observations.append(bbox)
|
| 129 |
+
|
| 130 |
+
self.time_since_update = 0
|
| 131 |
+
self.history = []
|
| 132 |
+
self.hits += 1
|
| 133 |
+
self.hit_streak += 1
|
| 134 |
+
self.kf.update(convert_bbox_to_z(bbox))
|
| 135 |
+
else:
|
| 136 |
+
self.kf.update(bbox)
|
| 137 |
+
|
| 138 |
+
def predict(self):
|
| 139 |
+
"""
|
| 140 |
+
Advances the state vector and returns the predicted bounding box estimate.
|
| 141 |
+
"""
|
| 142 |
+
if (self.kf.x[6] + self.kf.x[2]) <= 0:
|
| 143 |
+
self.kf.x[6] *= 0.0
|
| 144 |
+
|
| 145 |
+
self.kf.predict()
|
| 146 |
+
self.age += 1
|
| 147 |
+
if self.time_since_update > 0:
|
| 148 |
+
self.hit_streak = 0
|
| 149 |
+
self.time_since_update += 1
|
| 150 |
+
self.history.append(convert_x_to_bbox(self.kf.x))
|
| 151 |
+
return self.history[-1]
|
| 152 |
+
|
| 153 |
+
def get_state(self):
|
| 154 |
+
"""
|
| 155 |
+
Returns the current bounding box estimate.
|
| 156 |
+
"""
|
| 157 |
+
return convert_x_to_bbox(self.kf.x)
|
src/gesturedetection/ocsort/kalmanfilter.py
ADDED
|
@@ -0,0 +1,1557 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# pylint: disable=invalid-name, too-many-arguments, too-many-branches,
|
| 3 |
+
# pylint: disable=too-many-locals, too-many-instance-attributes, too-many-lines
|
| 4 |
+
|
| 5 |
+
"""
|
| 6 |
+
This module implements the linear Kalman filter in both an object
|
| 7 |
+
oriented and procedural form. The KalmanFilter class implements
|
| 8 |
+
the filter by storing the various matrices in instance variables,
|
| 9 |
+
minimizing the amount of bookkeeping you have to do.
|
| 10 |
+
All Kalman filters operate with a predict->update cycle. The
|
| 11 |
+
predict step, implemented with the method or function predict(),
|
| 12 |
+
uses the state transition matrix F to predict the state in the next
|
| 13 |
+
time period (epoch). The state is stored as a gaussian (x, P), where
|
| 14 |
+
x is the state (column) vector, and P is its covariance. Covariance
|
| 15 |
+
matrix Q specifies the process covariance. In Bayesian terms, this
|
| 16 |
+
prediction is called the *prior*, which you can think of colloquially
|
| 17 |
+
as the estimate prior to incorporating the measurement.
|
| 18 |
+
The update step, implemented with the method or function `update()`,
|
| 19 |
+
incorporates the measurement z with covariance R, into the state
|
| 20 |
+
estimate (x, P). The class stores the system uncertainty in S,
|
| 21 |
+
the innovation (residual between prediction and measurement in
|
| 22 |
+
measurement space) in y, and the Kalman gain in k. The procedural
|
| 23 |
+
form returns these variables to you. In Bayesian terms this computes
|
| 24 |
+
the *posterior* - the estimate after the information from the
|
| 25 |
+
measurement is incorporated.
|
| 26 |
+
Whether you use the OO form or procedural form is up to you. If
|
| 27 |
+
matrices such as H, R, and F are changing each epoch, you'll probably
|
| 28 |
+
opt to use the procedural form. If they are unchanging, the OO
|
| 29 |
+
form is perhaps easier to use since you won't need to keep track
|
| 30 |
+
of these matrices. This is especially useful if you are implementing
|
| 31 |
+
banks of filters or comparing various KF designs for performance;
|
| 32 |
+
a trivial coding bug could lead to using the wrong sets of matrices.
|
| 33 |
+
This module also offers an implementation of the RTS smoother, and
|
| 34 |
+
other helper functions, such as log likelihood computations.
|
| 35 |
+
The Saver class allows you to easily save the state of the
|
| 36 |
+
KalmanFilter class after every update
|
| 37 |
+
This module expects NumPy arrays for all values that expect
|
| 38 |
+
arrays, although in a few cases, particularly method parameters,
|
| 39 |
+
it will accept types that convert to NumPy arrays, such as lists
|
| 40 |
+
of lists. These exceptions are documented in the method or function.
|
| 41 |
+
Examples
|
| 42 |
+
--------
|
| 43 |
+
The following example constructs a constant velocity kinematic
|
| 44 |
+
filter, filters noisy data, and plots the results. It also demonstrates
|
| 45 |
+
using the Saver class to save the state of the filter at each epoch.
|
| 46 |
+
.. code-block:: Python
|
| 47 |
+
import matplotlib.pyplot as plt
|
| 48 |
+
import numpy as np
|
| 49 |
+
from filterpy.kalman import KalmanFilter
|
| 50 |
+
from filterpy.common import Q_discrete_white_noise, Saver
|
| 51 |
+
r_std, q_std = 2., 0.003
|
| 52 |
+
cv = KalmanFilter(dim_x=2, dim_z=1)
|
| 53 |
+
cv.x = np.array([[0., 1.]]) # position, velocity
|
| 54 |
+
cv.F = np.array([[1, dt],[ [0, 1]])
|
| 55 |
+
cv.R = np.array([[r_std^^2]])
|
| 56 |
+
f.H = np.array([[1., 0.]])
|
| 57 |
+
f.P = np.diag([.1^^2, .03^^2)
|
| 58 |
+
f.Q = Q_discrete_white_noise(2, dt, q_std**2)
|
| 59 |
+
saver = Saver(cv)
|
| 60 |
+
for z in range(100):
|
| 61 |
+
cv.predict()
|
| 62 |
+
cv.update([z + randn() * r_std])
|
| 63 |
+
saver.save() # save the filter's state
|
| 64 |
+
saver.to_array()
|
| 65 |
+
plt.plot(saver.x[:, 0])
|
| 66 |
+
# plot all of the priors
|
| 67 |
+
plt.plot(saver.x_prior[:, 0])
|
| 68 |
+
# plot mahalanobis distance
|
| 69 |
+
plt.figure()
|
| 70 |
+
plt.plot(saver.mahalanobis)
|
| 71 |
+
This code implements the same filter using the procedural form
|
| 72 |
+
x = np.array([[0., 1.]]) # position, velocity
|
| 73 |
+
F = np.array([[1, dt],[ [0, 1]])
|
| 74 |
+
R = np.array([[r_std^^2]])
|
| 75 |
+
H = np.array([[1., 0.]])
|
| 76 |
+
P = np.diag([.1^^2, .03^^2)
|
| 77 |
+
Q = Q_discrete_white_noise(2, dt, q_std**2)
|
| 78 |
+
for z in range(100):
|
| 79 |
+
x, P = predict(x, P, F=F, Q=Q)
|
| 80 |
+
x, P = update(x, P, z=[z + randn() * r_std], R=R, H=H)
|
| 81 |
+
xs.append(x[0, 0])
|
| 82 |
+
plt.plot(xs)
|
| 83 |
+
For more examples see the test subdirectory, or refer to the
|
| 84 |
+
book cited below. In it I both teach Kalman filtering from basic
|
| 85 |
+
principles, and teach the use of this library in great detail.
|
| 86 |
+
FilterPy library.
|
| 87 |
+
http://github.com/rlabbe/filterpy
|
| 88 |
+
Documentation at:
|
| 89 |
+
https://filterpy.readthedocs.org
|
| 90 |
+
Supporting book at:
|
| 91 |
+
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
|
| 92 |
+
This is licensed under an MIT license. See the readme.MD file
|
| 93 |
+
for more information.
|
| 94 |
+
Copyright 2014-2018 Roger R Labbe Jr.
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
from __future__ import absolute_import, division
|
| 98 |
+
|
| 99 |
+
import sys
|
| 100 |
+
from copy import deepcopy
|
| 101 |
+
from math import exp, log, sqrt
|
| 102 |
+
|
| 103 |
+
import numpy as np
|
| 104 |
+
import numpy.linalg as linalg
|
| 105 |
+
from filterpy.common import pretty_str, reshape_z
|
| 106 |
+
from filterpy.stats import logpdf
|
| 107 |
+
from numpy import dot, eye, isscalar, shape, zeros
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
class KalmanFilterNew(object):
|
| 111 |
+
"""Implements a Kalman filter. You are responsible for setting the
|
| 112 |
+
various state variables to reasonable values; the defaults will
|
| 113 |
+
not give you a functional filter.
|
| 114 |
+
For now the best documentation is my free book Kalman and Bayesian
|
| 115 |
+
Filters in Python [2]_. The test files in this directory also give you a
|
| 116 |
+
basic idea of use, albeit without much description.
|
| 117 |
+
In brief, you will first construct this object, specifying the size of
|
| 118 |
+
the state vector with dim_x and the size of the measurement vector that
|
| 119 |
+
you will be using with dim_z. These are mostly used to perform size checks
|
| 120 |
+
when you assign values to the various matrices. For example, if you
|
| 121 |
+
specified dim_z=2 and then try to assign a 3x3 matrix to R (the
|
| 122 |
+
measurement noise matrix you will get an assert exception because R
|
| 123 |
+
should be 2x2. (If for whatever reason you need to alter the size of
|
| 124 |
+
things midstream just use the underscore version of the matrices to
|
| 125 |
+
assign directly: your_filter._R = a_3x3_matrix.)
|
| 126 |
+
After construction the filter will have default matrices created for you,
|
| 127 |
+
but you must specify the values for each. It’s usually easiest to just
|
| 128 |
+
overwrite them rather than assign to each element yourself. This will be
|
| 129 |
+
clearer in the example below. All are of type numpy.array.
|
| 130 |
+
Examples
|
| 131 |
+
--------
|
| 132 |
+
Here is a filter that tracks position and velocity using a sensor that only
|
| 133 |
+
reads position.
|
| 134 |
+
First construct the object with the required dimensionality. Here the state
|
| 135 |
+
(`dim_x`) has 2 coefficients (position and velocity), and the measurement
|
| 136 |
+
(`dim_z`) has one. In FilterPy `x` is the state, `z` is the measurement.
|
| 137 |
+
.. code::
|
| 138 |
+
from filterpy.kalman import KalmanFilter
|
| 139 |
+
f = KalmanFilter (dim_x=2, dim_z=1)
|
| 140 |
+
Assign the initial value for the state (position and velocity). You can do this
|
| 141 |
+
with a two dimensional array like so:
|
| 142 |
+
.. code::
|
| 143 |
+
f.x = np.array([[2.], # position
|
| 144 |
+
[0.]]) # velocity
|
| 145 |
+
or just use a one dimensional array, which I prefer doing.
|
| 146 |
+
.. code::
|
| 147 |
+
f.x = np.array([2., 0.])
|
| 148 |
+
Define the state transition matrix:
|
| 149 |
+
.. code::
|
| 150 |
+
f.F = np.array([[1.,1.],
|
| 151 |
+
[0.,1.]])
|
| 152 |
+
Define the measurement function. Here we need to convert a position-velocity
|
| 153 |
+
vector into just a position vector, so we use:
|
| 154 |
+
.. code::
|
| 155 |
+
f.H = np.array([[1., 0.]])
|
| 156 |
+
Define the state's covariance matrix P.
|
| 157 |
+
.. code::
|
| 158 |
+
f.P = np.array([[1000., 0.],
|
| 159 |
+
[ 0., 1000.] ])
|
| 160 |
+
Now assign the measurement noise. Here the dimension is 1x1, so I can
|
| 161 |
+
use a scalar
|
| 162 |
+
.. code::
|
| 163 |
+
f.R = 5
|
| 164 |
+
I could have done this instead:
|
| 165 |
+
.. code::
|
| 166 |
+
f.R = np.array([[5.]])
|
| 167 |
+
Note that this must be a 2 dimensional array.
|
| 168 |
+
Finally, I will assign the process noise. Here I will take advantage of
|
| 169 |
+
another FilterPy library function:
|
| 170 |
+
.. code::
|
| 171 |
+
from filterpy.common import Q_discrete_white_noise
|
| 172 |
+
f.Q = Q_discrete_white_noise(dim=2, dt=0.1, var=0.13)
|
| 173 |
+
Now just perform the standard predict/update loop:
|
| 174 |
+
.. code::
|
| 175 |
+
while some_condition_is_true:
|
| 176 |
+
z = get_sensor_reading()
|
| 177 |
+
f.predict()
|
| 178 |
+
f.update(z)
|
| 179 |
+
do_something_with_estimate (f.x)
|
| 180 |
+
**Procedural Form**
|
| 181 |
+
This module also contains stand alone functions to perform Kalman filtering.
|
| 182 |
+
Use these if you are not a fan of objects.
|
| 183 |
+
**Example**
|
| 184 |
+
.. code::
|
| 185 |
+
while True:
|
| 186 |
+
z, R = read_sensor()
|
| 187 |
+
x, P = predict(x, P, F, Q)
|
| 188 |
+
x, P = update(x, P, z, R, H)
|
| 189 |
+
See my book Kalman and Bayesian Filters in Python [2]_.
|
| 190 |
+
You will have to set the following attributes after constructing this
|
| 191 |
+
object for the filter to perform properly. Please note that there are
|
| 192 |
+
various checks in place to ensure that you have made everything the
|
| 193 |
+
'correct' size. However, it is possible to provide incorrectly sized
|
| 194 |
+
arrays such that the linear algebra can not perform an operation.
|
| 195 |
+
It can also fail silently - you can end up with matrices of a size that
|
| 196 |
+
allows the linear algebra to work, but are the wrong shape for the problem
|
| 197 |
+
you are trying to solve.
|
| 198 |
+
Parameters
|
| 199 |
+
----------
|
| 200 |
+
dim_x : int
|
| 201 |
+
Number of state variables for the Kalman filter. For example, if
|
| 202 |
+
you are tracking the position and velocity of an object in two
|
| 203 |
+
dimensions, dim_x would be 4.
|
| 204 |
+
This is used to set the default size of P, Q, and u
|
| 205 |
+
dim_z : int
|
| 206 |
+
Number of of measurement inputs. For example, if the sensor
|
| 207 |
+
provides you with position in (x,y), dim_z would be 2.
|
| 208 |
+
dim_u : int (optional)
|
| 209 |
+
size of the control input, if it is being used.
|
| 210 |
+
Default value of 0 indicates it is not used.
|
| 211 |
+
compute_log_likelihood : bool (default = True)
|
| 212 |
+
Computes log likelihood by default, but this can be a slow
|
| 213 |
+
computation, so if you never use it you can turn this computation
|
| 214 |
+
off.
|
| 215 |
+
Attributes
|
| 216 |
+
----------
|
| 217 |
+
x : numpy.array(dim_x, 1)
|
| 218 |
+
Current state estimate. Any call to update() or predict() updates
|
| 219 |
+
this variable.
|
| 220 |
+
P : numpy.array(dim_x, dim_x)
|
| 221 |
+
Current state covariance matrix. Any call to update() or predict()
|
| 222 |
+
updates this variable.
|
| 223 |
+
x_prior : numpy.array(dim_x, 1)
|
| 224 |
+
Prior (predicted) state estimate. The *_prior and *_post attributes
|
| 225 |
+
are for convenience; they store the prior and posterior of the
|
| 226 |
+
current epoch. Read Only.
|
| 227 |
+
P_prior : numpy.array(dim_x, dim_x)
|
| 228 |
+
Prior (predicted) state covariance matrix. Read Only.
|
| 229 |
+
x_post : numpy.array(dim_x, 1)
|
| 230 |
+
Posterior (updated) state estimate. Read Only.
|
| 231 |
+
P_post : numpy.array(dim_x, dim_x)
|
| 232 |
+
Posterior (updated) state covariance matrix. Read Only.
|
| 233 |
+
z : numpy.array
|
| 234 |
+
Last measurement used in update(). Read only.
|
| 235 |
+
R : numpy.array(dim_z, dim_z)
|
| 236 |
+
Measurement noise covariance matrix. Also known as the
|
| 237 |
+
observation covariance.
|
| 238 |
+
Q : numpy.array(dim_x, dim_x)
|
| 239 |
+
Process noise covariance matrix. Also known as the transition
|
| 240 |
+
covariance.
|
| 241 |
+
F : numpy.array()
|
| 242 |
+
State Transition matrix. Also known as `A` in some formulation.
|
| 243 |
+
H : numpy.array(dim_z, dim_x)
|
| 244 |
+
Measurement function. Also known as the observation matrix, or as `C`.
|
| 245 |
+
y : numpy.array
|
| 246 |
+
Residual of the update step. Read only.
|
| 247 |
+
K : numpy.array(dim_x, dim_z)
|
| 248 |
+
Kalman gain of the update step. Read only.
|
| 249 |
+
S : numpy.array
|
| 250 |
+
System uncertainty (P projected to measurement space). Read only.
|
| 251 |
+
SI : numpy.array
|
| 252 |
+
Inverse system uncertainty. Read only.
|
| 253 |
+
log_likelihood : float
|
| 254 |
+
log-likelihood of the last measurement. Read only.
|
| 255 |
+
likelihood : float
|
| 256 |
+
likelihood of last measurement. Read only.
|
| 257 |
+
Computed from the log-likelihood. The log-likelihood can be very
|
| 258 |
+
small, meaning a large negative value such as -28000. Taking the
|
| 259 |
+
exp() of that results in 0.0, which can break typical algorithms
|
| 260 |
+
which multiply by this value, so by default we always return a
|
| 261 |
+
number >= sys.float_info.min.
|
| 262 |
+
mahalanobis : float
|
| 263 |
+
mahalanobis distance of the innovation. Read only.
|
| 264 |
+
inv : function, default numpy.linalg.inv
|
| 265 |
+
If you prefer another inverse function, such as the Moore-Penrose
|
| 266 |
+
pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv
|
| 267 |
+
This is only used to invert self.S. If you know it is diagonal, you
|
| 268 |
+
might choose to set it to filterpy.common.inv_diagonal, which is
|
| 269 |
+
several times faster than numpy.linalg.inv for diagonal matrices.
|
| 270 |
+
alpha : float
|
| 271 |
+
Fading memory setting. 1.0 gives the normal Kalman filter, and
|
| 272 |
+
values slightly larger than 1.0 (such as 1.02) give a fading
|
| 273 |
+
memory effect - previous measurements have less influence on the
|
| 274 |
+
filter's estimates. This formulation of the Fading memory filter
|
| 275 |
+
(there are many) is due to Dan Simon [1]_.
|
| 276 |
+
References
|
| 277 |
+
----------
|
| 278 |
+
.. [1] Dan Simon. "Optimal State Estimation." John Wiley & Sons.
|
| 279 |
+
p. 208-212. (2006)
|
| 280 |
+
.. [2] Roger Labbe. "Kalman and Bayesian Filters in Python"
|
| 281 |
+
https://github.com/rlabbe/Kalman-and-Bayesian-Filters-in-Python
|
| 282 |
+
"""
|
| 283 |
+
|
| 284 |
+
def __init__(self, dim_x, dim_z, dim_u=0):
|
| 285 |
+
if dim_x < 1:
|
| 286 |
+
raise ValueError("dim_x must be 1 or greater")
|
| 287 |
+
if dim_z < 1:
|
| 288 |
+
raise ValueError("dim_z must be 1 or greater")
|
| 289 |
+
if dim_u < 0:
|
| 290 |
+
raise ValueError("dim_u must be 0 or greater")
|
| 291 |
+
|
| 292 |
+
self.dim_x = dim_x
|
| 293 |
+
self.dim_z = dim_z
|
| 294 |
+
self.dim_u = dim_u
|
| 295 |
+
|
| 296 |
+
self.x = zeros((dim_x, 1)) # state
|
| 297 |
+
self.P = eye(dim_x) # uncertainty covariance
|
| 298 |
+
self.Q = eye(dim_x) # process uncertainty
|
| 299 |
+
self.B = None # control transition matrix
|
| 300 |
+
self.F = eye(dim_x) # state transition matrix
|
| 301 |
+
self.H = zeros((dim_z, dim_x)) # measurement function
|
| 302 |
+
self.R = eye(dim_z) # measurement uncertainty
|
| 303 |
+
self._alpha_sq = 1.0 # fading memory control
|
| 304 |
+
self.M = np.zeros((dim_x, dim_z)) # process-measurement cross correlation
|
| 305 |
+
self.z = np.array([[None] * self.dim_z]).T
|
| 306 |
+
|
| 307 |
+
# gain and residual are computed during the innovation step. We
|
| 308 |
+
# save them so that in case you want to inspect them for various
|
| 309 |
+
# purposes
|
| 310 |
+
self.K = np.zeros((dim_x, dim_z)) # kalman gain
|
| 311 |
+
self.y = zeros((dim_z, 1))
|
| 312 |
+
self.S = np.zeros((dim_z, dim_z)) # system uncertainty
|
| 313 |
+
self.SI = np.zeros((dim_z, dim_z)) # inverse system uncertainty
|
| 314 |
+
|
| 315 |
+
# identity matrix. Do not alter this.
|
| 316 |
+
self._I = np.eye(dim_x)
|
| 317 |
+
|
| 318 |
+
# these will always be a copy of x,P after predict() is called
|
| 319 |
+
self.x_prior = self.x.copy()
|
| 320 |
+
self.P_prior = self.P.copy()
|
| 321 |
+
|
| 322 |
+
# these will always be a copy of x,P after update() is called
|
| 323 |
+
self.x_post = self.x.copy()
|
| 324 |
+
self.P_post = self.P.copy()
|
| 325 |
+
|
| 326 |
+
# Only computed only if requested via property
|
| 327 |
+
self._log_likelihood = log(sys.float_info.min)
|
| 328 |
+
self._likelihood = sys.float_info.min
|
| 329 |
+
self._mahalanobis = None
|
| 330 |
+
|
| 331 |
+
# keep all observations
|
| 332 |
+
self.history_obs = []
|
| 333 |
+
|
| 334 |
+
self.inv = np.linalg.inv
|
| 335 |
+
|
| 336 |
+
self.attr_saved = None
|
| 337 |
+
self.observed = False
|
| 338 |
+
|
| 339 |
+
def predict(self, u=None, B=None, F=None, Q=None):
|
| 340 |
+
"""
|
| 341 |
+
Predict next state (prior) using the Kalman filter state propagation
|
| 342 |
+
equations.
|
| 343 |
+
Parameters
|
| 344 |
+
----------
|
| 345 |
+
u : np.array, default 0
|
| 346 |
+
Optional control vector.
|
| 347 |
+
B : np.array(dim_x, dim_u), or None
|
| 348 |
+
Optional control transition matrix; a value of None
|
| 349 |
+
will cause the filter to use `self.B`.
|
| 350 |
+
F : np.array(dim_x, dim_x), or None
|
| 351 |
+
Optional state transition matrix; a value of None
|
| 352 |
+
will cause the filter to use `self.F`.
|
| 353 |
+
Q : np.array(dim_x, dim_x), scalar, or None
|
| 354 |
+
Optional process noise matrix; a value of None will cause the
|
| 355 |
+
filter to use `self.Q`.
|
| 356 |
+
"""
|
| 357 |
+
|
| 358 |
+
if B is None:
|
| 359 |
+
B = self.B
|
| 360 |
+
if F is None:
|
| 361 |
+
F = self.F
|
| 362 |
+
if Q is None:
|
| 363 |
+
Q = self.Q
|
| 364 |
+
elif isscalar(Q):
|
| 365 |
+
Q = eye(self.dim_x) * Q
|
| 366 |
+
|
| 367 |
+
# x = Fx + Bu
|
| 368 |
+
if B is not None and u is not None:
|
| 369 |
+
self.x = dot(F, self.x) + dot(B, u)
|
| 370 |
+
else:
|
| 371 |
+
self.x = dot(F, self.x)
|
| 372 |
+
|
| 373 |
+
# P = FPF' + Q
|
| 374 |
+
self.P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q
|
| 375 |
+
|
| 376 |
+
# save prior
|
| 377 |
+
self.x_prior = self.x.copy()
|
| 378 |
+
self.P_prior = self.P.copy()
|
| 379 |
+
|
| 380 |
+
def freeze(self):
|
| 381 |
+
"""
|
| 382 |
+
Save the parameters before non-observation forward
|
| 383 |
+
"""
|
| 384 |
+
self.attr_saved = deepcopy(self.__dict__)
|
| 385 |
+
|
| 386 |
+
def unfreeze(self):
|
| 387 |
+
if self.attr_saved is not None:
|
| 388 |
+
new_history = deepcopy(self.history_obs)
|
| 389 |
+
self.__dict__ = self.attr_saved
|
| 390 |
+
# self.history_obs = new_history
|
| 391 |
+
self.history_obs = self.history_obs[:-1]
|
| 392 |
+
occur = [int(d is None) for d in new_history]
|
| 393 |
+
indices = np.where(np.array(occur) == 0)[0]
|
| 394 |
+
index1 = indices[-2]
|
| 395 |
+
index2 = indices[-1]
|
| 396 |
+
box1 = new_history[index1]
|
| 397 |
+
x1, y1, s1, r1 = box1
|
| 398 |
+
w1 = np.sqrt(s1 * r1)
|
| 399 |
+
h1 = np.sqrt(s1 / r1)
|
| 400 |
+
box2 = new_history[index2]
|
| 401 |
+
x2, y2, s2, r2 = box2
|
| 402 |
+
w2 = np.sqrt(s2 * r2)
|
| 403 |
+
h2 = np.sqrt(s2 / r2)
|
| 404 |
+
time_gap = index2 - index1
|
| 405 |
+
dx = (x2 - x1) / time_gap
|
| 406 |
+
dy = (y2 - y1) / time_gap
|
| 407 |
+
dw = (w2 - w1) / time_gap
|
| 408 |
+
dh = (h2 - h1) / time_gap
|
| 409 |
+
for i in range(index2 - index1):
|
| 410 |
+
"""
|
| 411 |
+
The default virtual trajectory generation is by linear
|
| 412 |
+
motion (constant speed hypothesis), you could modify this
|
| 413 |
+
part to implement your own.
|
| 414 |
+
"""
|
| 415 |
+
x = x1 + (i + 1) * dx
|
| 416 |
+
y = y1 + (i + 1) * dy
|
| 417 |
+
w = w1 + (i + 1) * dw
|
| 418 |
+
h = h1 + (i + 1) * dh
|
| 419 |
+
s = w * h
|
| 420 |
+
r = w / float(h)
|
| 421 |
+
new_box = np.array([x, y, s, r]).reshape((4, 1))
|
| 422 |
+
"""
|
| 423 |
+
I still use predict-update loop here to refresh the parameters,
|
| 424 |
+
but this can be faster by directly modifying the internal parameters
|
| 425 |
+
as suggested in the paper. I keep this naive but slow way for
|
| 426 |
+
easy read and understanding
|
| 427 |
+
"""
|
| 428 |
+
self.update(new_box)
|
| 429 |
+
if not i == (index2 - index1 - 1):
|
| 430 |
+
self.predict()
|
| 431 |
+
|
| 432 |
+
def update(self, z, R=None, H=None):
|
| 433 |
+
"""
|
| 434 |
+
Add a new measurement (z) to the Kalman filter.
|
| 435 |
+
If z is None, nothing is computed. However, x_post and P_post are
|
| 436 |
+
updated with the prior (x_prior, P_prior), and self.z is set to None.
|
| 437 |
+
Parameters
|
| 438 |
+
----------
|
| 439 |
+
z : (dim_z, 1): array_like
|
| 440 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 441 |
+
otherwise it must be convertible to a column vector.
|
| 442 |
+
If you pass in a value of H, z must be a column vector the
|
| 443 |
+
of the correct size.
|
| 444 |
+
R : np.array, scalar, or None
|
| 445 |
+
Optionally provide R to override the measurement noise for this
|
| 446 |
+
one call, otherwise self.R will be used.
|
| 447 |
+
H : np.array, or None
|
| 448 |
+
Optionally provide H to override the measurement function for this
|
| 449 |
+
one call, otherwise self.H will be used.
|
| 450 |
+
"""
|
| 451 |
+
|
| 452 |
+
# set to None to force recompute
|
| 453 |
+
self._log_likelihood = None
|
| 454 |
+
self._likelihood = None
|
| 455 |
+
self._mahalanobis = None
|
| 456 |
+
|
| 457 |
+
# append the observation
|
| 458 |
+
self.history_obs.append(z)
|
| 459 |
+
|
| 460 |
+
if z is None:
|
| 461 |
+
if self.observed:
|
| 462 |
+
"""
|
| 463 |
+
Got no observation so freeze the current parameters for future
|
| 464 |
+
potential online smoothing.
|
| 465 |
+
"""
|
| 466 |
+
self.freeze()
|
| 467 |
+
self.observed = False
|
| 468 |
+
self.z = np.array([[None] * self.dim_z]).T
|
| 469 |
+
self.x_post = self.x.copy()
|
| 470 |
+
self.P_post = self.P.copy()
|
| 471 |
+
self.y = zeros((self.dim_z, 1))
|
| 472 |
+
return
|
| 473 |
+
|
| 474 |
+
# self.observed = True
|
| 475 |
+
if not self.observed:
|
| 476 |
+
"""
|
| 477 |
+
Get observation, use online smoothing to re-update parameters
|
| 478 |
+
"""
|
| 479 |
+
self.unfreeze()
|
| 480 |
+
self.observed = True
|
| 481 |
+
|
| 482 |
+
if R is None:
|
| 483 |
+
R = self.R
|
| 484 |
+
elif isscalar(R):
|
| 485 |
+
R = eye(self.dim_z) * R
|
| 486 |
+
|
| 487 |
+
if H is None:
|
| 488 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
| 489 |
+
H = self.H
|
| 490 |
+
|
| 491 |
+
# y = z - Hx
|
| 492 |
+
# error (residual) between measurement and prediction
|
| 493 |
+
self.y = z - dot(H, self.x)
|
| 494 |
+
|
| 495 |
+
# common subexpression for speed
|
| 496 |
+
PHT = dot(self.P, H.T)
|
| 497 |
+
|
| 498 |
+
# S = HPH' + R
|
| 499 |
+
# project system uncertainty into measurement space
|
| 500 |
+
self.S = dot(H, PHT) + R
|
| 501 |
+
self.SI = self.inv(self.S)
|
| 502 |
+
# K = PH'inv(S)
|
| 503 |
+
# map system uncertainty into kalman gain
|
| 504 |
+
self.K = dot(PHT, self.SI)
|
| 505 |
+
|
| 506 |
+
# x = x + Ky
|
| 507 |
+
# predict new x with residual scaled by the kalman gain
|
| 508 |
+
self.x = self.x + dot(self.K, self.y)
|
| 509 |
+
|
| 510 |
+
# P = (I-KH)P(I-KH)' + KRK'
|
| 511 |
+
# This is more numerically stable
|
| 512 |
+
# and works for non-optimal K vs the equation
|
| 513 |
+
# P = (I-KH)P usually seen in the literature.
|
| 514 |
+
|
| 515 |
+
I_KH = self._I - dot(self.K, H)
|
| 516 |
+
self.P = dot(dot(I_KH, self.P), I_KH.T) + dot(dot(self.K, R), self.K.T)
|
| 517 |
+
|
| 518 |
+
# save measurement and posterior state
|
| 519 |
+
self.z = deepcopy(z)
|
| 520 |
+
self.x_post = self.x.copy()
|
| 521 |
+
self.P_post = self.P.copy()
|
| 522 |
+
|
| 523 |
+
def predict_steadystate(self, u=0, B=None):
|
| 524 |
+
"""
|
| 525 |
+
Predict state (prior) using the Kalman filter state propagation
|
| 526 |
+
equations. Only x is updated, P is left unchanged. See
|
| 527 |
+
update_steadstate() for a longer explanation of when to use this
|
| 528 |
+
method.
|
| 529 |
+
Parameters
|
| 530 |
+
----------
|
| 531 |
+
u : np.array
|
| 532 |
+
Optional control vector. If non-zero, it is multiplied by B
|
| 533 |
+
to create the control input into the system.
|
| 534 |
+
B : np.array(dim_x, dim_u), or None
|
| 535 |
+
Optional control transition matrix; a value of None
|
| 536 |
+
will cause the filter to use `self.B`.
|
| 537 |
+
"""
|
| 538 |
+
|
| 539 |
+
if B is None:
|
| 540 |
+
B = self.B
|
| 541 |
+
|
| 542 |
+
# x = Fx + Bu
|
| 543 |
+
if B is not None:
|
| 544 |
+
self.x = dot(self.F, self.x) + dot(B, u)
|
| 545 |
+
else:
|
| 546 |
+
self.x = dot(self.F, self.x)
|
| 547 |
+
|
| 548 |
+
# save prior
|
| 549 |
+
self.x_prior = self.x.copy()
|
| 550 |
+
self.P_prior = self.P.copy()
|
| 551 |
+
|
| 552 |
+
def update_steadystate(self, z):
|
| 553 |
+
"""
|
| 554 |
+
Add a new measurement (z) to the Kalman filter without recomputing
|
| 555 |
+
the Kalman gain K, the state covariance P, or the system
|
| 556 |
+
uncertainty S.
|
| 557 |
+
You can use this for LTI systems since the Kalman gain and covariance
|
| 558 |
+
converge to a fixed value. Precompute these and assign them explicitly,
|
| 559 |
+
or run the Kalman filter using the normal predict()/update(0 cycle
|
| 560 |
+
until they converge.
|
| 561 |
+
The main advantage of this call is speed. We do significantly less
|
| 562 |
+
computation, notably avoiding a costly matrix inversion.
|
| 563 |
+
Use in conjunction with predict_steadystate(), otherwise P will grow
|
| 564 |
+
without bound.
|
| 565 |
+
Parameters
|
| 566 |
+
----------
|
| 567 |
+
z : (dim_z, 1): array_like
|
| 568 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 569 |
+
otherwise it must be convertible to a column vector.
|
| 570 |
+
Examples
|
| 571 |
+
--------
|
| 572 |
+
>>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter
|
| 573 |
+
>>> # let filter converge on representative data, then save k and P
|
| 574 |
+
>>> for i in range(100):
|
| 575 |
+
>>> cv.predict()
|
| 576 |
+
>>> cv.update([i, i, i])
|
| 577 |
+
>>> saved_k = np.copy(cv.K)
|
| 578 |
+
>>> saved_P = np.copy(cv.P)
|
| 579 |
+
later on:
|
| 580 |
+
>>> cv = kinematic_kf(dim=3, order=2) # 3D const velocity filter
|
| 581 |
+
>>> cv.K = np.copy(saved_K)
|
| 582 |
+
>>> cv.P = np.copy(saved_P)
|
| 583 |
+
>>> for i in range(100):
|
| 584 |
+
>>> cv.predict_steadystate()
|
| 585 |
+
>>> cv.update_steadystate([i, i, i])
|
| 586 |
+
"""
|
| 587 |
+
|
| 588 |
+
# set to None to force recompute
|
| 589 |
+
self._log_likelihood = None
|
| 590 |
+
self._likelihood = None
|
| 591 |
+
self._mahalanobis = None
|
| 592 |
+
|
| 593 |
+
if z is None:
|
| 594 |
+
self.z = np.array([[None] * self.dim_z]).T
|
| 595 |
+
self.x_post = self.x.copy()
|
| 596 |
+
self.P_post = self.P.copy()
|
| 597 |
+
self.y = zeros((self.dim_z, 1))
|
| 598 |
+
return
|
| 599 |
+
|
| 600 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
| 601 |
+
|
| 602 |
+
# y = z - Hx
|
| 603 |
+
# error (residual) between measurement and prediction
|
| 604 |
+
self.y = z - dot(self.H, self.x)
|
| 605 |
+
|
| 606 |
+
# x = x + Ky
|
| 607 |
+
# predict new x with residual scaled by the kalman gain
|
| 608 |
+
self.x = self.x + dot(self.K, self.y)
|
| 609 |
+
|
| 610 |
+
self.z = deepcopy(z)
|
| 611 |
+
self.x_post = self.x.copy()
|
| 612 |
+
self.P_post = self.P.copy()
|
| 613 |
+
|
| 614 |
+
# set to None to force recompute
|
| 615 |
+
self._log_likelihood = None
|
| 616 |
+
self._likelihood = None
|
| 617 |
+
self._mahalanobis = None
|
| 618 |
+
|
| 619 |
+
def update_correlated(self, z, R=None, H=None):
|
| 620 |
+
"""Add a new measurement (z) to the Kalman filter assuming that
|
| 621 |
+
process noise and measurement noise are correlated as defined in
|
| 622 |
+
the `self.M` matrix.
|
| 623 |
+
A partial derivation can be found in [1]
|
| 624 |
+
If z is None, nothing is changed.
|
| 625 |
+
Parameters
|
| 626 |
+
----------
|
| 627 |
+
z : (dim_z, 1): array_like
|
| 628 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 629 |
+
otherwise it must be convertible to a column vector.
|
| 630 |
+
R : np.array, scalar, or None
|
| 631 |
+
Optionally provide R to override the measurement noise for this
|
| 632 |
+
one call, otherwise self.R will be used.
|
| 633 |
+
H : np.array, or None
|
| 634 |
+
Optionally provide H to override the measurement function for this
|
| 635 |
+
one call, otherwise self.H will be used.
|
| 636 |
+
References
|
| 637 |
+
----------
|
| 638 |
+
.. [1] Bulut, Y. (2011). Applied Kalman filter theory (Doctoral dissertation, Northeastern University).
|
| 639 |
+
http://people.duke.edu/~hpgavin/SystemID/References/Balut-KalmanFilter-PhD-NEU-2011.pdf
|
| 640 |
+
"""
|
| 641 |
+
|
| 642 |
+
# set to None to force recompute
|
| 643 |
+
self._log_likelihood = None
|
| 644 |
+
self._likelihood = None
|
| 645 |
+
self._mahalanobis = None
|
| 646 |
+
|
| 647 |
+
if z is None:
|
| 648 |
+
self.z = np.array([[None] * self.dim_z]).T
|
| 649 |
+
self.x_post = self.x.copy()
|
| 650 |
+
self.P_post = self.P.copy()
|
| 651 |
+
self.y = zeros((self.dim_z, 1))
|
| 652 |
+
return
|
| 653 |
+
|
| 654 |
+
if R is None:
|
| 655 |
+
R = self.R
|
| 656 |
+
elif isscalar(R):
|
| 657 |
+
R = eye(self.dim_z) * R
|
| 658 |
+
|
| 659 |
+
# rename for readability and a tiny extra bit of speed
|
| 660 |
+
if H is None:
|
| 661 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
| 662 |
+
H = self.H
|
| 663 |
+
|
| 664 |
+
# handle special case: if z is in form [[z]] but x is not a column
|
| 665 |
+
# vector dimensions will not match
|
| 666 |
+
if self.x.ndim == 1 and shape(z) == (1, 1):
|
| 667 |
+
z = z[0]
|
| 668 |
+
|
| 669 |
+
if shape(z) == (): # is it scalar, e.g. z=3 or z=np.array(3)
|
| 670 |
+
z = np.asarray([z])
|
| 671 |
+
|
| 672 |
+
# y = z - Hx
|
| 673 |
+
# error (residual) between measurement and prediction
|
| 674 |
+
self.y = z - dot(H, self.x)
|
| 675 |
+
|
| 676 |
+
# common subexpression for speed
|
| 677 |
+
PHT = dot(self.P, H.T)
|
| 678 |
+
|
| 679 |
+
# project system uncertainty into measurement space
|
| 680 |
+
self.S = dot(H, PHT) + dot(H, self.M) + dot(self.M.T, H.T) + R
|
| 681 |
+
self.SI = self.inv(self.S)
|
| 682 |
+
|
| 683 |
+
# K = PH'inv(S)
|
| 684 |
+
# map system uncertainty into kalman gain
|
| 685 |
+
self.K = dot(PHT + self.M, self.SI)
|
| 686 |
+
|
| 687 |
+
# x = x + Ky
|
| 688 |
+
# predict new x with residual scaled by the kalman gain
|
| 689 |
+
self.x = self.x + dot(self.K, self.y)
|
| 690 |
+
self.P = self.P - dot(self.K, dot(H, self.P) + self.M.T)
|
| 691 |
+
|
| 692 |
+
self.z = deepcopy(z)
|
| 693 |
+
self.x_post = self.x.copy()
|
| 694 |
+
self.P_post = self.P.copy()
|
| 695 |
+
|
| 696 |
+
def batch_filter(self, zs, Fs=None, Qs=None, Hs=None, Rs=None, Bs=None, us=None, update_first=False, saver=None):
|
| 697 |
+
"""Batch processes a sequences of measurements.
|
| 698 |
+
Parameters
|
| 699 |
+
----------
|
| 700 |
+
zs : list-like
|
| 701 |
+
list of measurements at each time step `self.dt`. Missing
|
| 702 |
+
measurements must be represented by `None`.
|
| 703 |
+
Fs : None, list-like, default=None
|
| 704 |
+
optional value or list of values to use for the state transition
|
| 705 |
+
matrix F.
|
| 706 |
+
If Fs is None then self.F is used for all epochs.
|
| 707 |
+
Otherwise it must contain a list-like list of F's, one for
|
| 708 |
+
each epoch. This allows you to have varying F per epoch.
|
| 709 |
+
Qs : None, np.array or list-like, default=None
|
| 710 |
+
optional value or list of values to use for the process error
|
| 711 |
+
covariance Q.
|
| 712 |
+
If Qs is None then self.Q is used for all epochs.
|
| 713 |
+
Otherwise it must contain a list-like list of Q's, one for
|
| 714 |
+
each epoch. This allows you to have varying Q per epoch.
|
| 715 |
+
Hs : None, np.array or list-like, default=None
|
| 716 |
+
optional list of values to use for the measurement matrix H.
|
| 717 |
+
If Hs is None then self.H is used for all epochs.
|
| 718 |
+
If Hs contains a single matrix, then it is used as H for all
|
| 719 |
+
epochs.
|
| 720 |
+
Otherwise it must contain a list-like list of H's, one for
|
| 721 |
+
each epoch. This allows you to have varying H per epoch.
|
| 722 |
+
Rs : None, np.array or list-like, default=None
|
| 723 |
+
optional list of values to use for the measurement error
|
| 724 |
+
covariance R.
|
| 725 |
+
If Rs is None then self.R is used for all epochs.
|
| 726 |
+
Otherwise it must contain a list-like list of R's, one for
|
| 727 |
+
each epoch. This allows you to have varying R per epoch.
|
| 728 |
+
Bs : None, np.array or list-like, default=None
|
| 729 |
+
optional list of values to use for the control transition matrix B.
|
| 730 |
+
If Bs is None then self.B is used for all epochs.
|
| 731 |
+
Otherwise it must contain a list-like list of B's, one for
|
| 732 |
+
each epoch. This allows you to have varying B per epoch.
|
| 733 |
+
us : None, np.array or list-like, default=None
|
| 734 |
+
optional list of values to use for the control input vector;
|
| 735 |
+
If us is None then None is used for all epochs (equivalent to 0,
|
| 736 |
+
or no control input).
|
| 737 |
+
Otherwise it must contain a list-like list of u's, one for
|
| 738 |
+
each epoch.
|
| 739 |
+
update_first : bool, optional, default=False
|
| 740 |
+
controls whether the order of operations is update followed by
|
| 741 |
+
predict, or predict followed by update. Default is predict->update.
|
| 742 |
+
saver : filterpy.common.Saver, optional
|
| 743 |
+
filterpy.common.Saver object. If provided, saver.save() will be
|
| 744 |
+
called after every epoch
|
| 745 |
+
Returns
|
| 746 |
+
-------
|
| 747 |
+
means : np.array((n,dim_x,1))
|
| 748 |
+
array of the state for each time step after the update. Each entry
|
| 749 |
+
is an np.array. In other words `means[k,:]` is the state at step
|
| 750 |
+
`k`.
|
| 751 |
+
covariance : np.array((n,dim_x,dim_x))
|
| 752 |
+
array of the covariances for each time step after the update.
|
| 753 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
| 754 |
+
means_predictions : np.array((n,dim_x,1))
|
| 755 |
+
array of the state for each time step after the predictions. Each
|
| 756 |
+
entry is an np.array. In other words `means[k,:]` is the state at
|
| 757 |
+
step `k`.
|
| 758 |
+
covariance_predictions : np.array((n,dim_x,dim_x))
|
| 759 |
+
array of the covariances for each time step after the prediction.
|
| 760 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
| 761 |
+
Examples
|
| 762 |
+
--------
|
| 763 |
+
.. code-block:: Python
|
| 764 |
+
# this example demonstrates tracking a measurement where the time
|
| 765 |
+
# between measurement varies, as stored in dts. This requires
|
| 766 |
+
# that F be recomputed for each epoch. The output is then smoothed
|
| 767 |
+
# with an RTS smoother.
|
| 768 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
| 769 |
+
Fs = [np.array([[1., dt], [0, 1]] for dt in dts]
|
| 770 |
+
(mu, cov, _, _) = kf.batch_filter(zs, Fs=Fs)
|
| 771 |
+
(xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs)
|
| 772 |
+
"""
|
| 773 |
+
|
| 774 |
+
# pylint: disable=too-many-statements
|
| 775 |
+
n = np.size(zs, 0)
|
| 776 |
+
if Fs is None:
|
| 777 |
+
Fs = [self.F] * n
|
| 778 |
+
if Qs is None:
|
| 779 |
+
Qs = [self.Q] * n
|
| 780 |
+
if Hs is None:
|
| 781 |
+
Hs = [self.H] * n
|
| 782 |
+
if Rs is None:
|
| 783 |
+
Rs = [self.R] * n
|
| 784 |
+
if Bs is None:
|
| 785 |
+
Bs = [self.B] * n
|
| 786 |
+
if us is None:
|
| 787 |
+
us = [0] * n
|
| 788 |
+
|
| 789 |
+
# mean estimates from Kalman Filter
|
| 790 |
+
if self.x.ndim == 1:
|
| 791 |
+
means = zeros((n, self.dim_x))
|
| 792 |
+
means_p = zeros((n, self.dim_x))
|
| 793 |
+
else:
|
| 794 |
+
means = zeros((n, self.dim_x, 1))
|
| 795 |
+
means_p = zeros((n, self.dim_x, 1))
|
| 796 |
+
|
| 797 |
+
# state covariances from Kalman Filter
|
| 798 |
+
covariances = zeros((n, self.dim_x, self.dim_x))
|
| 799 |
+
covariances_p = zeros((n, self.dim_x, self.dim_x))
|
| 800 |
+
|
| 801 |
+
if update_first:
|
| 802 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
| 803 |
+
|
| 804 |
+
self.update(z, R=R, H=H)
|
| 805 |
+
means[i, :] = self.x
|
| 806 |
+
covariances[i, :, :] = self.P
|
| 807 |
+
|
| 808 |
+
self.predict(u=u, B=B, F=F, Q=Q)
|
| 809 |
+
means_p[i, :] = self.x
|
| 810 |
+
covariances_p[i, :, :] = self.P
|
| 811 |
+
|
| 812 |
+
if saver is not None:
|
| 813 |
+
saver.save()
|
| 814 |
+
else:
|
| 815 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
| 816 |
+
|
| 817 |
+
self.predict(u=u, B=B, F=F, Q=Q)
|
| 818 |
+
means_p[i, :] = self.x
|
| 819 |
+
covariances_p[i, :, :] = self.P
|
| 820 |
+
|
| 821 |
+
self.update(z, R=R, H=H)
|
| 822 |
+
means[i, :] = self.x
|
| 823 |
+
covariances[i, :, :] = self.P
|
| 824 |
+
|
| 825 |
+
if saver is not None:
|
| 826 |
+
saver.save()
|
| 827 |
+
|
| 828 |
+
return (means, covariances, means_p, covariances_p)
|
| 829 |
+
|
| 830 |
+
def rts_smoother(self, Xs, Ps, Fs=None, Qs=None, inv=np.linalg.inv):
|
| 831 |
+
"""
|
| 832 |
+
Runs the Rauch-Tung-Striebel Kalman smoother on a set of
|
| 833 |
+
means and covariances computed by a Kalman filter. The usual input
|
| 834 |
+
would come from the output of `KalmanFilter.batch_filter()`.
|
| 835 |
+
Parameters
|
| 836 |
+
----------
|
| 837 |
+
Xs : numpy.array
|
| 838 |
+
array of the means (state variable x) of the output of a Kalman
|
| 839 |
+
filter.
|
| 840 |
+
Ps : numpy.array
|
| 841 |
+
array of the covariances of the output of a kalman filter.
|
| 842 |
+
Fs : list-like collection of numpy.array, optional
|
| 843 |
+
State transition matrix of the Kalman filter at each time step.
|
| 844 |
+
Optional, if not provided the filter's self.F will be used
|
| 845 |
+
Qs : list-like collection of numpy.array, optional
|
| 846 |
+
Process noise of the Kalman filter at each time step. Optional,
|
| 847 |
+
if not provided the filter's self.Q will be used
|
| 848 |
+
inv : function, default numpy.linalg.inv
|
| 849 |
+
If you prefer another inverse function, such as the Moore-Penrose
|
| 850 |
+
pseudo inverse, set it to that instead: kf.inv = np.linalg.pinv
|
| 851 |
+
Returns
|
| 852 |
+
-------
|
| 853 |
+
x : numpy.ndarray
|
| 854 |
+
smoothed means
|
| 855 |
+
P : numpy.ndarray
|
| 856 |
+
smoothed state covariances
|
| 857 |
+
K : numpy.ndarray
|
| 858 |
+
smoother gain at each step
|
| 859 |
+
Pp : numpy.ndarray
|
| 860 |
+
Predicted state covariances
|
| 861 |
+
Examples
|
| 862 |
+
--------
|
| 863 |
+
.. code-block:: Python
|
| 864 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
| 865 |
+
(mu, cov, _, _) = kalman.batch_filter(zs)
|
| 866 |
+
(x, P, K, Pp) = rts_smoother(mu, cov, kf.F, kf.Q)
|
| 867 |
+
"""
|
| 868 |
+
|
| 869 |
+
if len(Xs) != len(Ps):
|
| 870 |
+
raise ValueError("length of Xs and Ps must be the same")
|
| 871 |
+
|
| 872 |
+
n = Xs.shape[0]
|
| 873 |
+
dim_x = Xs.shape[1]
|
| 874 |
+
|
| 875 |
+
if Fs is None:
|
| 876 |
+
Fs = [self.F] * n
|
| 877 |
+
if Qs is None:
|
| 878 |
+
Qs = [self.Q] * n
|
| 879 |
+
|
| 880 |
+
# smoother gain
|
| 881 |
+
K = zeros((n, dim_x, dim_x))
|
| 882 |
+
|
| 883 |
+
x, P, Pp = Xs.copy(), Ps.copy(), Ps.copy()
|
| 884 |
+
for k in range(n - 2, -1, -1):
|
| 885 |
+
Pp[k] = dot(dot(Fs[k + 1], P[k]), Fs[k + 1].T) + Qs[k + 1]
|
| 886 |
+
|
| 887 |
+
# pylint: disable=bad-whitespace
|
| 888 |
+
K[k] = dot(dot(P[k], Fs[k + 1].T), inv(Pp[k]))
|
| 889 |
+
x[k] += dot(K[k], x[k + 1] - dot(Fs[k + 1], x[k]))
|
| 890 |
+
P[k] += dot(dot(K[k], P[k + 1] - Pp[k]), K[k].T)
|
| 891 |
+
|
| 892 |
+
return (x, P, K, Pp)
|
| 893 |
+
|
| 894 |
+
def get_prediction(self, u=None, B=None, F=None, Q=None):
|
| 895 |
+
"""
|
| 896 |
+
Predict next state (prior) using the Kalman filter state propagation
|
| 897 |
+
equations and returns it without modifying the object.
|
| 898 |
+
Parameters
|
| 899 |
+
----------
|
| 900 |
+
u : np.array, default 0
|
| 901 |
+
Optional control vector.
|
| 902 |
+
B : np.array(dim_x, dim_u), or None
|
| 903 |
+
Optional control transition matrix; a value of None
|
| 904 |
+
will cause the filter to use `self.B`.
|
| 905 |
+
F : np.array(dim_x, dim_x), or None
|
| 906 |
+
Optional state transition matrix; a value of None
|
| 907 |
+
will cause the filter to use `self.F`.
|
| 908 |
+
Q : np.array(dim_x, dim_x), scalar, or None
|
| 909 |
+
Optional process noise matrix; a value of None will cause the
|
| 910 |
+
filter to use `self.Q`.
|
| 911 |
+
Returns
|
| 912 |
+
-------
|
| 913 |
+
(x, P) : tuple
|
| 914 |
+
State vector and covariance array of the prediction.
|
| 915 |
+
"""
|
| 916 |
+
|
| 917 |
+
if B is None:
|
| 918 |
+
B = self.B
|
| 919 |
+
if F is None:
|
| 920 |
+
F = self.F
|
| 921 |
+
if Q is None:
|
| 922 |
+
Q = self.Q
|
| 923 |
+
elif isscalar(Q):
|
| 924 |
+
Q = eye(self.dim_x) * Q
|
| 925 |
+
|
| 926 |
+
# x = Fx + Bu
|
| 927 |
+
if B is not None and u is not None:
|
| 928 |
+
x = dot(F, self.x) + dot(B, u)
|
| 929 |
+
else:
|
| 930 |
+
x = dot(F, self.x)
|
| 931 |
+
|
| 932 |
+
# P = FPF' + Q
|
| 933 |
+
P = self._alpha_sq * dot(dot(F, self.P), F.T) + Q
|
| 934 |
+
|
| 935 |
+
return x, P
|
| 936 |
+
|
| 937 |
+
def get_update(self, z=None):
|
| 938 |
+
"""
|
| 939 |
+
Computes the new estimate based on measurement `z` and returns it
|
| 940 |
+
without altering the state of the filter.
|
| 941 |
+
Parameters
|
| 942 |
+
----------
|
| 943 |
+
z : (dim_z, 1): array_like
|
| 944 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 945 |
+
otherwise it must be convertible to a column vector.
|
| 946 |
+
Returns
|
| 947 |
+
-------
|
| 948 |
+
(x, P) : tuple
|
| 949 |
+
State vector and covariance array of the update.
|
| 950 |
+
"""
|
| 951 |
+
|
| 952 |
+
if z is None:
|
| 953 |
+
return self.x, self.P
|
| 954 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
| 955 |
+
|
| 956 |
+
R = self.R
|
| 957 |
+
H = self.H
|
| 958 |
+
P = self.P
|
| 959 |
+
x = self.x
|
| 960 |
+
|
| 961 |
+
# error (residual) between measurement and prediction
|
| 962 |
+
y = z - dot(H, x)
|
| 963 |
+
|
| 964 |
+
# common subexpression for speed
|
| 965 |
+
PHT = dot(P, H.T)
|
| 966 |
+
|
| 967 |
+
# project system uncertainty into measurement space
|
| 968 |
+
S = dot(H, PHT) + R
|
| 969 |
+
|
| 970 |
+
# map system uncertainty into kalman gain
|
| 971 |
+
K = dot(PHT, self.inv(S))
|
| 972 |
+
|
| 973 |
+
# predict new x with residual scaled by the kalman gain
|
| 974 |
+
x = x + dot(K, y)
|
| 975 |
+
|
| 976 |
+
# P = (I-KH)P(I-KH)' + KRK'
|
| 977 |
+
I_KH = self._I - dot(K, H)
|
| 978 |
+
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
|
| 979 |
+
|
| 980 |
+
return x, P
|
| 981 |
+
|
| 982 |
+
def residual_of(self, z):
|
| 983 |
+
"""
|
| 984 |
+
Returns the residual for the given measurement (z). Does not alter
|
| 985 |
+
the state of the filter.
|
| 986 |
+
"""
|
| 987 |
+
z = reshape_z(z, self.dim_z, self.x.ndim)
|
| 988 |
+
return z - dot(self.H, self.x_prior)
|
| 989 |
+
|
| 990 |
+
def measurement_of_state(self, x):
|
| 991 |
+
"""
|
| 992 |
+
Helper function that converts a state into a measurement.
|
| 993 |
+
Parameters
|
| 994 |
+
----------
|
| 995 |
+
x : np.array
|
| 996 |
+
kalman state vector
|
| 997 |
+
Returns
|
| 998 |
+
-------
|
| 999 |
+
z : (dim_z, 1): array_like
|
| 1000 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 1001 |
+
otherwise it must be convertible to a column vector.
|
| 1002 |
+
"""
|
| 1003 |
+
|
| 1004 |
+
return dot(self.H, x)
|
| 1005 |
+
|
| 1006 |
+
@property
|
| 1007 |
+
def log_likelihood(self):
|
| 1008 |
+
"""
|
| 1009 |
+
log-likelihood of the last measurement.
|
| 1010 |
+
"""
|
| 1011 |
+
if self._log_likelihood is None:
|
| 1012 |
+
self._log_likelihood = logpdf(x=self.y, cov=self.S)
|
| 1013 |
+
return self._log_likelihood
|
| 1014 |
+
|
| 1015 |
+
@property
|
| 1016 |
+
def likelihood(self):
|
| 1017 |
+
"""
|
| 1018 |
+
Computed from the log-likelihood. The log-likelihood can be very
|
| 1019 |
+
small, meaning a large negative value such as -28000. Taking the
|
| 1020 |
+
exp() of that results in 0.0, which can break typical algorithms
|
| 1021 |
+
which multiply by this value, so by default we always return a
|
| 1022 |
+
number >= sys.float_info.min.
|
| 1023 |
+
"""
|
| 1024 |
+
if self._likelihood is None:
|
| 1025 |
+
self._likelihood = exp(self.log_likelihood)
|
| 1026 |
+
if self._likelihood == 0:
|
| 1027 |
+
self._likelihood = sys.float_info.min
|
| 1028 |
+
return self._likelihood
|
| 1029 |
+
|
| 1030 |
+
@property
|
| 1031 |
+
def mahalanobis(self):
|
| 1032 |
+
""" "
|
| 1033 |
+
Mahalanobis distance of measurement. E.g. 3 means measurement
|
| 1034 |
+
was 3 standard deviations away from the predicted value.
|
| 1035 |
+
Returns
|
| 1036 |
+
-------
|
| 1037 |
+
mahalanobis : float
|
| 1038 |
+
"""
|
| 1039 |
+
if self._mahalanobis is None:
|
| 1040 |
+
self._mahalanobis = sqrt(float(dot(dot(self.y.T, self.SI), self.y)))
|
| 1041 |
+
return self._mahalanobis
|
| 1042 |
+
|
| 1043 |
+
@property
|
| 1044 |
+
def alpha(self):
|
| 1045 |
+
"""
|
| 1046 |
+
Fading memory setting. 1.0 gives the normal Kalman filter, and
|
| 1047 |
+
values slightly larger than 1.0 (such as 1.02) give a fading
|
| 1048 |
+
memory effect - previous measurements have less influence on the
|
| 1049 |
+
filter's estimates. This formulation of the Fading memory filter
|
| 1050 |
+
(there are many) is due to Dan Simon [1]_.
|
| 1051 |
+
"""
|
| 1052 |
+
return self._alpha_sq**0.5
|
| 1053 |
+
|
| 1054 |
+
def log_likelihood_of(self, z):
|
| 1055 |
+
"""
|
| 1056 |
+
log likelihood of the measurement `z`. This should only be called
|
| 1057 |
+
after a call to update(). Calling after predict() will yield an
|
| 1058 |
+
incorrect result."""
|
| 1059 |
+
|
| 1060 |
+
if z is None:
|
| 1061 |
+
return log(sys.float_info.min)
|
| 1062 |
+
return logpdf(z, dot(self.H, self.x), self.S)
|
| 1063 |
+
|
| 1064 |
+
@alpha.setter
|
| 1065 |
+
def alpha(self, value):
|
| 1066 |
+
if not np.isscalar(value) or value < 1:
|
| 1067 |
+
raise ValueError("alpha must be a float greater than 1")
|
| 1068 |
+
|
| 1069 |
+
self._alpha_sq = value**2
|
| 1070 |
+
|
| 1071 |
+
def __repr__(self):
|
| 1072 |
+
return "\n".join(
|
| 1073 |
+
[
|
| 1074 |
+
"KalmanFilter object",
|
| 1075 |
+
pretty_str("dim_x", self.dim_x),
|
| 1076 |
+
pretty_str("dim_z", self.dim_z),
|
| 1077 |
+
pretty_str("dim_u", self.dim_u),
|
| 1078 |
+
pretty_str("x", self.x),
|
| 1079 |
+
pretty_str("P", self.P),
|
| 1080 |
+
pretty_str("x_prior", self.x_prior),
|
| 1081 |
+
pretty_str("P_prior", self.P_prior),
|
| 1082 |
+
pretty_str("x_post", self.x_post),
|
| 1083 |
+
pretty_str("P_post", self.P_post),
|
| 1084 |
+
pretty_str("F", self.F),
|
| 1085 |
+
pretty_str("Q", self.Q),
|
| 1086 |
+
pretty_str("R", self.R),
|
| 1087 |
+
pretty_str("H", self.H),
|
| 1088 |
+
pretty_str("K", self.K),
|
| 1089 |
+
pretty_str("y", self.y),
|
| 1090 |
+
pretty_str("S", self.S),
|
| 1091 |
+
pretty_str("SI", self.SI),
|
| 1092 |
+
pretty_str("M", self.M),
|
| 1093 |
+
pretty_str("B", self.B),
|
| 1094 |
+
pretty_str("z", self.z),
|
| 1095 |
+
pretty_str("log-likelihood", self.log_likelihood),
|
| 1096 |
+
pretty_str("likelihood", self.likelihood),
|
| 1097 |
+
pretty_str("mahalanobis", self.mahalanobis),
|
| 1098 |
+
pretty_str("alpha", self.alpha),
|
| 1099 |
+
pretty_str("inv", self.inv),
|
| 1100 |
+
]
|
| 1101 |
+
)
|
| 1102 |
+
|
| 1103 |
+
def test_matrix_dimensions(self, z=None, H=None, R=None, F=None, Q=None):
|
| 1104 |
+
"""
|
| 1105 |
+
Performs a series of asserts to check that the size of everything
|
| 1106 |
+
is what it should be. This can help you debug problems in your design.
|
| 1107 |
+
If you pass in H, R, F, Q those will be used instead of this object's
|
| 1108 |
+
value for those matrices.
|
| 1109 |
+
Testing `z` (the measurement) is problamatic. x is a vector, and can be
|
| 1110 |
+
implemented as either a 1D array or as a nx1 column vector. Thus Hx
|
| 1111 |
+
can be of different shapes. Then, if Hx is a single value, it can
|
| 1112 |
+
be either a 1D array or 2D vector. If either is true, z can reasonably
|
| 1113 |
+
be a scalar (either '3' or np.array('3') are scalars under this
|
| 1114 |
+
definition), a 1D, 1 element array, or a 2D, 1 element array. You are
|
| 1115 |
+
allowed to pass in any combination that works.
|
| 1116 |
+
"""
|
| 1117 |
+
|
| 1118 |
+
if H is None:
|
| 1119 |
+
H = self.H
|
| 1120 |
+
if R is None:
|
| 1121 |
+
R = self.R
|
| 1122 |
+
if F is None:
|
| 1123 |
+
F = self.F
|
| 1124 |
+
if Q is None:
|
| 1125 |
+
Q = self.Q
|
| 1126 |
+
x = self.x
|
| 1127 |
+
P = self.P
|
| 1128 |
+
|
| 1129 |
+
assert x.ndim == 1 or x.ndim == 2, "x must have one or two dimensions, but has {}".format(x.ndim)
|
| 1130 |
+
|
| 1131 |
+
if x.ndim == 1:
|
| 1132 |
+
assert x.shape[0] == self.dim_x, "Shape of x must be ({},{}), but is {}".format(self.dim_x, 1, x.shape)
|
| 1133 |
+
else:
|
| 1134 |
+
assert x.shape == (self.dim_x, 1), "Shape of x must be ({},{}), but is {}".format(self.dim_x, 1, x.shape)
|
| 1135 |
+
|
| 1136 |
+
assert P.shape == (self.dim_x, self.dim_x), "Shape of P must be ({},{}), but is {}".format(
|
| 1137 |
+
self.dim_x, self.dim_x, P.shape
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
assert Q.shape == (self.dim_x, self.dim_x), "Shape of Q must be ({},{}), but is {}".format(
|
| 1141 |
+
self.dim_x, self.dim_x, P.shape
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
assert F.shape == (self.dim_x, self.dim_x), "Shape of F must be ({},{}), but is {}".format(
|
| 1145 |
+
self.dim_x, self.dim_x, F.shape
|
| 1146 |
+
)
|
| 1147 |
+
|
| 1148 |
+
assert np.ndim(H) == 2, "Shape of H must be (dim_z, {}), but is {}".format(P.shape[0], shape(H))
|
| 1149 |
+
|
| 1150 |
+
assert H.shape[1] == P.shape[0], "Shape of H must be (dim_z, {}), but is {}".format(P.shape[0], H.shape)
|
| 1151 |
+
|
| 1152 |
+
# shape of R must be the same as HPH'
|
| 1153 |
+
hph_shape = (H.shape[0], H.shape[0])
|
| 1154 |
+
r_shape = shape(R)
|
| 1155 |
+
|
| 1156 |
+
if H.shape[0] == 1:
|
| 1157 |
+
# r can be scalar, 1D, or 2D in this case
|
| 1158 |
+
assert r_shape in [(), (1,), (1, 1)], "R must be scalar or one element array, but is shaped {}".format(
|
| 1159 |
+
r_shape
|
| 1160 |
+
)
|
| 1161 |
+
else:
|
| 1162 |
+
assert r_shape == hph_shape, "shape of R should be {} but it is {}".format(hph_shape, r_shape)
|
| 1163 |
+
|
| 1164 |
+
if z is not None:
|
| 1165 |
+
z_shape = shape(z)
|
| 1166 |
+
else:
|
| 1167 |
+
z_shape = (self.dim_z, 1)
|
| 1168 |
+
|
| 1169 |
+
# H@x must have shape of z
|
| 1170 |
+
Hx = dot(H, x)
|
| 1171 |
+
|
| 1172 |
+
if z_shape == (): # scalar or np.array(scalar)
|
| 1173 |
+
assert Hx.ndim == 1 or shape(Hx) == (1, 1), "shape of z should be {}, not {} for the given H".format(
|
| 1174 |
+
shape(Hx), z_shape
|
| 1175 |
+
)
|
| 1176 |
+
|
| 1177 |
+
elif shape(Hx) == (1,):
|
| 1178 |
+
assert z_shape[0] == 1, "Shape of z must be {} for the given H".format(shape(Hx))
|
| 1179 |
+
|
| 1180 |
+
else:
|
| 1181 |
+
assert z_shape == shape(Hx) or (
|
| 1182 |
+
len(z_shape) == 1 and shape(Hx) == (z_shape[0], 1)
|
| 1183 |
+
), "shape of z should be {}, not {} for the given H".format(shape(Hx), z_shape)
|
| 1184 |
+
|
| 1185 |
+
if np.ndim(Hx) > 1 and shape(Hx) != (1, 1):
|
| 1186 |
+
assert shape(Hx) == z_shape, "shape of z should be {} for the given H, but it is {}".format(
|
| 1187 |
+
shape(Hx), z_shape
|
| 1188 |
+
)
|
| 1189 |
+
|
| 1190 |
+
|
| 1191 |
+
def update(x, P, z, R, H=None, return_all=False):
|
| 1192 |
+
"""
|
| 1193 |
+
Add a new measurement (z) to the Kalman filter. If z is None, nothing
|
| 1194 |
+
is changed.
|
| 1195 |
+
This can handle either the multidimensional or unidimensional case. If
|
| 1196 |
+
all parameters are floats instead of arrays the filter will still work,
|
| 1197 |
+
and return floats for x, P as the result.
|
| 1198 |
+
update(1, 2, 1, 1, 1) # univariate
|
| 1199 |
+
update(x, P, 1
|
| 1200 |
+
Parameters
|
| 1201 |
+
----------
|
| 1202 |
+
x : numpy.array(dim_x, 1), or float
|
| 1203 |
+
State estimate vector
|
| 1204 |
+
P : numpy.array(dim_x, dim_x), or float
|
| 1205 |
+
Covariance matrix
|
| 1206 |
+
z : (dim_z, 1): array_like
|
| 1207 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 1208 |
+
otherwise it must be convertible to a column vector.
|
| 1209 |
+
R : numpy.array(dim_z, dim_z), or float
|
| 1210 |
+
Measurement noise matrix
|
| 1211 |
+
H : numpy.array(dim_x, dim_x), or float, optional
|
| 1212 |
+
Measurement function. If not provided, a value of 1 is assumed.
|
| 1213 |
+
return_all : bool, default False
|
| 1214 |
+
If true, y, K, S, and log_likelihood are returned, otherwise
|
| 1215 |
+
only x and P are returned.
|
| 1216 |
+
Returns
|
| 1217 |
+
-------
|
| 1218 |
+
x : numpy.array
|
| 1219 |
+
Posterior state estimate vector
|
| 1220 |
+
P : numpy.array
|
| 1221 |
+
Posterior covariance matrix
|
| 1222 |
+
y : numpy.array or scalar
|
| 1223 |
+
Residua. Difference between measurement and state in measurement space
|
| 1224 |
+
K : numpy.array
|
| 1225 |
+
Kalman gain
|
| 1226 |
+
S : numpy.array
|
| 1227 |
+
System uncertainty in measurement space
|
| 1228 |
+
log_likelihood : float
|
| 1229 |
+
log likelihood of the measurement
|
| 1230 |
+
"""
|
| 1231 |
+
|
| 1232 |
+
# pylint: disable=bare-except
|
| 1233 |
+
|
| 1234 |
+
if z is None:
|
| 1235 |
+
if return_all:
|
| 1236 |
+
return x, P, None, None, None, None
|
| 1237 |
+
return x, P
|
| 1238 |
+
|
| 1239 |
+
if H is None:
|
| 1240 |
+
H = np.array([1])
|
| 1241 |
+
|
| 1242 |
+
if np.isscalar(H):
|
| 1243 |
+
H = np.array([H])
|
| 1244 |
+
|
| 1245 |
+
Hx = np.atleast_1d(dot(H, x))
|
| 1246 |
+
z = reshape_z(z, Hx.shape[0], x.ndim)
|
| 1247 |
+
|
| 1248 |
+
# error (residual) between measurement and prediction
|
| 1249 |
+
y = z - Hx
|
| 1250 |
+
|
| 1251 |
+
# project system uncertainty into measurement space
|
| 1252 |
+
S = dot(dot(H, P), H.T) + R
|
| 1253 |
+
|
| 1254 |
+
# map system uncertainty into kalman gain
|
| 1255 |
+
try:
|
| 1256 |
+
K = dot(dot(P, H.T), linalg.inv(S))
|
| 1257 |
+
except linalg.LinAlgError:
|
| 1258 |
+
# can't invert a 1D array, annoyingly
|
| 1259 |
+
K = dot(dot(P, H.T), 1.0 / S)
|
| 1260 |
+
|
| 1261 |
+
# predict new x with residual scaled by the kalman gain
|
| 1262 |
+
x = x + dot(K, y)
|
| 1263 |
+
|
| 1264 |
+
# P = (I-KH)P(I-KH)' + KRK'
|
| 1265 |
+
KH = dot(K, H)
|
| 1266 |
+
|
| 1267 |
+
try:
|
| 1268 |
+
I_KH = np.eye(KH.shape[0]) - KH
|
| 1269 |
+
except linalg.LinAlgError:
|
| 1270 |
+
I_KH = np.array([1 - KH])
|
| 1271 |
+
P = dot(dot(I_KH, P), I_KH.T) + dot(dot(K, R), K.T)
|
| 1272 |
+
|
| 1273 |
+
if return_all:
|
| 1274 |
+
# compute log likelihood
|
| 1275 |
+
log_likelihood = logpdf(z, dot(H, x), S)
|
| 1276 |
+
return x, P, y, K, S, log_likelihood
|
| 1277 |
+
return x, P
|
| 1278 |
+
|
| 1279 |
+
|
| 1280 |
+
def update_steadystate(x, z, K, H=None):
|
| 1281 |
+
"""
|
| 1282 |
+
Add a new measurement (z) to the Kalman filter. If z is None, nothing
|
| 1283 |
+
is changed.
|
| 1284 |
+
Parameters
|
| 1285 |
+
----------
|
| 1286 |
+
x : numpy.array(dim_x, 1), or float
|
| 1287 |
+
State estimate vector
|
| 1288 |
+
z : (dim_z, 1): array_like
|
| 1289 |
+
measurement for this update. z can be a scalar if dim_z is 1,
|
| 1290 |
+
otherwise it must be convertible to a column vector.
|
| 1291 |
+
K : numpy.array, or float
|
| 1292 |
+
Kalman gain matrix
|
| 1293 |
+
H : numpy.array(dim_x, dim_x), or float, optional
|
| 1294 |
+
Measurement function. If not provided, a value of 1 is assumed.
|
| 1295 |
+
Returns
|
| 1296 |
+
-------
|
| 1297 |
+
x : numpy.array
|
| 1298 |
+
Posterior state estimate vector
|
| 1299 |
+
Examples
|
| 1300 |
+
--------
|
| 1301 |
+
This can handle either the multidimensional or unidimensional case. If
|
| 1302 |
+
all parameters are floats instead of arrays the filter will still work,
|
| 1303 |
+
and return floats for x, P as the result.
|
| 1304 |
+
>>> update_steadystate(1, 2, 1) # univariate
|
| 1305 |
+
>>> update_steadystate(x, P, z, H)
|
| 1306 |
+
"""
|
| 1307 |
+
|
| 1308 |
+
if z is None:
|
| 1309 |
+
return x
|
| 1310 |
+
|
| 1311 |
+
if H is None:
|
| 1312 |
+
H = np.array([1])
|
| 1313 |
+
|
| 1314 |
+
if np.isscalar(H):
|
| 1315 |
+
H = np.array([H])
|
| 1316 |
+
|
| 1317 |
+
Hx = np.atleast_1d(dot(H, x))
|
| 1318 |
+
z = reshape_z(z, Hx.shape[0], x.ndim)
|
| 1319 |
+
|
| 1320 |
+
# error (residual) between measurement and prediction
|
| 1321 |
+
y = z - Hx
|
| 1322 |
+
|
| 1323 |
+
# estimate new x with residual scaled by the kalman gain
|
| 1324 |
+
return x + dot(K, y)
|
| 1325 |
+
|
| 1326 |
+
|
| 1327 |
+
def predict(x, P, F=1, Q=0, u=0, B=1, alpha=1.0):
|
| 1328 |
+
"""
|
| 1329 |
+
Predict next state (prior) using the Kalman filter state propagation
|
| 1330 |
+
equations.
|
| 1331 |
+
Parameters
|
| 1332 |
+
----------
|
| 1333 |
+
x : numpy.array
|
| 1334 |
+
State estimate vector
|
| 1335 |
+
P : numpy.array
|
| 1336 |
+
Covariance matrix
|
| 1337 |
+
F : numpy.array()
|
| 1338 |
+
State Transition matrix
|
| 1339 |
+
Q : numpy.array, Optional
|
| 1340 |
+
Process noise matrix
|
| 1341 |
+
u : numpy.array, Optional, default 0.
|
| 1342 |
+
Control vector. If non-zero, it is multiplied by B
|
| 1343 |
+
to create the control input into the system.
|
| 1344 |
+
B : numpy.array, optional, default 0.
|
| 1345 |
+
Control transition matrix.
|
| 1346 |
+
alpha : float, Optional, default=1.0
|
| 1347 |
+
Fading memory setting. 1.0 gives the normal Kalman filter, and
|
| 1348 |
+
values slightly larger than 1.0 (such as 1.02) give a fading
|
| 1349 |
+
memory effect - previous measurements have less influence on the
|
| 1350 |
+
filter's estimates. This formulation of the Fading memory filter
|
| 1351 |
+
(there are many) is due to Dan Simon
|
| 1352 |
+
Returns
|
| 1353 |
+
-------
|
| 1354 |
+
x : numpy.array
|
| 1355 |
+
Prior state estimate vector
|
| 1356 |
+
P : numpy.array
|
| 1357 |
+
Prior covariance matrix
|
| 1358 |
+
"""
|
| 1359 |
+
|
| 1360 |
+
if np.isscalar(F):
|
| 1361 |
+
F = np.array(F)
|
| 1362 |
+
x = dot(F, x) + dot(B, u)
|
| 1363 |
+
P = (alpha * alpha) * dot(dot(F, P), F.T) + Q
|
| 1364 |
+
|
| 1365 |
+
return x, P
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
def predict_steadystate(x, F=1, u=0, B=1):
|
| 1369 |
+
"""
|
| 1370 |
+
Predict next state (prior) using the Kalman filter state propagation
|
| 1371 |
+
equations. This steady state form only computes x, assuming that the
|
| 1372 |
+
covariance is constant.
|
| 1373 |
+
Parameters
|
| 1374 |
+
----------
|
| 1375 |
+
x : numpy.array
|
| 1376 |
+
State estimate vector
|
| 1377 |
+
P : numpy.array
|
| 1378 |
+
Covariance matrix
|
| 1379 |
+
F : numpy.array()
|
| 1380 |
+
State Transition matrix
|
| 1381 |
+
u : numpy.array, Optional, default 0.
|
| 1382 |
+
Control vector. If non-zero, it is multiplied by B
|
| 1383 |
+
to create the control input into the system.
|
| 1384 |
+
B : numpy.array, optional, default 0.
|
| 1385 |
+
Control transition matrix.
|
| 1386 |
+
Returns
|
| 1387 |
+
-------
|
| 1388 |
+
x : numpy.array
|
| 1389 |
+
Prior state estimate vector
|
| 1390 |
+
"""
|
| 1391 |
+
|
| 1392 |
+
if np.isscalar(F):
|
| 1393 |
+
F = np.array(F)
|
| 1394 |
+
x = dot(F, x) + dot(B, u)
|
| 1395 |
+
|
| 1396 |
+
return x
|
| 1397 |
+
|
| 1398 |
+
|
| 1399 |
+
def batch_filter(x, P, zs, Fs, Qs, Hs, Rs, Bs=None, us=None, update_first=False, saver=None):
|
| 1400 |
+
"""
|
| 1401 |
+
Batch processes a sequences of measurements.
|
| 1402 |
+
Parameters
|
| 1403 |
+
----------
|
| 1404 |
+
zs : list-like
|
| 1405 |
+
list of measurements at each time step. Missing measurements must be
|
| 1406 |
+
represented by None.
|
| 1407 |
+
Fs : list-like
|
| 1408 |
+
list of values to use for the state transition matrix matrix.
|
| 1409 |
+
Qs : list-like
|
| 1410 |
+
list of values to use for the process error
|
| 1411 |
+
covariance.
|
| 1412 |
+
Hs : list-like
|
| 1413 |
+
list of values to use for the measurement matrix.
|
| 1414 |
+
Rs : list-like
|
| 1415 |
+
list of values to use for the measurement error
|
| 1416 |
+
covariance.
|
| 1417 |
+
Bs : list-like, optional
|
| 1418 |
+
list of values to use for the control transition matrix;
|
| 1419 |
+
a value of None in any position will cause the filter
|
| 1420 |
+
to use `self.B` for that time step.
|
| 1421 |
+
us : list-like, optional
|
| 1422 |
+
list of values to use for the control input vector;
|
| 1423 |
+
a value of None in any position will cause the filter to use
|
| 1424 |
+
0 for that time step.
|
| 1425 |
+
update_first : bool, optional
|
| 1426 |
+
controls whether the order of operations is update followed by
|
| 1427 |
+
predict, or predict followed by update. Default is predict->update.
|
| 1428 |
+
saver : filterpy.common.Saver, optional
|
| 1429 |
+
filterpy.common.Saver object. If provided, saver.save() will be
|
| 1430 |
+
called after every epoch
|
| 1431 |
+
Returns
|
| 1432 |
+
-------
|
| 1433 |
+
means : np.array((n,dim_x,1))
|
| 1434 |
+
array of the state for each time step after the update. Each entry
|
| 1435 |
+
is an np.array. In other words `means[k,:]` is the state at step
|
| 1436 |
+
`k`.
|
| 1437 |
+
covariance : np.array((n,dim_x,dim_x))
|
| 1438 |
+
array of the covariances for each time step after the update.
|
| 1439 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
| 1440 |
+
means_predictions : np.array((n,dim_x,1))
|
| 1441 |
+
array of the state for each time step after the predictions. Each
|
| 1442 |
+
entry is an np.array. In other words `means[k,:]` is the state at
|
| 1443 |
+
step `k`.
|
| 1444 |
+
covariance_predictions : np.array((n,dim_x,dim_x))
|
| 1445 |
+
array of the covariances for each time step after the prediction.
|
| 1446 |
+
In other words `covariance[k,:,:]` is the covariance at step `k`.
|
| 1447 |
+
Examples
|
| 1448 |
+
--------
|
| 1449 |
+
.. code-block:: Python
|
| 1450 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
| 1451 |
+
Fs = [kf.F for t in range (40)]
|
| 1452 |
+
Hs = [kf.H for t in range (40)]
|
| 1453 |
+
(mu, cov, _, _) = kf.batch_filter(zs, Rs=R_list, Fs=Fs, Hs=Hs, Qs=None,
|
| 1454 |
+
Bs=None, us=None, update_first=False)
|
| 1455 |
+
(xs, Ps, Ks, Pps) = kf.rts_smoother(mu, cov, Fs=Fs, Qs=None)
|
| 1456 |
+
"""
|
| 1457 |
+
|
| 1458 |
+
n = np.size(zs, 0)
|
| 1459 |
+
dim_x = x.shape[0]
|
| 1460 |
+
|
| 1461 |
+
# mean estimates from Kalman Filter
|
| 1462 |
+
if x.ndim == 1:
|
| 1463 |
+
means = zeros((n, dim_x))
|
| 1464 |
+
means_p = zeros((n, dim_x))
|
| 1465 |
+
else:
|
| 1466 |
+
means = zeros((n, dim_x, 1))
|
| 1467 |
+
means_p = zeros((n, dim_x, 1))
|
| 1468 |
+
|
| 1469 |
+
# state covariances from Kalman Filter
|
| 1470 |
+
covariances = zeros((n, dim_x, dim_x))
|
| 1471 |
+
covariances_p = zeros((n, dim_x, dim_x))
|
| 1472 |
+
|
| 1473 |
+
if us is None:
|
| 1474 |
+
us = [0.0] * n
|
| 1475 |
+
Bs = [0.0] * n
|
| 1476 |
+
|
| 1477 |
+
if update_first:
|
| 1478 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
| 1479 |
+
|
| 1480 |
+
x, P = update(x, P, z, R=R, H=H)
|
| 1481 |
+
means[i, :] = x
|
| 1482 |
+
covariances[i, :, :] = P
|
| 1483 |
+
|
| 1484 |
+
x, P = predict(x, P, u=u, B=B, F=F, Q=Q)
|
| 1485 |
+
means_p[i, :] = x
|
| 1486 |
+
covariances_p[i, :, :] = P
|
| 1487 |
+
if saver is not None:
|
| 1488 |
+
saver.save()
|
| 1489 |
+
else:
|
| 1490 |
+
for i, (z, F, Q, H, R, B, u) in enumerate(zip(zs, Fs, Qs, Hs, Rs, Bs, us)):
|
| 1491 |
+
|
| 1492 |
+
x, P = predict(x, P, u=u, B=B, F=F, Q=Q)
|
| 1493 |
+
means_p[i, :] = x
|
| 1494 |
+
covariances_p[i, :, :] = P
|
| 1495 |
+
|
| 1496 |
+
x, P = update(x, P, z, R=R, H=H)
|
| 1497 |
+
means[i, :] = x
|
| 1498 |
+
covariances[i, :, :] = P
|
| 1499 |
+
if saver is not None:
|
| 1500 |
+
saver.save()
|
| 1501 |
+
|
| 1502 |
+
return (means, covariances, means_p, covariances_p)
|
| 1503 |
+
|
| 1504 |
+
|
| 1505 |
+
def rts_smoother(Xs, Ps, Fs, Qs):
|
| 1506 |
+
"""
|
| 1507 |
+
Runs the Rauch-Tung-Striebel Kalman smoother on a set of
|
| 1508 |
+
means and covariances computed by a Kalman filter. The usual input
|
| 1509 |
+
would come from the output of `KalmanFilter.batch_filter()`.
|
| 1510 |
+
Parameters
|
| 1511 |
+
----------
|
| 1512 |
+
Xs : numpy.array
|
| 1513 |
+
array of the means (state variable x) of the output of a Kalman
|
| 1514 |
+
filter.
|
| 1515 |
+
Ps : numpy.array
|
| 1516 |
+
array of the covariances of the output of a kalman filter.
|
| 1517 |
+
Fs : list-like collection of numpy.array
|
| 1518 |
+
State transition matrix of the Kalman filter at each time step.
|
| 1519 |
+
Qs : list-like collection of numpy.array, optional
|
| 1520 |
+
Process noise of the Kalman filter at each time step.
|
| 1521 |
+
Returns
|
| 1522 |
+
-------
|
| 1523 |
+
x : numpy.ndarray
|
| 1524 |
+
smoothed means
|
| 1525 |
+
P : numpy.ndarray
|
| 1526 |
+
smoothed state covariances
|
| 1527 |
+
K : numpy.ndarray
|
| 1528 |
+
smoother gain at each step
|
| 1529 |
+
pP : numpy.ndarray
|
| 1530 |
+
predicted state covariances
|
| 1531 |
+
Examples
|
| 1532 |
+
--------
|
| 1533 |
+
.. code-block:: Python
|
| 1534 |
+
zs = [t + random.randn()*4 for t in range (40)]
|
| 1535 |
+
(mu, cov, _, _) = kalman.batch_filter(zs)
|
| 1536 |
+
(x, P, K, pP) = rts_smoother(mu, cov, kf.F, kf.Q)
|
| 1537 |
+
"""
|
| 1538 |
+
|
| 1539 |
+
if len(Xs) != len(Ps):
|
| 1540 |
+
raise ValueError("length of Xs and Ps must be the same")
|
| 1541 |
+
|
| 1542 |
+
n = Xs.shape[0]
|
| 1543 |
+
dim_x = Xs.shape[1]
|
| 1544 |
+
|
| 1545 |
+
# smoother gain
|
| 1546 |
+
K = zeros((n, dim_x, dim_x))
|
| 1547 |
+
x, P, pP = Xs.copy(), Ps.copy(), Ps.copy()
|
| 1548 |
+
|
| 1549 |
+
for k in range(n - 2, -1, -1):
|
| 1550 |
+
pP[k] = dot(dot(Fs[k], P[k]), Fs[k].T) + Qs[k]
|
| 1551 |
+
|
| 1552 |
+
# pylint: disable=bad-whitespace
|
| 1553 |
+
K[k] = dot(dot(P[k], Fs[k].T), linalg.inv(pP[k]))
|
| 1554 |
+
x[k] += dot(K[k], x[k + 1] - dot(Fs[k], x[k]))
|
| 1555 |
+
P[k] += dot(dot(K[k], P[k + 1] - pP[k]), K[k].T)
|
| 1556 |
+
|
| 1557 |
+
return (x, P, K, pP)
|
src/gesturedetection/onnx_models.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import onnxruntime as ort
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class OnnxModel(ABC):
|
| 9 |
+
def __init__(self, model_path, image_size):
|
| 10 |
+
self.model_path = model_path
|
| 11 |
+
self.image_size = image_size
|
| 12 |
+
self.mean = np.array([127, 127, 127], dtype=np.float32)
|
| 13 |
+
self.std = np.array([128, 128, 128], dtype=np.float32)
|
| 14 |
+
options, prov_opts, providers = self.get_onnx_provider()
|
| 15 |
+
self.sess = ort.InferenceSession(
|
| 16 |
+
model_path, sess_options=options, providers=providers, provider_options=prov_opts
|
| 17 |
+
)
|
| 18 |
+
self._get_input_output()
|
| 19 |
+
|
| 20 |
+
def preprocess(self, frame):
|
| 21 |
+
"""
|
| 22 |
+
Preprocess frame
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
frame : np.ndarray
|
| 26 |
+
Frame to preprocess
|
| 27 |
+
Returns
|
| 28 |
+
-------
|
| 29 |
+
np.ndarray
|
| 30 |
+
Preprocessed frame
|
| 31 |
+
"""
|
| 32 |
+
image = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
| 33 |
+
image = cv2.resize(image, self.image_size)
|
| 34 |
+
image = (image - self.mean) / self.std
|
| 35 |
+
image = np.transpose(image, [2, 0, 1])
|
| 36 |
+
image = np.expand_dims(image, axis=0)
|
| 37 |
+
return image
|
| 38 |
+
|
| 39 |
+
def _get_input_output(self):
|
| 40 |
+
inputs = self.sess.get_inputs()
|
| 41 |
+
self.inputs = "".join(
|
| 42 |
+
[
|
| 43 |
+
f"\n {i}: {input.name}" f" Shape: ({','.join(map(str, input.shape))})" f" Dtype: {input.type}"
|
| 44 |
+
for i, input in enumerate(inputs)
|
| 45 |
+
]
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
outputs = self.sess.get_outputs()
|
| 49 |
+
self.outputs = "".join(
|
| 50 |
+
[
|
| 51 |
+
f"\n {i}: {output.name}" f" Shape: ({','.join(map(str, output.shape))})" f" Dtype: {output.type}"
|
| 52 |
+
for i, output in enumerate(outputs)
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def get_onnx_provider():
|
| 58 |
+
"""
|
| 59 |
+
Get onnx provider
|
| 60 |
+
Returns
|
| 61 |
+
-------
|
| 62 |
+
options : onnxruntime.SessionOptions
|
| 63 |
+
Session options
|
| 64 |
+
prov_opts : dict
|
| 65 |
+
Provider options
|
| 66 |
+
providers : list
|
| 67 |
+
List of providers
|
| 68 |
+
"""
|
| 69 |
+
providers = ["CPUExecutionProvider"]
|
| 70 |
+
options = ort.SessionOptions()
|
| 71 |
+
options.enable_mem_pattern = False
|
| 72 |
+
options.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL
|
| 73 |
+
prov_opts = []
|
| 74 |
+
print("Using ONNX Runtime", ort.get_device())
|
| 75 |
+
|
| 76 |
+
if "DML" in ort.get_device():
|
| 77 |
+
prov_opts = [{"device_id": 0}]
|
| 78 |
+
providers.append("DmlExecutionProvider")
|
| 79 |
+
|
| 80 |
+
elif "GPU" in ort.get_device():
|
| 81 |
+
prov_opts = [
|
| 82 |
+
{
|
| 83 |
+
"device_id": 0,
|
| 84 |
+
"arena_extend_strategy": "kNextPowerOfTwo",
|
| 85 |
+
"gpu_mem_limit": 2 * 1024 * 1024 * 1024,
|
| 86 |
+
"cudnn_conv_algo_search": "EXHAUSTIVE",
|
| 87 |
+
"do_copy_in_default_stream": True,
|
| 88 |
+
}
|
| 89 |
+
]
|
| 90 |
+
providers.append("CUDAExecutionProvider")
|
| 91 |
+
|
| 92 |
+
return options, prov_opts, providers
|
| 93 |
+
|
| 94 |
+
def __repr__(self):
|
| 95 |
+
return (
|
| 96 |
+
f"Providers: {self.sess.get_providers()}\n"
|
| 97 |
+
f"Model: {self.sess.get_modelmeta().description}\n"
|
| 98 |
+
f"Version: {self.sess.get_modelmeta().version}\n"
|
| 99 |
+
f"Inputs: {self.inputs}\n"
|
| 100 |
+
f"Outputs: {self.outputs}"
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
class HandDetection(OnnxModel):
|
| 104 |
+
def __init__(self, model_path, image_size=(320, 240)):
|
| 105 |
+
super().__init__(model_path, image_size)
|
| 106 |
+
self.image_size = image_size
|
| 107 |
+
self.sess = ort.InferenceSession(model_path)
|
| 108 |
+
self.input_name = self.sess.get_inputs()[0].name
|
| 109 |
+
self.output_names = [output.name for output in self.sess.get_outputs()]
|
| 110 |
+
|
| 111 |
+
def __call__(self, frame):
|
| 112 |
+
input_tensor = self.preprocess(frame)
|
| 113 |
+
boxes, _, probs = self.sess.run(self.output_names, {self.input_name: input_tensor})
|
| 114 |
+
width, height = frame.shape[1], frame.shape[0]
|
| 115 |
+
boxes[:, 0] *= width
|
| 116 |
+
boxes[:, 1] *= height
|
| 117 |
+
boxes[:, 2] *= width
|
| 118 |
+
boxes[:, 3] *= height
|
| 119 |
+
return boxes.astype(np.int32), probs
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class HandClassification(OnnxModel):
|
| 123 |
+
def __init__(self, model_path, image_size=(128, 128)):
|
| 124 |
+
super().__init__(model_path, image_size)
|
| 125 |
+
|
| 126 |
+
@staticmethod
|
| 127 |
+
def get_square(box, image):
|
| 128 |
+
"""
|
| 129 |
+
Get square box
|
| 130 |
+
Parameters
|
| 131 |
+
----------
|
| 132 |
+
box : np.ndarray
|
| 133 |
+
Box coordinates (x1, y1, x2, y2)
|
| 134 |
+
image : np.ndarray
|
| 135 |
+
Image for shape
|
| 136 |
+
"""
|
| 137 |
+
height, width, _ = image.shape
|
| 138 |
+
x0, y0, x1, y1 = box
|
| 139 |
+
w, h = x1 - x0, y1 - y0
|
| 140 |
+
if h < w:
|
| 141 |
+
y0 = y0 - int((w - h) / 2)
|
| 142 |
+
y1 = y0 + w
|
| 143 |
+
if h > w:
|
| 144 |
+
x0 = x0 - int((h - w) / 2)
|
| 145 |
+
x1 = x0 + h
|
| 146 |
+
x0 = max(0, x0)
|
| 147 |
+
y0 = max(0, y0)
|
| 148 |
+
x1 = min(width - 1, x1)
|
| 149 |
+
y1 = min(height - 1, y1)
|
| 150 |
+
return x0, y0, x1, y1
|
| 151 |
+
|
| 152 |
+
def get_crops(self, frame, bboxes):
|
| 153 |
+
"""
|
| 154 |
+
Get crops from frame
|
| 155 |
+
Parameters
|
| 156 |
+
----------
|
| 157 |
+
frame : np.ndarray
|
| 158 |
+
Frame to crop from bboxes
|
| 159 |
+
bboxes : np.ndarray
|
| 160 |
+
Bounding boxes
|
| 161 |
+
|
| 162 |
+
Returns
|
| 163 |
+
-------
|
| 164 |
+
crops : np.ndarray
|
| 165 |
+
Crops from frame
|
| 166 |
+
"""
|
| 167 |
+
crops = []
|
| 168 |
+
for bbox in bboxes:
|
| 169 |
+
bbox = self.get_square(bbox, frame)
|
| 170 |
+
crop = frame[bbox[1] : bbox[3], bbox[0] : bbox[2]]
|
| 171 |
+
crops.append(crop)
|
| 172 |
+
return crops
|
| 173 |
+
|
| 174 |
+
def __call__(self, image, bboxes):
|
| 175 |
+
"""
|
| 176 |
+
Get predictions from model
|
| 177 |
+
Parameters
|
| 178 |
+
----------
|
| 179 |
+
image : np.ndarray
|
| 180 |
+
Image to predict
|
| 181 |
+
bboxes : np.ndarray
|
| 182 |
+
Bounding boxes
|
| 183 |
+
|
| 184 |
+
Returns
|
| 185 |
+
-------
|
| 186 |
+
predictions : np.ndarray
|
| 187 |
+
Predictions from model
|
| 188 |
+
"""
|
| 189 |
+
crops = self.get_crops(image, bboxes)
|
| 190 |
+
crops = [self.preprocess(crop) for crop in crops]
|
| 191 |
+
input_name = self.sess.get_inputs()[0].name
|
| 192 |
+
outputs = self.sess.run(None, {input_name: np.concatenate(crops, axis=0)})[0]
|
| 193 |
+
labels = np.argmax(outputs, axis=1)
|
| 194 |
+
return labels
|
src/gesturedetection/utils/__init__.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .action_controller import Deque
|
| 2 |
+
from .box_utils_numpy import hard_nms
|
| 3 |
+
from .drawer import Drawer
|
| 4 |
+
from .enums import Event, HandPosition, targets
|
| 5 |
+
from .hand import Hand
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"Deque",
|
| 10 |
+
"hard_nms",
|
| 11 |
+
"Drawer",
|
| 12 |
+
"Event",
|
| 13 |
+
"HandPosition",
|
| 14 |
+
"targets",
|
| 15 |
+
"Hand"
|
| 16 |
+
]
|
src/gesturedetection/utils/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (504 Bytes). View file
|
|
|
src/gesturedetection/utils/__pycache__/__init__.cpython-39.pyc
ADDED
|
Binary file (465 Bytes). View file
|
|
|
src/gesturedetection/utils/__pycache__/action_controller.cpython-312.pyc
ADDED
|
Binary file (25.8 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/action_controller.cpython-39.pyc
ADDED
|
Binary file (12 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/box_utils_numpy.cpython-312.pyc
ADDED
|
Binary file (7.47 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/box_utils_numpy.cpython-39.pyc
ADDED
|
Binary file (5.46 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/drawer.cpython-312.pyc
ADDED
|
Binary file (9.86 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/drawer.cpython-39.pyc
ADDED
|
Binary file (4.29 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/enums.cpython-312.pyc
ADDED
|
Binary file (2.59 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/enums.cpython-39.pyc
ADDED
|
Binary file (2.33 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/hand.cpython-312.pyc
ADDED
|
Binary file (1.7 kB). View file
|
|
|
src/gesturedetection/utils/__pycache__/hand.cpython-39.pyc
ADDED
|
Binary file (1.18 kB). View file
|
|
|
src/gesturedetection/utils/action_controller.py
ADDED
|
@@ -0,0 +1,598 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scipy.spatial import distance
|
| 2 |
+
from collections import deque
|
| 3 |
+
|
| 4 |
+
from .enums import Event, HandPosition, targets
|
| 5 |
+
from .hand import Hand
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Deque:
|
| 9 |
+
def __init__(self, maxlen=30, min_frames=20):
|
| 10 |
+
self.maxlen = maxlen
|
| 11 |
+
self._deque = []
|
| 12 |
+
self.action = None
|
| 13 |
+
self.min_absolute_distance = 1.5
|
| 14 |
+
self.min_frames = min_frames
|
| 15 |
+
self.action_deque = deque(maxlen=5)
|
| 16 |
+
|
| 17 |
+
def __len__(self):
|
| 18 |
+
return len(self._deque)
|
| 19 |
+
|
| 20 |
+
def index_position(self, x):
|
| 21 |
+
for i in range(len(self._deque)):
|
| 22 |
+
if self._deque[i].position == x:
|
| 23 |
+
return i
|
| 24 |
+
|
| 25 |
+
def index_gesture(self, x):
|
| 26 |
+
for i in range(len(self._deque)):
|
| 27 |
+
if self._deque[i].gesture == x:
|
| 28 |
+
return i
|
| 29 |
+
|
| 30 |
+
def __getitem__(self, index):
|
| 31 |
+
return self._deque[index]
|
| 32 |
+
|
| 33 |
+
def __setitem__(self, index, value):
|
| 34 |
+
self._deque[index] = value
|
| 35 |
+
|
| 36 |
+
def __delitem__(self, index):
|
| 37 |
+
del self._deque[index]
|
| 38 |
+
|
| 39 |
+
def __iter__(self):
|
| 40 |
+
return iter(self._deque)
|
| 41 |
+
|
| 42 |
+
def __reversed__(self):
|
| 43 |
+
return reversed(self._deque)
|
| 44 |
+
|
| 45 |
+
def append(self, x):
|
| 46 |
+
if self.maxlen is not None and len(self) >= self.maxlen:
|
| 47 |
+
self._deque.pop(0)
|
| 48 |
+
self.set_hand_position(x)
|
| 49 |
+
self._deque.append(x)
|
| 50 |
+
self.check_is_action(x)
|
| 51 |
+
|
| 52 |
+
def check_duration(self, start_index, min_frames=None):
|
| 53 |
+
"""
|
| 54 |
+
Check duration of swipe.
|
| 55 |
+
|
| 56 |
+
Parameters
|
| 57 |
+
----------
|
| 58 |
+
start_index : int
|
| 59 |
+
Index of start position of swipe.
|
| 60 |
+
|
| 61 |
+
Returns
|
| 62 |
+
-------
|
| 63 |
+
bool
|
| 64 |
+
True if duration of swipe is more than min_frames.
|
| 65 |
+
"""
|
| 66 |
+
if min_frames == None:
|
| 67 |
+
min_frames = self.min_frames
|
| 68 |
+
if len(self) - start_index >= min_frames:
|
| 69 |
+
return True
|
| 70 |
+
else:
|
| 71 |
+
return False
|
| 72 |
+
|
| 73 |
+
def check_duration_max(self, start_index, max_frames=10):
|
| 74 |
+
"""
|
| 75 |
+
Check duration of swipe.
|
| 76 |
+
|
| 77 |
+
Parameters
|
| 78 |
+
----------
|
| 79 |
+
start_index : int
|
| 80 |
+
Index of start position of swipe.
|
| 81 |
+
|
| 82 |
+
Returns
|
| 83 |
+
-------
|
| 84 |
+
bool
|
| 85 |
+
True if duration of swipe is more than min_frames.
|
| 86 |
+
"""
|
| 87 |
+
if len(self) - start_index <= max_frames:
|
| 88 |
+
return True
|
| 89 |
+
else:
|
| 90 |
+
return False
|
| 91 |
+
|
| 92 |
+
def check_is_action(self, x):
|
| 93 |
+
"""
|
| 94 |
+
Check if gesture is action.
|
| 95 |
+
|
| 96 |
+
Parameters
|
| 97 |
+
----------
|
| 98 |
+
x : Hand
|
| 99 |
+
Hand object.
|
| 100 |
+
|
| 101 |
+
Returns
|
| 102 |
+
-------
|
| 103 |
+
bool
|
| 104 |
+
True if gesture is action.
|
| 105 |
+
"""
|
| 106 |
+
if x.position == HandPosition.LEFT_END and HandPosition.RIGHT_START in self:
|
| 107 |
+
start_index = self.index_position(HandPosition.RIGHT_START)
|
| 108 |
+
if (
|
| 109 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 110 |
+
and self.check_duration(start_index)
|
| 111 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 112 |
+
):
|
| 113 |
+
self.action = Event.SWIPE_LEFT
|
| 114 |
+
self.clear()
|
| 115 |
+
return True
|
| 116 |
+
|
| 117 |
+
elif x.position == HandPosition.RIGHT_END and HandPosition.LEFT_START in self:
|
| 118 |
+
start_index = self.index_position(HandPosition.LEFT_START)
|
| 119 |
+
if (
|
| 120 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 121 |
+
and self.check_duration(start_index)
|
| 122 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 123 |
+
):
|
| 124 |
+
self.action = Event.SWIPE_RIGHT
|
| 125 |
+
self.clear()
|
| 126 |
+
return True
|
| 127 |
+
else:
|
| 128 |
+
self.clear()
|
| 129 |
+
|
| 130 |
+
elif x.position == HandPosition.UP_END and HandPosition.DOWN_START in self:
|
| 131 |
+
start_index = self.index_position(HandPosition.DOWN_START)
|
| 132 |
+
if (
|
| 133 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 134 |
+
and self.check_duration(start_index)
|
| 135 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 136 |
+
):
|
| 137 |
+
self.action = Event.SWIPE_UP
|
| 138 |
+
self.clear()
|
| 139 |
+
return True
|
| 140 |
+
else:
|
| 141 |
+
self.clear()
|
| 142 |
+
|
| 143 |
+
elif x.position == HandPosition.DOWN_END and HandPosition.UP_START in self:
|
| 144 |
+
start_index = self.index_position(HandPosition.UP_START)
|
| 145 |
+
if (
|
| 146 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 147 |
+
and self.check_duration(start_index)
|
| 148 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 149 |
+
):
|
| 150 |
+
self.action = Event.SWIPE_DOWN
|
| 151 |
+
self.clear()
|
| 152 |
+
return True
|
| 153 |
+
else:
|
| 154 |
+
self.clear()
|
| 155 |
+
|
| 156 |
+
elif x.gesture == 18: # grip
|
| 157 |
+
if self.action is None:
|
| 158 |
+
start_index = self.index_gesture(18)
|
| 159 |
+
if self.check_duration(start_index):
|
| 160 |
+
self.action = Event.DRAG2
|
| 161 |
+
return True
|
| 162 |
+
|
| 163 |
+
elif self.action == Event.DRAG2 and x.gesture in [11, 12]: # hand heart
|
| 164 |
+
self.action = Event.DROP2
|
| 165 |
+
self.clear()
|
| 166 |
+
return True
|
| 167 |
+
|
| 168 |
+
elif x.gesture == 29: # ok
|
| 169 |
+
if self.action is None:
|
| 170 |
+
start_index = self.index_gesture(29)
|
| 171 |
+
if self.check_duration(start_index):
|
| 172 |
+
self.action = Event.DRAG3
|
| 173 |
+
return True
|
| 174 |
+
|
| 175 |
+
elif self.action == Event.DRAG3 and x.gesture in [11, 12]: # hand heart
|
| 176 |
+
self.action = Event.DROP3
|
| 177 |
+
self.clear()
|
| 178 |
+
return True
|
| 179 |
+
|
| 180 |
+
elif x.position == HandPosition.FAST_SWIPE_UP_END and HandPosition.FAST_SWIPE_UP_START in self:
|
| 181 |
+
start_index = self.index_position(HandPosition.FAST_SWIPE_UP_START)
|
| 182 |
+
if (
|
| 183 |
+
self.check_duration(start_index, min_frames=20)
|
| 184 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 185 |
+
):
|
| 186 |
+
self.action = Event.FAST_SWIPE_UP
|
| 187 |
+
self.clear()
|
| 188 |
+
return True
|
| 189 |
+
else:
|
| 190 |
+
self.clear()
|
| 191 |
+
|
| 192 |
+
elif x.position == HandPosition.FAST_SWIPE_DOWN_END and HandPosition.FAST_SWIPE_DOWN_START in self:
|
| 193 |
+
start_index = self.index_position(HandPosition.FAST_SWIPE_DOWN_START)
|
| 194 |
+
if (
|
| 195 |
+
self.check_duration(start_index, min_frames=20)
|
| 196 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 197 |
+
):
|
| 198 |
+
self.action = Event.FAST_SWIPE_DOWN
|
| 199 |
+
self.clear()
|
| 200 |
+
return True
|
| 201 |
+
|
| 202 |
+
elif x.position == HandPosition.ZOOM_IN_END and HandPosition.ZOOM_IN_START in self:
|
| 203 |
+
start_index = self.index_position(HandPosition.ZOOM_IN_START)
|
| 204 |
+
if (
|
| 205 |
+
self.check_duration(start_index, min_frames=20)
|
| 206 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 207 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 208 |
+
):
|
| 209 |
+
self.action = Event.ZOOM_IN
|
| 210 |
+
self.clear()
|
| 211 |
+
return True
|
| 212 |
+
|
| 213 |
+
elif x.position == HandPosition.ZOOM_OUT_END and HandPosition.ZOOM_OUT_START in self:
|
| 214 |
+
start_index = self.index_position(HandPosition.ZOOM_OUT_START)
|
| 215 |
+
if (
|
| 216 |
+
self.check_duration(start_index, min_frames=20)
|
| 217 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 218 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 219 |
+
):
|
| 220 |
+
self.action = Event.ZOOM_OUT
|
| 221 |
+
self.clear()
|
| 222 |
+
return True
|
| 223 |
+
else:
|
| 224 |
+
self.clear()
|
| 225 |
+
|
| 226 |
+
elif x.position == HandPosition.LEFT_END2 and HandPosition.RIGHT_START2 in self:
|
| 227 |
+
|
| 228 |
+
start_index = self.index_position(HandPosition.RIGHT_START2)
|
| 229 |
+
if (
|
| 230 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 231 |
+
and self.check_duration(start_index)
|
| 232 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 233 |
+
):
|
| 234 |
+
self.action = Event.SWIPE_LEFT2
|
| 235 |
+
self.clear()
|
| 236 |
+
return True
|
| 237 |
+
else:
|
| 238 |
+
self.clear()
|
| 239 |
+
|
| 240 |
+
elif x.position == HandPosition.RIGHT_END2 and HandPosition.LEFT_START2 in self:
|
| 241 |
+
start_index = self.index_position(HandPosition.LEFT_START2)
|
| 242 |
+
if (
|
| 243 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 244 |
+
and self.check_duration(start_index)
|
| 245 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 246 |
+
):
|
| 247 |
+
self.action = Event.SWIPE_RIGHT2
|
| 248 |
+
self.clear()
|
| 249 |
+
return True
|
| 250 |
+
else:
|
| 251 |
+
self.clear()
|
| 252 |
+
|
| 253 |
+
elif x.position == HandPosition.UP_END2 and HandPosition.DOWN_START2 in self:
|
| 254 |
+
start_index = self.index_position(HandPosition.DOWN_START2)
|
| 255 |
+
if (
|
| 256 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 257 |
+
and self.check_duration(start_index)
|
| 258 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 259 |
+
):
|
| 260 |
+
self.action = Event.SWIPE_UP2
|
| 261 |
+
self.clear()
|
| 262 |
+
return True
|
| 263 |
+
else:
|
| 264 |
+
self.clear()
|
| 265 |
+
|
| 266 |
+
elif x.position == HandPosition.LEFT_END3 and HandPosition.RIGHT_START3 in self:
|
| 267 |
+
start_index = self.index_position(HandPosition.RIGHT_START3)
|
| 268 |
+
if (
|
| 269 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 270 |
+
and self.check_duration(start_index)
|
| 271 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 272 |
+
):
|
| 273 |
+
self.action = Event.SWIPE_LEFT3 # two
|
| 274 |
+
self.clear()
|
| 275 |
+
return True
|
| 276 |
+
else:
|
| 277 |
+
self.clear()
|
| 278 |
+
|
| 279 |
+
elif x.position == HandPosition.RIGHT_END3 and HandPosition.LEFT_START3 in self:
|
| 280 |
+
start_index = self.index_position(HandPosition.LEFT_START3)
|
| 281 |
+
if (
|
| 282 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 283 |
+
and self.check_duration(start_index)
|
| 284 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 285 |
+
):
|
| 286 |
+
self.action = Event.SWIPE_RIGHT3
|
| 287 |
+
self.clear()
|
| 288 |
+
return True
|
| 289 |
+
else:
|
| 290 |
+
self.clear()
|
| 291 |
+
|
| 292 |
+
elif x.position == HandPosition.UP_END3 and HandPosition.DOWN_START3 in self:
|
| 293 |
+
start_index = self.index_position(HandPosition.DOWN_START3)
|
| 294 |
+
if (
|
| 295 |
+
self.check_duration(start_index, min_frames=15)
|
| 296 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 297 |
+
):
|
| 298 |
+
self.action = Event.SWIPE_UP3
|
| 299 |
+
self.clear()
|
| 300 |
+
return True
|
| 301 |
+
else:
|
| 302 |
+
self.clear()
|
| 303 |
+
|
| 304 |
+
elif x.position == HandPosition.DOWN_END3 and HandPosition.UP_START3 in self:
|
| 305 |
+
start_index = self.index_position(HandPosition.UP_START3)
|
| 306 |
+
if (
|
| 307 |
+
self.check_duration(start_index, min_frames=15)
|
| 308 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 309 |
+
):
|
| 310 |
+
self.action = Event.SWIPE_DOWN3
|
| 311 |
+
self.clear()
|
| 312 |
+
return True
|
| 313 |
+
else:
|
| 314 |
+
self.clear()
|
| 315 |
+
|
| 316 |
+
elif HandPosition.DRAG_START in self and x.gesture == 25: # fist
|
| 317 |
+
if self.action is None:
|
| 318 |
+
start_index = self.index_gesture(17) # grabbing
|
| 319 |
+
|
| 320 |
+
if self.check_duration(start_index, min_frames=3):
|
| 321 |
+
self.action = Event.DRAG
|
| 322 |
+
return True
|
| 323 |
+
else:
|
| 324 |
+
self.clear()
|
| 325 |
+
|
| 326 |
+
elif HandPosition.ZOOM_IN_START in self and x.gesture == 19: # point
|
| 327 |
+
start_index = self.index_position(HandPosition.ZOOM_IN_START)
|
| 328 |
+
if (
|
| 329 |
+
self.check_duration(start_index, min_frames=8)
|
| 330 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 331 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 332 |
+
):
|
| 333 |
+
self.action = Event.TAP
|
| 334 |
+
self.clear()
|
| 335 |
+
return True
|
| 336 |
+
elif (
|
| 337 |
+
self.check_duration(start_index, min_frames=2)
|
| 338 |
+
and self.check_duration_max(start_index, max_frames=8)
|
| 339 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 340 |
+
and self.check_horizontal_swipe(self._deque[start_index], x)
|
| 341 |
+
):
|
| 342 |
+
self.action_deque.append(Event.TAP)
|
| 343 |
+
if len(self.action_deque) >= 2 and self.action_deque[-1] == Event.TAP and self.action_deque[-2] == Event.TAP:
|
| 344 |
+
self.action_deque.pop()
|
| 345 |
+
self.action_deque.pop()
|
| 346 |
+
self.action = Event.DOUBLE_TAP
|
| 347 |
+
self.clear()
|
| 348 |
+
return True
|
| 349 |
+
else:
|
| 350 |
+
self.clear()
|
| 351 |
+
|
| 352 |
+
elif x.position == HandPosition.DOWN_END2 and HandPosition.ZOOM_OUT_START in self:
|
| 353 |
+
start_index = self.index_position(HandPosition.ZOOM_OUT_START)
|
| 354 |
+
if (
|
| 355 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 356 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 357 |
+
):
|
| 358 |
+
self.action = Event.SWIPE_DOWN2
|
| 359 |
+
self.clear()
|
| 360 |
+
return True
|
| 361 |
+
else:
|
| 362 |
+
self.clear()
|
| 363 |
+
|
| 364 |
+
elif x.position == HandPosition.ZOOM_OUT_START and HandPosition.UP_START2 in self:
|
| 365 |
+
start_index = self.index_position(HandPosition.UP_START2)
|
| 366 |
+
if (
|
| 367 |
+
self.swipe_distance(self._deque[start_index], x)
|
| 368 |
+
and self.check_vertical_swipe(self._deque[start_index], x)
|
| 369 |
+
):
|
| 370 |
+
self.action = Event.SWIPE_UP2
|
| 371 |
+
self.clear()
|
| 372 |
+
return True
|
| 373 |
+
else:
|
| 374 |
+
self.clear()
|
| 375 |
+
|
| 376 |
+
elif self.action == Event.DRAG and x.gesture in [35, 31, 36, 17]: # [stop, palm, stop_inverted, grabbing]
|
| 377 |
+
self.action = Event.DROP
|
| 378 |
+
self.clear()
|
| 379 |
+
return True
|
| 380 |
+
return False
|
| 381 |
+
|
| 382 |
+
@staticmethod
|
| 383 |
+
def check_horizontal_swipe(start_hand, x):
|
| 384 |
+
"""
|
| 385 |
+
Check if swipe is horizontal.
|
| 386 |
+
|
| 387 |
+
Parameters
|
| 388 |
+
----------
|
| 389 |
+
start_hand : Hand
|
| 390 |
+
Hand object of start position of swipe.
|
| 391 |
+
|
| 392 |
+
x : Hand
|
| 393 |
+
Hand object of end position of swipe.
|
| 394 |
+
|
| 395 |
+
Returns
|
| 396 |
+
-------
|
| 397 |
+
bool
|
| 398 |
+
True if swipe is horizontal.
|
| 399 |
+
|
| 400 |
+
"""
|
| 401 |
+
boundary = [start_hand.bbox[1], start_hand.bbox[3]]
|
| 402 |
+
if boundary[0] < x.center[1] < boundary[1]:
|
| 403 |
+
return True
|
| 404 |
+
else:
|
| 405 |
+
return False
|
| 406 |
+
|
| 407 |
+
@staticmethod
|
| 408 |
+
def check_vertical_swipe(start_hand, x):
|
| 409 |
+
"""
|
| 410 |
+
Check if swipe is vertical.
|
| 411 |
+
|
| 412 |
+
Parameters
|
| 413 |
+
----------
|
| 414 |
+
start_hand : Hand
|
| 415 |
+
Hand object of start position of swipe.
|
| 416 |
+
|
| 417 |
+
x : Hand
|
| 418 |
+
Hand object of end position of swipe.
|
| 419 |
+
|
| 420 |
+
Returns
|
| 421 |
+
-------
|
| 422 |
+
bool
|
| 423 |
+
True if swipe is vertical.
|
| 424 |
+
|
| 425 |
+
"""
|
| 426 |
+
boundary = [start_hand.bbox[0], start_hand.bbox[2]]
|
| 427 |
+
if boundary[0] < x.center[0] < boundary[1]:
|
| 428 |
+
return True
|
| 429 |
+
else:
|
| 430 |
+
return False
|
| 431 |
+
|
| 432 |
+
def __contains__(self, item):
|
| 433 |
+
for x in self._deque:
|
| 434 |
+
if x.position == item:
|
| 435 |
+
return True
|
| 436 |
+
|
| 437 |
+
def set_hand_position(self, hand: Hand):
|
| 438 |
+
"""
|
| 439 |
+
Set hand position.
|
| 440 |
+
|
| 441 |
+
Parameters
|
| 442 |
+
----------
|
| 443 |
+
hand : Hand
|
| 444 |
+
Hand object.
|
| 445 |
+
"""
|
| 446 |
+
if hand.gesture in [31, 35, 36]: # [palm, stop, stop_inv]
|
| 447 |
+
if HandPosition.DOWN_START in self:
|
| 448 |
+
hand.position = HandPosition.UP_END
|
| 449 |
+
else:
|
| 450 |
+
hand.position = HandPosition.UP_START
|
| 451 |
+
|
| 452 |
+
elif hand.gesture == 0: # hand_down
|
| 453 |
+
if HandPosition.UP_START in self:
|
| 454 |
+
hand.position = HandPosition.DOWN_END
|
| 455 |
+
else:
|
| 456 |
+
hand.position = HandPosition.DOWN_START
|
| 457 |
+
|
| 458 |
+
elif hand.gesture == 1: # hand_right
|
| 459 |
+
if HandPosition.LEFT_START in self:
|
| 460 |
+
hand.position = HandPosition.RIGHT_END
|
| 461 |
+
else:
|
| 462 |
+
hand.position = HandPosition.RIGHT_START
|
| 463 |
+
|
| 464 |
+
elif hand.gesture == 2: # hand_left
|
| 465 |
+
if HandPosition.RIGHT_START in self:
|
| 466 |
+
hand.position = HandPosition.LEFT_END
|
| 467 |
+
else:
|
| 468 |
+
hand.position = HandPosition.LEFT_START
|
| 469 |
+
|
| 470 |
+
elif hand.gesture == 30: # one
|
| 471 |
+
if HandPosition.FAST_SWIPE_UP_START in self:
|
| 472 |
+
hand.position = HandPosition.FAST_SWIPE_UP_END
|
| 473 |
+
else:
|
| 474 |
+
hand.position = HandPosition.FAST_SWIPE_DOWN_START
|
| 475 |
+
|
| 476 |
+
elif hand.gesture == 19: # point
|
| 477 |
+
if HandPosition.FAST_SWIPE_DOWN_START in self:
|
| 478 |
+
hand.position = HandPosition.FAST_SWIPE_DOWN_END
|
| 479 |
+
else:
|
| 480 |
+
hand.position = HandPosition.FAST_SWIPE_UP_START
|
| 481 |
+
|
| 482 |
+
elif hand.gesture == 17: # grabbing
|
| 483 |
+
hand.position = HandPosition.DRAG_START
|
| 484 |
+
|
| 485 |
+
elif hand.gesture == 25: # fist
|
| 486 |
+
if HandPosition.ZOOM_OUT_START in self:
|
| 487 |
+
hand.position = HandPosition.ZOOM_OUT_END
|
| 488 |
+
else:
|
| 489 |
+
hand.position = HandPosition.ZOOM_IN_START
|
| 490 |
+
|
| 491 |
+
elif hand.gesture == 3: # thumb_index
|
| 492 |
+
if HandPosition.ZOOM_IN_START in self:
|
| 493 |
+
hand.position = HandPosition.ZOOM_IN_END
|
| 494 |
+
else:
|
| 495 |
+
hand.position = HandPosition.ZOOM_OUT_START
|
| 496 |
+
|
| 497 |
+
elif hand.gesture == 38: # three2
|
| 498 |
+
if HandPosition.ZOOM_IN_START in self:
|
| 499 |
+
hand.position = HandPosition.ZOOM_IN_END
|
| 500 |
+
else:
|
| 501 |
+
hand.position = HandPosition.ZOOM_OUT_START
|
| 502 |
+
|
| 503 |
+
elif hand.gesture == 5: # thumb_right
|
| 504 |
+
if HandPosition.LEFT_START2 in self:
|
| 505 |
+
hand.position = HandPosition.RIGHT_END2
|
| 506 |
+
else:
|
| 507 |
+
hand.position = HandPosition.RIGHT_START2
|
| 508 |
+
|
| 509 |
+
elif hand.gesture == 4: # thumb_left
|
| 510 |
+
if HandPosition.RIGHT_START2 in self:
|
| 511 |
+
hand.position = HandPosition.LEFT_END2
|
| 512 |
+
else:
|
| 513 |
+
hand.position = HandPosition.LEFT_START2
|
| 514 |
+
|
| 515 |
+
elif hand.gesture == 15: # two_right
|
| 516 |
+
if HandPosition.LEFT_START3 in self:
|
| 517 |
+
hand.position = HandPosition.RIGHT_END3
|
| 518 |
+
else:
|
| 519 |
+
hand.position = HandPosition.RIGHT_START3
|
| 520 |
+
|
| 521 |
+
elif hand.gesture == 14: # two_left
|
| 522 |
+
if HandPosition.RIGHT_START3 in self:
|
| 523 |
+
hand.position = HandPosition.LEFT_END3
|
| 524 |
+
else:
|
| 525 |
+
hand.position = HandPosition.LEFT_START3
|
| 526 |
+
|
| 527 |
+
elif hand.gesture == 39: # two_up
|
| 528 |
+
if HandPosition.DOWN_START3 in self:
|
| 529 |
+
hand.position = HandPosition.UP_END3
|
| 530 |
+
else:
|
| 531 |
+
hand.position = HandPosition.UP_START3
|
| 532 |
+
|
| 533 |
+
elif hand.gesture == 16: # two_down
|
| 534 |
+
if HandPosition.UP_START3 in self:
|
| 535 |
+
hand.position = HandPosition.DOWN_END3
|
| 536 |
+
else:
|
| 537 |
+
hand.position = HandPosition.DOWN_START3
|
| 538 |
+
|
| 539 |
+
elif hand.gesture == 6: # thumb_down
|
| 540 |
+
if HandPosition.ZOOM_OUT_START in self:
|
| 541 |
+
hand.position = HandPosition.DOWN_END2
|
| 542 |
+
else:
|
| 543 |
+
hand.position = HandPosition.UP_START2
|
| 544 |
+
else:
|
| 545 |
+
hand.position = HandPosition.UNKNOWN
|
| 546 |
+
|
| 547 |
+
def swipe_distance(
|
| 548 |
+
self,
|
| 549 |
+
first_hand: Hand,
|
| 550 |
+
last_hand: Hand,
|
| 551 |
+
):
|
| 552 |
+
"""
|
| 553 |
+
Check if swipe distance is more than min_distance.
|
| 554 |
+
|
| 555 |
+
Parameters
|
| 556 |
+
----------
|
| 557 |
+
first_hand : Hand
|
| 558 |
+
Hand object of start position of swipe.
|
| 559 |
+
|
| 560 |
+
last_hand : Hand
|
| 561 |
+
Hand object of end position of swipe.
|
| 562 |
+
|
| 563 |
+
Returns
|
| 564 |
+
-------
|
| 565 |
+
bool
|
| 566 |
+
True if swipe distance is more than min_distance.
|
| 567 |
+
|
| 568 |
+
"""
|
| 569 |
+
hand_dist = distance.euclidean(first_hand.center, last_hand.center)
|
| 570 |
+
hand_size = (first_hand.size + last_hand.size) / 2
|
| 571 |
+
return hand_dist / hand_size > self.min_absolute_distance
|
| 572 |
+
|
| 573 |
+
def clear(self):
|
| 574 |
+
self._deque.clear()
|
| 575 |
+
|
| 576 |
+
def copy(self):
|
| 577 |
+
return self._deque.copy()
|
| 578 |
+
|
| 579 |
+
def count(self, x):
|
| 580 |
+
return self._deque.count(x)
|
| 581 |
+
|
| 582 |
+
def extend(self, iterable):
|
| 583 |
+
self._deque.extend(iterable)
|
| 584 |
+
|
| 585 |
+
def insert(self, i, x):
|
| 586 |
+
self._deque.insert(i, x)
|
| 587 |
+
|
| 588 |
+
def pop(self):
|
| 589 |
+
return self._deque.pop()
|
| 590 |
+
|
| 591 |
+
def remove(self, value):
|
| 592 |
+
self._deque.remove(value)
|
| 593 |
+
|
| 594 |
+
def reverse(self):
|
| 595 |
+
self._deque.reverse()
|
| 596 |
+
|
| 597 |
+
def __str__(self):
|
| 598 |
+
return f"Deque({[hand.gesture for hand in self._deque]})"
|