from typing import Any, List, Set from networkx import DiGraph from inference.enterprise.workflows.entities.outputs import JsonField from inference.enterprise.workflows.entities.validators import is_selector from inference.enterprise.workflows.entities.workflows_specification import ( InputType, StepType, ) def get_input_parameters_selectors(inputs: List[InputType]) -> Set[str]: return { construct_input_selector(input_name=input_definition.name) for input_definition in inputs } def construct_input_selector(input_name: str) -> str: return f"$inputs.{input_name}" def get_steps_selectors(steps: List[StepType]) -> Set[str]: return {construct_step_selector(step_name=step.name) for step in steps} def construct_step_selector(step_name: str) -> str: return f"$steps.{step_name}" def get_steps_input_selectors(steps: List[StepType]) -> Set[str]: result = set() for step in steps: result.update(get_step_input_selectors(step=step)) return result def get_step_input_selectors(step: StepType) -> Set[str]: result = set() for step_input_name in step.get_input_names(): step_input = getattr(step, step_input_name) if not issubclass(type(step_input), list): step_input = [step_input] for element in step_input: if not is_selector(selector_or_value=element): continue result.add(element) return result def get_steps_output_selectors(steps: List[StepType]) -> Set[str]: result = set() for step in steps: for output_name in step.get_output_names(): result.add(f"$steps.{step.name}.{output_name}") return result def get_output_names(outputs: List[JsonField]) -> Set[str]: return {construct_output_name(name=output.name) for output in outputs} def construct_output_name(name: str) -> str: return f"$outputs.{name}" def get_output_selectors(outputs: List[JsonField]) -> Set[str]: return {output.selector for output in outputs} def is_input_selector(selector_or_value: Any) -> bool: if not is_selector(selector_or_value=selector_or_value): return False return selector_or_value.startswith("$inputs") def construct_selector_pointing_step_output(selector: str, new_output: str) -> str: if is_step_output_selector(selector_or_value=selector): selector = get_step_selector_from_its_output(step_output_selector=selector) return f"{selector}.{new_output}" def is_step_output_selector(selector_or_value: Any) -> bool: if not is_selector(selector_or_value=selector_or_value): return False return ( selector_or_value.startswith("$steps.") and len(selector_or_value.split(".")) == 3 ) def get_step_selector_from_its_output(step_output_selector: str) -> str: return ".".join(step_output_selector.split(".")[:2]) def get_nodes_of_specific_kind(execution_graph: DiGraph, kind: str) -> Set[str]: return { node[0] for node in execution_graph.nodes(data=True) if node[1].get("kind") == kind } def is_condition_step(execution_graph: DiGraph, node: str) -> bool: return execution_graph.nodes[node]["definition"].type == "Condition"