File size: 6,929 Bytes
2eafbc4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
from typing import Any, List, Optional, Union

from pydantic import Field, root_validator, validator

from inference.core.entities.requests.inference import (
    BaseRequest,
    InferenceRequestImage,
)
from inference.core.env import SAM_VERSION_ID


class SamInferenceRequest(BaseRequest):
    """SAM inference request.

    Attributes:
        api_key (Optional[str]): Roboflow API Key.
        sam_version_id (Optional[str]): The version ID of SAM to be used for this request.
    """

    sam_version_id: Optional[str] = Field(
        default=SAM_VERSION_ID,
        examples=["vit_h"],
        description="The version ID of SAM to be used for this request. Must be one of vit_h, vit_l, or vit_b.",
    )

    model_id: Optional[str] = Field(None)

    # TODO[pydantic]: We couldn't refactor the `validator`, please replace it by `field_validator` manually.
    # Check https://docs.pydantic.dev/dev-v2/migration/#changes-to-validators for more information.
    @validator("model_id", always=True)
    def validate_model_id(cls, value, values):
        if value is not None:
            return value
        if values.get("sam_version_id") is None:
            return None
        return f"sam/{values['sam_version_id']}"


class SamEmbeddingRequest(SamInferenceRequest):
    """SAM embedding request.

    Attributes:
        image (Optional[inference.core.entities.requests.inference.InferenceRequestImage]): The image to be embedded.
        image_id (Optional[str]): The ID of the image to be embedded used to cache the embedding.
        format (Optional[str]): The format of the response. Must be one of json or binary.
    """

    image: Optional[InferenceRequestImage] = Field(
        default=None,
        description="The image to be embedded",
    )
    image_id: Optional[str] = Field(
        default=None,
        examples=["image_id"],
        description="The ID of the image to be embedded used to cache the embedding.",
    )
    format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the response. Must be one of json or binary. If binary, embedding is returned as a binary numpy array.",
    )


class SamSegmentationRequest(SamInferenceRequest):
    """SAM segmentation request.

    Attributes:
        embeddings (Optional[Union[List[List[List[List[float]]]], Any]]): The embeddings to be decoded.
        embeddings_format (Optional[str]): The format of the embeddings.
        format (Optional[str]): The format of the response.
        image (Optional[InferenceRequestImage]): The image to be segmented.
        image_id (Optional[str]): The ID of the image to be segmented used to retrieve cached embeddings.
        has_mask_input (Optional[bool]): Whether or not the request includes a mask input.
        mask_input (Optional[Union[List[List[List[float]]], Any]]): The set of output masks.
        mask_input_format (Optional[str]): The format of the mask input.
        orig_im_size (Optional[List[int]]): The original size of the image used to generate the embeddings.
        point_coords (Optional[List[List[float]]]): The coordinates of the interactive points used during decoding.
        point_labels (Optional[List[float]]): The labels of the interactive points used during decoding.
        use_mask_input_cache (Optional[bool]): Whether or not to use the mask input cache.
    """

    embeddings: Optional[Union[List[List[List[List[float]]]], Any]] = Field(
        None,
        examples=["[[[[0.1, 0.2, 0.3, ...] ...] ...]]"],
        description="The embeddings to be decoded. The dimensions of the embeddings are 1 x 256 x 64 x 64. If embeddings is not provided, image must be provided.",
    )
    embeddings_format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the embeddings. Must be one of json or binary. If binary, embeddings are expected to be a binary numpy array.",
    )
    format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the response. Must be one of json or binary. If binary, masks are returned as binary numpy arrays. If json, masks are converted to polygons, then returned as json.",
    )
    image: Optional[InferenceRequestImage] = Field(
        default=None,
        description="The image to be segmented. Only required if embeddings are not provided.",
    )
    image_id: Optional[str] = Field(
        default=None,
        examples=["image_id"],
        description="The ID of the image to be segmented used to retrieve cached embeddings. If an embedding is cached, it will be used instead of generating a new embedding. If no embedding is cached, a new embedding will be generated and cached.",
    )
    has_mask_input: Optional[bool] = Field(
        default=False,
        examples=[True],
        description="Whether or not the request includes a mask input. If true, the mask input must be provided.",
    )
    mask_input: Optional[Union[List[List[List[float]]], Any]] = Field(
        default=None,
        description="The set of output masks. If request format is json, masks is a list of polygons, where each polygon is a list of points, where each point is a tuple containing the x,y pixel coordinates of the point. If request format is binary, masks is a list of binary numpy arrays. The dimensions of each mask are 256 x 256. This is the same as the output, low resolution mask from the previous inference.",
    )
    mask_input_format: Optional[str] = Field(
        default="json",
        examples=["json"],
        description="The format of the mask input. Must be one of json or binary. If binary, mask input is expected to be a binary numpy array.",
    )
    orig_im_size: Optional[List[int]] = Field(
        default=None,
        examples=[[640, 320]],
        description="The original size of the image used to generate the embeddings. This is only required if the image is not provided.",
    )
    point_coords: Optional[List[List[float]]] = Field(
        default=[[0.0, 0.0]],
        examples=[[[10.0, 10.0]]],
        description="The coordinates of the interactive points used during decoding. Each point (x,y pair) corresponds to a label in point_labels.",
    )
    point_labels: Optional[List[float]] = Field(
        default=[-1],
        examples=[[1]],
        description="The labels of the interactive points used during decoding. A 1 represents a positive point (part of the object to be segmented). A -1 represents a negative point (not part of the object to be segmented). Each label corresponds to a point in point_coords.",
    )
    use_mask_input_cache: Optional[bool] = Field(
        default=True,
        examples=[True],
        description="Whether or not to use the mask input cache. If true, the mask input cache will be used if it exists. If false, the mask input cache will not be used.",
    )