Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any, List, Optional, Union | |
from uuid import uuid4 | |
from pydantic import BaseModel, ConfigDict, Field | |
from inference.core.entities.common import ApiKey, ModelID, ModelType | |
class BaseRequest(BaseModel): | |
"""Base request for inference. | |
Attributes: | |
id (str_): A unique request identifier. | |
api_key (Optional[str]): Roboflow API Key that will be passed to the model during initialization for artifact retrieval. | |
start (Optional[float]): start time of request | |
""" | |
def __init__(self, **kwargs): | |
kwargs["id"] = str(uuid4()) | |
super().__init__(**kwargs) | |
model_config = ConfigDict(protected_namespaces=()) | |
id: str | |
api_key: Optional[str] = ApiKey | |
start: Optional[float] = None | |
source: Optional[str] = None | |
source_info: Optional[str] = None | |
class InferenceRequest(BaseRequest): | |
"""Base request for inference. | |
Attributes: | |
model_id (str): A unique model identifier. | |
model_type (Optional[str]): The type of the model, usually referring to what task the model performs. | |
""" | |
model_id: Optional[str] = ModelID | |
model_type: Optional[str] = ModelType | |
class InferenceRequestImage(BaseModel): | |
"""Image data for inference request. | |
Attributes: | |
type (str): The type of image data provided, one of 'url', 'base64', or 'numpy'. | |
value (Optional[Any]): Image data corresponding to the image type. | |
""" | |
type: str = Field( | |
examples=["url"], | |
description="The type of image data provided, one of 'url', 'base64', or 'numpy'", | |
) | |
value: Optional[Any] = Field( | |
None, | |
examples=["http://www.example-image-url.com"], | |
description="Image data corresponding to the image type, if type = 'url' then value is a string containing the url of an image, else if type = 'base64' then value is a string containing base64 encoded image data, else if type = 'numpy' then value is binary numpy data serialized using pickle.dumps(); array should 3 dimensions, channels last, with values in the range [0,255].", | |
) | |
class CVInferenceRequest(InferenceRequest): | |
"""Computer Vision inference request. | |
Attributes: | |
image (Union[List[InferenceRequestImage], InferenceRequestImage]): Image(s) for inference. | |
disable_preproc_auto_orient (Optional[bool]): If true, the auto orient preprocessing step is disabled for this call. Default is False. | |
disable_preproc_contrast (Optional[bool]): If true, the auto contrast preprocessing step is disabled for this call. Default is False. | |
disable_preproc_grayscale (Optional[bool]): If true, the grayscale preprocessing step is disabled for this call. Default is False. | |
disable_preproc_static_crop (Optional[bool]): If true, the static crop preprocessing step is disabled for this call. Default is False. | |
""" | |
image: Union[List[InferenceRequestImage], InferenceRequestImage] | |
disable_preproc_auto_orient: Optional[bool] = Field( | |
default=False, | |
description="If true, the auto orient preprocessing step is disabled for this call.", | |
) | |
disable_preproc_contrast: Optional[bool] = Field( | |
default=False, | |
description="If true, the auto contrast preprocessing step is disabled for this call.", | |
) | |
disable_preproc_grayscale: Optional[bool] = Field( | |
default=False, | |
description="If true, the grayscale preprocessing step is disabled for this call.", | |
) | |
disable_preproc_static_crop: Optional[bool] = Field( | |
default=False, | |
description="If true, the static crop preprocessing step is disabled for this call.", | |
) | |
class ObjectDetectionInferenceRequest(CVInferenceRequest): | |
"""Object Detection inference request. | |
Attributes: | |
class_agnostic_nms (Optional[bool]): If true, NMS is applied to all detections at once, if false, NMS is applied per class. | |
class_filter (Optional[List[str]]): If provided, only predictions for the listed classes will be returned. | |
confidence (Optional[float]): The confidence threshold used to filter out predictions. | |
fix_batch_size (Optional[bool]): If true, the batch size will be fixed to the maximum batch size configured for this server. | |
iou_threshold (Optional[float]): The IoU threshold that must be met for a box pair to be considered duplicate during NMS. | |
max_detections (Optional[int]): The maximum number of detections that will be returned. | |
max_candidates (Optional[int]): The maximum number of candidate detections passed to NMS. | |
visualization_labels (Optional[bool]): If true, labels will be rendered on prediction visualizations. | |
visualization_stroke_width (Optional[int]): The stroke width used when visualizing predictions. | |
visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string. | |
""" | |
class_agnostic_nms: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, NMS is applied to all detections at once, if false, NMS is applied per class", | |
) | |
class_filter: Optional[List[str]] = Field( | |
default=None, | |
examples=[["class-1", "class-2", "class-n"]], | |
description="If provided, only predictions for the listed classes will be returned", | |
) | |
confidence: Optional[float] = Field( | |
default=0.4, | |
examples=[0.5], | |
description="The confidence threshold used to filter out predictions", | |
) | |
fix_batch_size: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, the batch size will be fixed to the maximum batch size configured for this server", | |
) | |
iou_threshold: Optional[float] = Field( | |
default=0.3, | |
examples=[0.5], | |
description="The IoU threhsold that must be met for a box pair to be considered duplicate during NMS", | |
) | |
max_detections: Optional[int] = Field( | |
default=300, | |
examples=[300], | |
description="The maximum number of detections that will be returned", | |
) | |
max_candidates: Optional[int] = Field( | |
default=3000, | |
description="The maximum number of candidate detections passed to NMS", | |
) | |
visualization_labels: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, labels will be rendered on prediction visualizations", | |
) | |
visualization_stroke_width: Optional[int] = Field( | |
default=1, | |
examples=[1], | |
description="The stroke width used when visualizing predictions", | |
) | |
visualize_predictions: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, the predictions will be drawn on the original image and returned as a base64 string", | |
) | |
disable_active_learning: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)", | |
) | |
class KeypointsDetectionInferenceRequest(ObjectDetectionInferenceRequest): | |
keypoint_confidence: Optional[float] = Field( | |
default=0.0, | |
examples=[0.5], | |
description="The confidence threshold used to filter out non visible keypoints", | |
) | |
class InstanceSegmentationInferenceRequest(ObjectDetectionInferenceRequest): | |
"""Instance Segmentation inference request. | |
Attributes: | |
mask_decode_mode (Optional[str]): The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'. | |
tradeoff_factor (Optional[float]): The amount to tradeoff between 0='fast' and 1='accurate'. | |
""" | |
mask_decode_mode: Optional[str] = Field( | |
default="accurate", | |
examples=["accurate"], | |
description="The mode used to decode instance segmentation masks, one of 'accurate', 'fast', 'tradeoff'", | |
) | |
tradeoff_factor: Optional[float] = Field( | |
default=0.0, | |
examples=[0.5], | |
description="The amount to tradeoff between 0='fast' and 1='accurate'", | |
) | |
class ClassificationInferenceRequest(CVInferenceRequest): | |
"""Classification inference request. | |
Attributes: | |
confidence (Optional[float]): The confidence threshold used to filter out predictions. | |
visualization_stroke_width (Optional[int]): The stroke width used when visualizing predictions. | |
visualize_predictions (Optional[bool]): If true, the predictions will be drawn on the original image and returned as a base64 string. | |
""" | |
confidence: Optional[float] = Field( | |
default=0.4, | |
examples=[0.5], | |
description="The confidence threshold used to filter out predictions", | |
) | |
visualization_stroke_width: Optional[int] = Field( | |
default=1, | |
examples=[1], | |
description="The stroke width used when visualizing predictions", | |
) | |
visualize_predictions: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, the predictions will be drawn on the original image and returned as a base64 string", | |
) | |
disable_active_learning: Optional[bool] = Field( | |
default=False, | |
examples=[False], | |
description="If true, the predictions will be prevented from registration by Active Learning (if the functionality is enabled)", | |
) | |
def request_from_type(model_type, request_dict): | |
"""Uses original request id""" | |
if model_type == "classification": | |
request = ClassificationInferenceRequest(**request_dict) | |
elif model_type == "instance-segmentation": | |
request = InstanceSegmentationInferenceRequest(**request_dict) | |
elif model_type == "object-detection": | |
request = ObjectDetectionInferenceRequest(**request_dict) | |
else: | |
raise ValueError(f"Uknown task type {model_type}") | |
request.id = request_dict.get("id") | |
return request | |