Fucius's picture
Upload 422 files
2eafbc4 verified
from typing import Any, Dict, Generator, Iterable, List, TypeVar, Union
import numpy as np
from inference.enterprise.workflows.complier.steps_executors.types import OutputsLookup
from inference.enterprise.workflows.complier.utils import (
get_step_selector_from_its_output,
is_input_selector,
is_step_output_selector,
)
from inference.enterprise.workflows.entities.steps import (
AbsoluteStaticCrop,
ActiveLearningDataCollector,
ClipComparison,
Crop,
OCRModel,
RelativeStaticCrop,
RoboflowModel,
YoloWorld,
)
from inference.enterprise.workflows.entities.validators import (
get_last_selector_chunk,
is_selector,
)
from inference.enterprise.workflows.errors import ExecutionGraphError
T = TypeVar("T")
def get_image(
step: Union[
RoboflowModel,
OCRModel,
Crop,
AbsoluteStaticCrop,
RelativeStaticCrop,
ClipComparison,
ActiveLearningDataCollector,
YoloWorld,
],
runtime_parameters: Dict[str, Any],
outputs_lookup: OutputsLookup,
) -> List[Dict[str, Union[str, np.ndarray]]]:
if is_input_selector(selector_or_value=step.image):
return runtime_parameters[get_last_selector_chunk(selector=step.image)]
if is_step_output_selector(selector_or_value=step.image):
step_selector = get_step_selector_from_its_output(
step_output_selector=step.image
)
step_output = outputs_lookup[step_selector]
return step_output[get_last_selector_chunk(selector=step.image)]
raise ExecutionGraphError("Cannot find image")
def resolve_parameter(
selector_or_value: Any,
runtime_parameters: Dict[str, Any],
outputs_lookup: OutputsLookup,
) -> Any:
if not is_selector(selector_or_value=selector_or_value):
return selector_or_value
if is_step_output_selector(selector_or_value=selector_or_value):
step_selector = get_step_selector_from_its_output(
step_output_selector=selector_or_value
)
step_output = outputs_lookup[step_selector]
if issubclass(type(step_output), list):
return [
e[get_last_selector_chunk(selector=selector_or_value)]
for e in step_output
]
return step_output[get_last_selector_chunk(selector=selector_or_value)]
return runtime_parameters[get_last_selector_chunk(selector=selector_or_value)]
def make_batches(
iterable: Iterable[T], batch_size: int
) -> Generator[List[T], None, None]:
batch_size = max(batch_size, 1)
batch = []
for element in iterable:
batch.append(element)
if len(batch) >= batch_size:
yield batch
batch = []
if len(batch) > 0:
yield batch