Fucius's picture
Upload 422 files
2eafbc4 verified
raw
history blame
3.24 kB
from typing import Any, List, Set
from networkx import DiGraph
from inference.enterprise.workflows.entities.outputs import JsonField
from inference.enterprise.workflows.entities.validators import is_selector
from inference.enterprise.workflows.entities.workflows_specification import (
InputType,
StepType,
)
def get_input_parameters_selectors(inputs: List[InputType]) -> Set[str]:
return {
construct_input_selector(input_name=input_definition.name)
for input_definition in inputs
}
def construct_input_selector(input_name: str) -> str:
return f"$inputs.{input_name}"
def get_steps_selectors(steps: List[StepType]) -> Set[str]:
return {construct_step_selector(step_name=step.name) for step in steps}
def construct_step_selector(step_name: str) -> str:
return f"$steps.{step_name}"
def get_steps_input_selectors(steps: List[StepType]) -> Set[str]:
result = set()
for step in steps:
result.update(get_step_input_selectors(step=step))
return result
def get_step_input_selectors(step: StepType) -> Set[str]:
result = set()
for step_input_name in step.get_input_names():
step_input = getattr(step, step_input_name)
if not issubclass(type(step_input), list):
step_input = [step_input]
for element in step_input:
if not is_selector(selector_or_value=element):
continue
result.add(element)
return result
def get_steps_output_selectors(steps: List[StepType]) -> Set[str]:
result = set()
for step in steps:
for output_name in step.get_output_names():
result.add(f"$steps.{step.name}.{output_name}")
return result
def get_output_names(outputs: List[JsonField]) -> Set[str]:
return {construct_output_name(name=output.name) for output in outputs}
def construct_output_name(name: str) -> str:
return f"$outputs.{name}"
def get_output_selectors(outputs: List[JsonField]) -> Set[str]:
return {output.selector for output in outputs}
def is_input_selector(selector_or_value: Any) -> bool:
if not is_selector(selector_or_value=selector_or_value):
return False
return selector_or_value.startswith("$inputs")
def construct_selector_pointing_step_output(selector: str, new_output: str) -> str:
if is_step_output_selector(selector_or_value=selector):
selector = get_step_selector_from_its_output(step_output_selector=selector)
return f"{selector}.{new_output}"
def is_step_output_selector(selector_or_value: Any) -> bool:
if not is_selector(selector_or_value=selector_or_value):
return False
return (
selector_or_value.startswith("$steps.")
and len(selector_or_value.split(".")) == 3
)
def get_step_selector_from_its_output(step_output_selector: str) -> str:
return ".".join(step_output_selector.split(".")[:2])
def get_nodes_of_specific_kind(execution_graph: DiGraph, kind: str) -> Set[str]:
return {
node[0]
for node in execution_graph.nodes(data=True)
if node[1].get("kind") == kind
}
def is_condition_step(execution_graph: DiGraph, node: str) -> bool:
return execution_graph.nodes[node]["definition"].type == "Condition"