Fucius's picture
Upload 422 files
2eafbc4 verified
import time
from datetime import datetime
from functools import partial
from queue import Queue
from threading import Thread
from typing import Callable, Generator, List, Optional, Tuple, Union
from inference.core import logger
from inference.core.active_learning.middlewares import (
NullActiveLearningMiddleware,
ThreadingActiveLearningMiddleware,
)
from inference.core.cache import cache
from inference.core.env import (
ACTIVE_LEARNING_ENABLED,
API_KEY,
API_KEY_ENV_NAMES,
DISABLE_PREPROC_AUTO_ORIENT,
PREDICTIONS_QUEUE_SIZE,
RESTART_ATTEMPT_DELAY,
)
from inference.core.exceptions import MissingApiKeyError
from inference.core.interfaces.camera.entities import (
StatusUpdate,
UpdateSeverity,
VideoFrame,
)
from inference.core.interfaces.camera.exceptions import SourceConnectionError
from inference.core.interfaces.camera.utils import get_video_frames_generator
from inference.core.interfaces.camera.video_source import (
BufferConsumptionStrategy,
BufferFillingStrategy,
VideoSource,
)
from inference.core.interfaces.stream.entities import (
ModelConfig,
ObjectDetectionPrediction,
)
from inference.core.interfaces.stream.sinks import active_learning_sink, multi_sink
from inference.core.interfaces.stream.watchdog import (
NullPipelineWatchdog,
PipelineWatchDog,
)
from inference.core.models.roboflow import OnnxRoboflowInferenceModel
from inference.models.utils import get_roboflow_model
INFERENCE_PIPELINE_CONTEXT = "inference_pipeline"
SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT = "SOURCE_CONNECTION_ATTEMPT_FAILED"
SOURCE_CONNECTION_LOST_EVENT = "SOURCE_CONNECTION_LOST"
INFERENCE_RESULTS_DISPATCHING_ERROR_EVENT = "INFERENCE_RESULTS_DISPATCHING_ERROR"
INFERENCE_THREAD_STARTED_EVENT = "INFERENCE_THREAD_STARTED"
INFERENCE_THREAD_FINISHED_EVENT = "INFERENCE_THREAD_FINISHED"
INFERENCE_COMPLETED_EVENT = "INFERENCE_COMPLETED"
INFERENCE_ERROR_EVENT = "INFERENCE_ERROR"
class InferencePipeline:
@classmethod
def init(
cls,
model_id: str,
video_reference: Union[str, int],
on_prediction: Callable[[ObjectDetectionPrediction, VideoFrame], None],
api_key: Optional[str] = None,
max_fps: Optional[Union[float, int]] = None,
watchdog: Optional[PipelineWatchDog] = None,
status_update_handlers: Optional[List[Callable[[StatusUpdate], None]]] = None,
source_buffer_filling_strategy: Optional[BufferFillingStrategy] = None,
source_buffer_consumption_strategy: Optional[BufferConsumptionStrategy] = None,
class_agnostic_nms: Optional[bool] = None,
confidence: Optional[float] = None,
iou_threshold: Optional[float] = None,
max_candidates: Optional[int] = None,
max_detections: Optional[int] = None,
mask_decode_mode: Optional[str] = "accurate",
tradeoff_factor: Optional[float] = 0.0,
active_learning_enabled: Optional[bool] = None,
) -> "InferencePipeline":
"""
This class creates the abstraction for making inferences from CV models against video stream.
It allows to choose Object Detection model from Roboflow platform and run predictions against
video streams - just by the price of specifying which model to use and what to do with predictions.
It allows to set the model post-processing parameters (via .init() or env) and intercept updates
related to state of pipeline via `PipelineWatchDog` abstraction (although that is something probably
useful only for advanced use-cases).
For maximum efficiency, all separate chunks of processing: video decoding, inference, results dispatching
are handled by separate threads.
Given that reference to stream is passed and connectivity is lost - it attempts to re-connect with delay.
Since version 0.9.11 it works not only for object detection models but is also compatible with stubs,
classification, instance-segmentation and keypoint-detection models.
Args:
model_id (str): Name and version of model at Roboflow platform (example: "my-model/3")
video_reference (Union[str, int]): Reference of source to be used to make predictions against.
It can be video file path, stream URL and device (like camera) id (we handle whatever cv2 handles).
on_prediction (Callable[ObjectDetectionPrediction, VideoFrame], None]): Function to be called
once prediction is ready - passing both decoded frame, their metadata and dict with standard
Roboflow Object Detection prediction.
api_key (Optional[str]): Roboflow API key - if not passed - will be looked in env under "ROBOFLOW_API_KEY"
and "API_KEY" variables. API key, passed in some form is required.
max_fps (Optional[Union[float, int]]): Specific value passed as this parameter will be used to
dictate max FPS of processing. It can be useful if we wanted to run concurrent inference pipelines
on single machine making tradeoff between number of frames and number of streams handled. Disabled
by default.
watchdog (Optional[PipelineWatchDog]): Implementation of class that allows profiling of
inference pipeline - if not given null implementation (doing nothing) will be used.
status_update_handlers (Optional[List[Callable[[StatusUpdate], None]]]): List of handlers to intercept
status updates of all elements of the pipeline. Should be used only if detailed inspection of
pipeline behaviour in time is needed. Please point out that handlers should be possible to be executed
fast - otherwise they will impair pipeline performance. All errors will be logged as warnings
without re-raising. Default: None.
source_buffer_filling_strategy (Optional[BufferFillingStrategy]): Parameter dictating strategy for
video stream decoding behaviour. By default - tweaked to the type of source given.
Please find detailed explanation in docs of [`VideoSource`](../camera/video_source.py)
source_buffer_consumption_strategy (Optional[BufferConsumptionStrategy]): Parameter dictating strategy for
video stream frames consumption. By default - tweaked to the type of source given.
Please find detailed explanation in docs of [`VideoSource`](../camera/video_source.py)
class_agnostic_nms (Optional[bool]): Parameter of model post-processing. If not given - value checked in
env variable "CLASS_AGNOSTIC_NMS" with default "False"
confidence (Optional[float]): Parameter of model post-processing. If not given - value checked in
env variable "CONFIDENCE" with default "0.5"
iou_threshold (Optional[float]): Parameter of model post-processing. If not given - value checked in
env variable "IOU_THRESHOLD" with default "0.5"
max_candidates (Optional[int]): Parameter of model post-processing. If not given - value checked in
env variable "MAX_CANDIDATES" with default "3000"
max_detections (Optional[int]): Parameter of model post-processing. If not given - value checked in
env variable "MAX_DETECTIONS" with default "300"
mask_decode_mode: (Optional[str]): Parameter of model post-processing. If not given - model "accurate" is
used. Applicable for instance segmentation models
tradeoff_factor (Optional[float]): Parameter of model post-processing. If not 0.0 - model default is used.
Applicable for instance segmentation models
active_learning_enabled (Optional[bool]): Flag to enable / disable Active Learning middleware (setting it
true does not guarantee any data to be collected, as data collection is controlled by Roboflow backend -
it just enables middleware intercepting predictions). If not given, env variable
`ACTIVE_LEARNING_ENABLED` will be used. Please point out that Active Learning will be forcefully
disabled in a scenario when Roboflow API key is not given, as Roboflow account is required
for this feature to be operational.
Other ENV variables involved in low-level configuration:
* INFERENCE_PIPELINE_PREDICTIONS_QUEUE_SIZE - size of buffer for predictions that are ready for dispatching
* INFERENCE_PIPELINE_RESTART_ATTEMPT_DELAY - delay for restarts on stream connection drop
* ACTIVE_LEARNING_ENABLED - controls Active Learning middleware if explicit parameter not given
Returns: Instance of InferencePipeline
Throws:
* SourceConnectionError if source cannot be connected at start, however it attempts to reconnect
always if connection to stream is lost.
"""
if api_key is None:
api_key = API_KEY
if status_update_handlers is None:
status_update_handlers = []
inference_config = ModelConfig.init(
class_agnostic_nms=class_agnostic_nms,
confidence=confidence,
iou_threshold=iou_threshold,
max_candidates=max_candidates,
max_detections=max_detections,
mask_decode_mode=mask_decode_mode,
tradeoff_factor=tradeoff_factor,
)
model = get_roboflow_model(model_id=model_id, api_key=api_key)
if watchdog is None:
watchdog = NullPipelineWatchdog()
status_update_handlers.append(watchdog.on_status_update)
video_source = VideoSource.init(
video_reference=video_reference,
status_update_handlers=status_update_handlers,
buffer_filling_strategy=source_buffer_filling_strategy,
buffer_consumption_strategy=source_buffer_consumption_strategy,
)
watchdog.register_video_source(video_source=video_source)
predictions_queue = Queue(maxsize=PREDICTIONS_QUEUE_SIZE)
active_learning_middleware = NullActiveLearningMiddleware()
if active_learning_enabled is None:
logger.info(
f"`active_learning_enabled` parameter not set - using env `ACTIVE_LEARNING_ENABLED` "
f"with value: {ACTIVE_LEARNING_ENABLED}"
)
active_learning_enabled = ACTIVE_LEARNING_ENABLED
if api_key is None:
logger.info(
f"Roboflow API key not given - Active Learning is forced to be disabled."
)
active_learning_enabled = False
if active_learning_enabled is True:
active_learning_middleware = ThreadingActiveLearningMiddleware.init(
api_key=api_key,
model_id=model_id,
cache=cache,
)
al_sink = partial(
active_learning_sink,
active_learning_middleware=active_learning_middleware,
model_type=model.task_type,
disable_preproc_auto_orient=DISABLE_PREPROC_AUTO_ORIENT,
)
logger.info(
"AL enabled - wrapping `on_prediction` with multi_sink() and active_learning_sink()"
)
on_prediction = partial(multi_sink, sinks=[on_prediction, al_sink])
return cls(
model=model,
video_source=video_source,
on_prediction=on_prediction,
max_fps=max_fps,
predictions_queue=predictions_queue,
watchdog=watchdog,
status_update_handlers=status_update_handlers,
inference_config=inference_config,
active_learning_middleware=active_learning_middleware,
)
def __init__(
self,
model: OnnxRoboflowInferenceModel,
video_source: VideoSource,
on_prediction: Callable[[ObjectDetectionPrediction, VideoFrame], None],
max_fps: Optional[float],
predictions_queue: Queue,
watchdog: PipelineWatchDog,
status_update_handlers: List[Callable[[StatusUpdate], None]],
inference_config: ModelConfig,
active_learning_middleware: Union[
NullActiveLearningMiddleware, ThreadingActiveLearningMiddleware
],
):
self._model = model
self._video_source = video_source
self._on_prediction = on_prediction
self._max_fps = max_fps
self._predictions_queue = predictions_queue
self._watchdog = watchdog
self._command_handler_thread: Optional[Thread] = None
self._inference_thread: Optional[Thread] = None
self._dispatching_thread: Optional[Thread] = None
self._stop = False
self._camera_restart_ongoing = False
self._status_update_handlers = status_update_handlers
self._inference_config = inference_config
self._active_learning_middleware = active_learning_middleware
def start(self, use_main_thread: bool = True) -> None:
self._stop = False
self._inference_thread = Thread(target=self._execute_inference)
self._inference_thread.start()
if self._active_learning_middleware is not None:
self._active_learning_middleware.start_registration_thread()
if use_main_thread:
self._dispatch_inference_results()
else:
self._dispatching_thread = Thread(target=self._dispatch_inference_results)
self._dispatching_thread.start()
def terminate(self) -> None:
self._stop = True
self._video_source.terminate()
def pause_stream(self) -> None:
self._video_source.pause()
def mute_stream(self) -> None:
self._video_source.mute()
def resume_stream(self) -> None:
self._video_source.resume()
def join(self) -> None:
if self._inference_thread is not None:
self._inference_thread.join()
self._inference_thread = None
if self._dispatching_thread is not None:
self._dispatching_thread.join()
self._dispatching_thread = None
if self._active_learning_middleware is not None:
self._active_learning_middleware.stop_registration_thread()
def _execute_inference(self) -> None:
send_inference_pipeline_status_update(
severity=UpdateSeverity.INFO,
event_type=INFERENCE_THREAD_STARTED_EVENT,
status_update_handlers=self._status_update_handlers,
)
logger.info(f"Inference thread started")
try:
for video_frame in self._generate_frames():
self._watchdog.on_model_preprocessing_started(
frame_timestamp=video_frame.frame_timestamp,
frame_id=video_frame.frame_id,
)
preprocessed_image, preprocessing_metadata = self._model.preprocess(
video_frame.image
)
self._watchdog.on_model_inference_started(
frame_timestamp=video_frame.frame_timestamp,
frame_id=video_frame.frame_id,
)
predictions = self._model.predict(preprocessed_image)
self._watchdog.on_model_postprocessing_started(
frame_timestamp=video_frame.frame_timestamp,
frame_id=video_frame.frame_id,
)
postprocessing_args = self._inference_config.to_postprocessing_params()
predictions = self._model.postprocess(
predictions,
preprocessing_metadata,
**postprocessing_args,
)
if issubclass(type(predictions), list):
predictions = predictions[0].dict(
by_alias=True,
exclude_none=True,
)
self._watchdog.on_model_prediction_ready(
frame_timestamp=video_frame.frame_timestamp,
frame_id=video_frame.frame_id,
)
self._predictions_queue.put((predictions, video_frame))
send_inference_pipeline_status_update(
severity=UpdateSeverity.DEBUG,
event_type=INFERENCE_COMPLETED_EVENT,
payload={
"frame_id": video_frame.frame_id,
"frame_timestamp": video_frame.frame_timestamp,
},
status_update_handlers=self._status_update_handlers,
)
except Exception as error:
payload = {
"error_type": error.__class__.__name__,
"error_message": str(error),
"error_context": "inference_thread",
}
send_inference_pipeline_status_update(
severity=UpdateSeverity.ERROR,
event_type=INFERENCE_ERROR_EVENT,
payload=payload,
status_update_handlers=self._status_update_handlers,
)
logger.exception(f"Encountered inference error: {error}")
finally:
self._predictions_queue.put(None)
send_inference_pipeline_status_update(
severity=UpdateSeverity.INFO,
event_type=INFERENCE_THREAD_FINISHED_EVENT,
status_update_handlers=self._status_update_handlers,
)
logger.info(f"Inference thread finished")
def _dispatch_inference_results(self) -> None:
while True:
inference_results: Optional[Tuple[dict, VideoFrame]] = (
self._predictions_queue.get()
)
if inference_results is None:
self._predictions_queue.task_done()
break
predictions, video_frame = inference_results
try:
self._on_prediction(predictions, video_frame)
except Exception as error:
payload = {
"error_type": error.__class__.__name__,
"error_message": str(error),
"error_context": "inference_results_dispatching",
}
send_inference_pipeline_status_update(
severity=UpdateSeverity.ERROR,
event_type=INFERENCE_RESULTS_DISPATCHING_ERROR_EVENT,
payload=payload,
status_update_handlers=self._status_update_handlers,
)
logger.warning(f"Error in results dispatching - {error}")
finally:
self._predictions_queue.task_done()
def _generate_frames(
self,
) -> Generator[VideoFrame, None, None]:
self._video_source.start()
while True:
source_properties = self._video_source.describe_source().source_properties
if source_properties is None:
break
allow_reconnect = not source_properties.is_file
yield from get_video_frames_generator(
video=self._video_source, max_fps=self._max_fps
)
if not allow_reconnect:
self.terminate()
break
if self._stop:
break
logger.warning(f"Lost connection with video source.")
send_inference_pipeline_status_update(
severity=UpdateSeverity.WARNING,
event_type=SOURCE_CONNECTION_LOST_EVENT,
payload={
"source_reference": self._video_source.describe_source().source_reference
},
status_update_handlers=self._status_update_handlers,
)
self._attempt_restart()
def _attempt_restart(self) -> None:
succeeded = False
while not self._stop and not succeeded:
try:
self._video_source.restart()
succeeded = True
except SourceConnectionError as error:
payload = {
"error_type": error.__class__.__name__,
"error_message": str(error),
"error_context": "video_frames_generator",
}
send_inference_pipeline_status_update(
severity=UpdateSeverity.WARNING,
event_type=SOURCE_CONNECTION_ATTEMPT_FAILED_EVENT,
payload=payload,
status_update_handlers=self._status_update_handlers,
)
logger.warning(
f"Could not connect to video source. Retrying in {RESTART_ATTEMPT_DELAY}s..."
)
time.sleep(RESTART_ATTEMPT_DELAY)
def send_inference_pipeline_status_update(
severity: UpdateSeverity,
event_type: str,
status_update_handlers: List[Callable[[StatusUpdate], None]],
payload: Optional[dict] = None,
sub_context: Optional[str] = None,
) -> None:
if payload is None:
payload = {}
context = INFERENCE_PIPELINE_CONTEXT
if sub_context is not None:
context = f"{context}.{sub_context}"
status_update = StatusUpdate(
timestamp=datetime.now(),
severity=severity,
event_type=event_type,
payload=payload,
context=context,
)
for handler in status_update_handlers:
try:
handler(status_update)
except Exception as error:
logger.warning(f"Could not execute handler update. Cause: {error}")