Spaces:
Running
on
Zero
Running
on
Zero
from typing import Any, List, Set | |
from networkx import DiGraph | |
from inference.enterprise.workflows.entities.outputs import JsonField | |
from inference.enterprise.workflows.entities.validators import is_selector | |
from inference.enterprise.workflows.entities.workflows_specification import ( | |
InputType, | |
StepType, | |
) | |
def get_input_parameters_selectors(inputs: List[InputType]) -> Set[str]: | |
return { | |
construct_input_selector(input_name=input_definition.name) | |
for input_definition in inputs | |
} | |
def construct_input_selector(input_name: str) -> str: | |
return f"$inputs.{input_name}" | |
def get_steps_selectors(steps: List[StepType]) -> Set[str]: | |
return {construct_step_selector(step_name=step.name) for step in steps} | |
def construct_step_selector(step_name: str) -> str: | |
return f"$steps.{step_name}" | |
def get_steps_input_selectors(steps: List[StepType]) -> Set[str]: | |
result = set() | |
for step in steps: | |
result.update(get_step_input_selectors(step=step)) | |
return result | |
def get_step_input_selectors(step: StepType) -> Set[str]: | |
result = set() | |
for step_input_name in step.get_input_names(): | |
step_input = getattr(step, step_input_name) | |
if not issubclass(type(step_input), list): | |
step_input = [step_input] | |
for element in step_input: | |
if not is_selector(selector_or_value=element): | |
continue | |
result.add(element) | |
return result | |
def get_steps_output_selectors(steps: List[StepType]) -> Set[str]: | |
result = set() | |
for step in steps: | |
for output_name in step.get_output_names(): | |
result.add(f"$steps.{step.name}.{output_name}") | |
return result | |
def get_output_names(outputs: List[JsonField]) -> Set[str]: | |
return {construct_output_name(name=output.name) for output in outputs} | |
def construct_output_name(name: str) -> str: | |
return f"$outputs.{name}" | |
def get_output_selectors(outputs: List[JsonField]) -> Set[str]: | |
return {output.selector for output in outputs} | |
def is_input_selector(selector_or_value: Any) -> bool: | |
if not is_selector(selector_or_value=selector_or_value): | |
return False | |
return selector_or_value.startswith("$inputs") | |
def construct_selector_pointing_step_output(selector: str, new_output: str) -> str: | |
if is_step_output_selector(selector_or_value=selector): | |
selector = get_step_selector_from_its_output(step_output_selector=selector) | |
return f"{selector}.{new_output}" | |
def is_step_output_selector(selector_or_value: Any) -> bool: | |
if not is_selector(selector_or_value=selector_or_value): | |
return False | |
return ( | |
selector_or_value.startswith("$steps.") | |
and len(selector_or_value.split(".")) == 3 | |
) | |
def get_step_selector_from_its_output(step_output_selector: str) -> str: | |
return ".".join(step_output_selector.split(".")[:2]) | |
def get_nodes_of_specific_kind(execution_graph: DiGraph, kind: str) -> Set[str]: | |
return { | |
node[0] | |
for node in execution_graph.nodes(data=True) | |
if node[1].get("kind") == kind | |
} | |
def is_condition_step(execution_graph: DiGraph, node: str) -> bool: | |
return execution_graph.nodes[node]["definition"].type == "Condition" | |