|
import keyword |
|
import re |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from typing import Optional |
|
|
|
from .structs import CallType, InputField, ModelStep, OutputField, Workflow |
|
from .utils import detect_cycles |
|
|
|
SUPPORTED_TYPES = {"str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"} |
|
|
|
|
|
MAX_FIELD_NAME_LENGTH = 50 |
|
MAX_DESCRIPTION_LENGTH = 200 |
|
MAX_SYSTEM_PROMPT_LENGTH = 4000 |
|
MAX_TEMPERATURE = 10.0 |
|
|
|
|
|
class ValidationErrorType(Enum): |
|
"""Types of validation errors that can occur""" |
|
|
|
STEP = "step" |
|
DAG = "dag" |
|
VARIABLE = "variable" |
|
TYPE = "type" |
|
GENERAL = "general" |
|
NAMING = "naming" |
|
LENGTH = "length" |
|
RANGE = "range" |
|
|
|
|
|
@dataclass |
|
class ValidationError: |
|
"""Represents a validation error with type and message""" |
|
|
|
error_type: ValidationErrorType |
|
message: str |
|
step_id: Optional[str] = None |
|
field_name: Optional[str] = None |
|
|
|
|
|
class WorkflowValidationError(ValueError): |
|
"""Base class for workflow validation errors""" |
|
|
|
def __init__(self, errors: list[ValidationError]): |
|
self.errors = errors |
|
super().__init__(f"Workflow validation failed with {len(errors)} errors") |
|
|
|
|
|
def _parse_variable_reference(var: str) -> tuple[Optional[str], str]: |
|
"""Extracts step_id and field_name from variable reference""" |
|
parts = var.split(".") |
|
if len(parts) == 1: |
|
return None, parts[0] |
|
return parts[0], parts[1] |
|
|
|
|
|
def _get_step_dependencies(step: ModelStep) -> set[str]: |
|
"""Gets set of step IDs that this step depends on""" |
|
deps = set() |
|
for field in step.input_fields: |
|
step_id, _ = _parse_variable_reference(field.variable) |
|
if step_id: |
|
deps.add(step_id) |
|
return deps |
|
|
|
|
|
def create_step_dep_graph(workflow: Workflow) -> dict[str, set[str]]: |
|
"""Creates a dependency graph of steps""" |
|
dep_graph: dict[str, set[str]] = {} |
|
for step_id, step in workflow.steps.items(): |
|
dep_graph[step_id] = _get_step_dependencies(step) |
|
return dep_graph |
|
|
|
|
|
class WorkflowValidator: |
|
"""Validates workflows for correctness and consistency""" |
|
|
|
def __init__( |
|
self, |
|
min_temperature: float = 0, |
|
max_temperature: float = MAX_TEMPERATURE, |
|
max_field_name_length: int = MAX_FIELD_NAME_LENGTH, |
|
max_description_length: int = MAX_DESCRIPTION_LENGTH, |
|
max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH, |
|
): |
|
self.errors: list[ValidationError] = [] |
|
self.workflow: Optional[Workflow] = None |
|
self.min_temperature = min_temperature |
|
self.max_temperature = max_temperature |
|
self.max_field_name_length = max_field_name_length |
|
self.max_description_length = max_description_length |
|
self.max_system_prompt_length = max_system_prompt_length |
|
|
|
def validate(self, workflow: Workflow) -> bool: |
|
"""Main validation entry point""" |
|
self.errors = [] |
|
self.workflow = workflow |
|
|
|
|
|
if not self._validate_workflow_basic(workflow): |
|
return False |
|
|
|
|
|
if len(workflow.steps) == 1: |
|
return self.validate_simple_workflow(workflow) |
|
|
|
|
|
return self.validate_complex_workflow(workflow) |
|
|
|
def validate_simple_workflow(self, workflow: Workflow) -> bool: |
|
"""Validates a single-step workflow""" |
|
if not self.workflow: |
|
return False |
|
|
|
|
|
step = next(iter(workflow.steps.values())) |
|
|
|
|
|
if not self._validate_step(step): |
|
return False |
|
|
|
|
|
for input_var in workflow.inputs: |
|
if not self._is_valid_external_input(input_var): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}") |
|
) |
|
return False |
|
|
|
|
|
for output_name, output_var in workflow.outputs.items(): |
|
if not output_var: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}") |
|
) |
|
return False |
|
|
|
|
|
if not self._is_valid_variable_reference(output_var): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}") |
|
) |
|
return False |
|
|
|
|
|
_, field_name = _parse_variable_reference(output_var) |
|
if not any(field.name == field_name for field in step.output_fields): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.VARIABLE, |
|
f"Output field '{field_name}' not found in step '{step.id}'", |
|
step.id, |
|
field_name, |
|
) |
|
) |
|
return False |
|
|
|
return True |
|
|
|
def validate_complex_workflow(self, workflow: Workflow) -> bool: |
|
"""Validates a multi-step workflow""" |
|
if not self.workflow: |
|
return False |
|
|
|
|
|
for step in workflow.steps.values(): |
|
if not self._validate_step(step): |
|
return False |
|
|
|
|
|
for input_var in workflow.inputs: |
|
if not self._is_valid_external_input(input_var): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}") |
|
) |
|
return False |
|
|
|
|
|
for output_name, output_var in workflow.outputs.items(): |
|
if not output_var: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}") |
|
) |
|
return False |
|
|
|
if not self._is_valid_variable_reference(output_var): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}") |
|
) |
|
return False |
|
|
|
|
|
step_id, field_name = _parse_variable_reference(output_var) |
|
if step_id not in workflow.steps: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found") |
|
) |
|
return False |
|
|
|
ref_step = workflow.steps[step_id] |
|
if not any(field.name == field_name for field in ref_step.output_fields): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.VARIABLE, |
|
f"Output field '{field_name}' not found in step '{step_id}'", |
|
step_id, |
|
field_name, |
|
) |
|
) |
|
return False |
|
|
|
dep_graph = create_step_dep_graph(workflow) |
|
if cycle_step_id := detect_cycles(dep_graph): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.DAG, f"Circular dependency detected involving step: {cycle_step_id}" |
|
) |
|
) |
|
return False |
|
|
|
|
|
used_steps = set() |
|
for deps in dep_graph.values(): |
|
used_steps.update(deps) |
|
for step_id in workflow.steps: |
|
if step_id not in used_steps and not any( |
|
output_var and _parse_variable_reference(output_var)[0] == step_id |
|
for output_var in workflow.outputs.values() |
|
): |
|
self.errors.append(ValidationError(ValidationErrorType.DAG, f"Orphaned step detected: {step_id}")) |
|
return False |
|
|
|
|
|
if not self._validate_variable_dependencies(workflow): |
|
return False |
|
|
|
return True |
|
|
|
def _validate_workflow_basic(self, workflow: Workflow) -> bool: |
|
"""Validates basic workflow properties""" |
|
|
|
if not workflow.inputs: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input") |
|
) |
|
return False |
|
|
|
if not workflow.outputs: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output") |
|
) |
|
return False |
|
|
|
for output_var in workflow.outputs.values(): |
|
if output_var is None: |
|
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Output variable cannot be None")) |
|
return False |
|
|
|
|
|
if not workflow.steps: |
|
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step")) |
|
return False |
|
|
|
|
|
for step_id, step in workflow.steps.items(): |
|
if step_id != step.id: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.STEP, f"Step ID mismatch: {step_id} != {step.id}", step_id) |
|
) |
|
return False |
|
return True |
|
|
|
def _validate_step(self, step: ModelStep) -> bool: |
|
"""Validates a single step""" |
|
|
|
if not step.id or not step.name or not step.model or not step.provider or not step.call_type: |
|
self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id)) |
|
return False |
|
|
|
|
|
if not self._is_valid_identifier(step.id): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.NAMING, |
|
f"Invalid step ID format: {step.id}. Must be a valid identifier.", |
|
step.id, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if step.call_type == CallType.LLM: |
|
if step.temperature is None: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.STEP, "LLM step must specify temperature", step.id) |
|
) |
|
return False |
|
|
|
if not self.min_temperature <= step.temperature <= self.max_temperature: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.RANGE, |
|
f"Temperature must be between {self.min_temperature} and {self.max_temperature}", |
|
step.id, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if step.call_type == CallType.LLM: |
|
if not step.system_prompt: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.STEP, "LLM step must specify system prompt", step.id) |
|
) |
|
return False |
|
|
|
if len(step.system_prompt) > self.max_system_prompt_length: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.LENGTH, |
|
f"System prompt exceeds maximum length of {self.max_system_prompt_length} characters", |
|
step.id, |
|
) |
|
) |
|
return False |
|
|
|
|
|
input_names = set() |
|
for field in step.input_fields: |
|
if not self._validate_input_field(field): |
|
return False |
|
if field.name in input_names: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.STEP, f"Duplicate input field name: {field.name}", step.id, field.name |
|
) |
|
) |
|
return False |
|
input_names.add(field.name) |
|
|
|
|
|
output_names = set() |
|
for field in step.output_fields: |
|
if not self._validate_output_field(field): |
|
return False |
|
if field.name in output_names: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.STEP, f"Duplicate output field name: {field.name}", step.id, field.name |
|
) |
|
) |
|
return False |
|
output_names.add(field.name) |
|
|
|
return True |
|
|
|
def _validate_input_field(self, field: InputField) -> bool: |
|
"""Validates an input field""" |
|
|
|
if not field.name or not field.description or not field.variable: |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.STEP, "Input field missing required fields", field_name=field.name) |
|
) |
|
return False |
|
|
|
|
|
if not self._is_valid_identifier(field.name): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.NAMING, |
|
f"Invalid field name format: {field.name}. Must be a valid Python identifier.", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if len(field.name) > self.max_field_name_length: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.LENGTH, |
|
f"Field name exceeds maximum length of {self.max_field_name_length} characters", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if len(field.description) > self.max_description_length: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.LENGTH, |
|
f"Description exceeds maximum length of {self.max_description_length} characters", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if not self._is_valid_variable_reference(field.variable): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.VARIABLE, |
|
f"Invalid variable reference: {field.variable}", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
return True |
|
|
|
def _validate_output_field(self, field: OutputField) -> bool: |
|
"""Validates an output field""" |
|
|
|
if not field.name or not field.description: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.STEP, "Output field missing required fields", field_name=field.name |
|
) |
|
) |
|
return False |
|
|
|
|
|
if not self._is_valid_identifier(field.name): |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.NAMING, |
|
f"Invalid field name format: {field.name}. Must be a valid Python identifier.", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if len(field.name) > self.max_field_name_length: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.LENGTH, |
|
f"Field name exceeds maximum length of {self.max_field_name_length} characters", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if len(field.description) > self.max_description_length: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.LENGTH, |
|
f"Description exceeds maximum length of {self.max_description_length} characters", |
|
field_name=field.name, |
|
) |
|
) |
|
return False |
|
|
|
|
|
if field.type not in SUPPORTED_TYPES: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.TYPE, f"Unsupported output type: {field.type}", field_name=field.name |
|
) |
|
) |
|
return False |
|
|
|
return True |
|
|
|
def _validate_simple_workflow_variables(self, workflow: Workflow) -> bool: |
|
"""Validates variables in a simple workflow""" |
|
step = next(iter(workflow.steps.values())) |
|
|
|
|
|
for input_var in workflow.inputs: |
|
if not self._is_valid_external_input(input_var): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}") |
|
) |
|
return False |
|
|
|
|
|
for output_name, output_var in workflow.outputs.items(): |
|
if output_var and not self._is_valid_variable_reference(output_var): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}") |
|
) |
|
return False |
|
|
|
return True |
|
|
|
def _validate_variable_dependencies(self, workflow: Workflow) -> bool: |
|
"""Validates variable dependencies between steps""" |
|
|
|
var_graph: dict[str, set[str]] = {} |
|
|
|
def create_var_dep_graph(workflow: Workflow) -> dict[str, set[str]]: |
|
var_graph: dict[str, set[str]] = {} |
|
for step_id, step in workflow.steps.items(): |
|
for field in step.input_fields: |
|
if field.variable not in var_graph: |
|
var_graph[field.variable] = set() |
|
|
|
for output in step.output_fields: |
|
var_graph[field.variable].add(f"{step_id}.{output.name}") |
|
return var_graph |
|
|
|
|
|
var_graph = create_var_dep_graph(workflow) |
|
if cycle_var := detect_cycles(var_graph): |
|
self.errors.append( |
|
ValidationError(ValidationErrorType.VARIABLE, f"Circular variable dependency detected: {cycle_var}") |
|
) |
|
return False |
|
|
|
|
|
external_inputs = set(workflow.inputs) |
|
for step in workflow.steps.values(): |
|
for field in step.input_fields: |
|
step_id, field_name = _parse_variable_reference(field.variable) |
|
if not step_id and field_name not in external_inputs: |
|
self.errors.append( |
|
ValidationError( |
|
ValidationErrorType.VARIABLE, |
|
f"External input '{field_name}' not found in workflow inputs", |
|
field_name=field_name, |
|
) |
|
) |
|
return False |
|
|
|
return True |
|
|
|
def _is_valid_variable_reference(self, var: str) -> bool: |
|
"""Validates if a variable reference is properly formatted""" |
|
if not self.workflow: |
|
return False |
|
parts = var.split(".") |
|
if len(parts) == 1: |
|
return True |
|
if len(parts) != 2: |
|
return False |
|
step_id, field_name = parts |
|
return step_id in self.workflow.steps and any( |
|
field.name == field_name for field in self.workflow.steps[step_id].output_fields |
|
) |
|
|
|
def _is_valid_external_input(self, var: str) -> bool: |
|
"""Validates if a variable is a valid external input""" |
|
if not var: |
|
return False |
|
if not self._is_valid_identifier(var): |
|
return False |
|
if keyword.iskeyword(var): |
|
return False |
|
if "." in var: |
|
return False |
|
return True |
|
|
|
def _is_valid_identifier(self, name: str) -> bool: |
|
"""Validates if a string is a valid Python identifier""" |
|
if name and name.strip(): |
|
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name)) |
|
return False |
|
|