Maharshi Gor
commited on
Commit
·
e272e20
1
Parent(s):
7985347
Added more validation layers
Browse files
src/components/model_pipeline/model_pipeline.py
CHANGED
@@ -8,12 +8,11 @@ from components import typed_dicts as td
|
|
8 |
from components.model_pipeline.state_manager import (
|
9 |
BasePipelineValidator,
|
10 |
PipelineStateManager,
|
11 |
-
TossupPipelineStateManager,
|
12 |
)
|
13 |
from components.model_step.model_step import ModelStepComponent
|
14 |
-
from components.structs import ModelStepUIState, PipelineState, PipelineUIState
|
15 |
from components.utils import make_state
|
16 |
-
from workflows.structs import ModelStep,
|
17 |
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
18 |
|
19 |
from .state_manager import get_output_panel_state
|
@@ -24,6 +23,8 @@ DEFAULT_MAX_TEMPERATURE = 5.0
|
|
24 |
class PipelineInterface:
|
25 |
"""UI for the pipeline."""
|
26 |
|
|
|
|
|
27 |
def __init__(
|
28 |
self,
|
29 |
app: gr.Blocks,
|
@@ -43,17 +44,14 @@ class PipelineInterface:
|
|
43 |
self.workflow_state = make_state(workflow.model_dump())
|
44 |
self.variables_state = make_state(workflow.get_available_variables())
|
45 |
self.output_panel_state = make_state(get_output_panel_state(workflow))
|
|
|
46 |
|
47 |
# Maintains the toggle state change for pipeline changes through user input.
|
48 |
self.pipeline_change = gr.State(False)
|
49 |
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
else:
|
54 |
-
pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
|
55 |
-
self.sm = PipelineStateManager(validator)
|
56 |
-
self.pipeline_state = make_state(pipeline_state.model_dump())
|
57 |
|
58 |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
59 |
"""Get the auxiliary states for the pipeline."""
|
@@ -164,20 +162,15 @@ class PipelineInterface:
|
|
164 |
)
|
165 |
return add_step_btn
|
166 |
|
167 |
-
def
|
168 |
"""Validate the workflow."""
|
169 |
try:
|
170 |
state = self.sm.make_pipeline_state(state_dict)
|
171 |
-
|
172 |
-
|
173 |
-
)
|
174 |
-
if not validator.validate(state.workflow):
|
175 |
-
raise WorkflowValidationError(validator.errors)
|
176 |
except ValueError as e:
|
177 |
logger.exception(e)
|
178 |
-
|
179 |
-
logger.error(f"Could not validate workflow: \n{state_dict_str}")
|
180 |
-
raise gr.Error(e)
|
181 |
|
182 |
def _render_pipeline_header(self):
|
183 |
# Add Step button at top
|
@@ -270,9 +263,8 @@ class PipelineInterface:
|
|
270 |
)
|
271 |
|
272 |
# Connect the export button to show the workflow JSON
|
273 |
-
self.add_triggers_for_pipeline_export(
|
274 |
-
|
275 |
-
fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]
|
276 |
)
|
277 |
|
278 |
def render(self):
|
@@ -315,21 +307,30 @@ class PipelineInterface:
|
|
315 |
|
316 |
self._render_pipeline_preview()
|
317 |
|
318 |
-
def add_triggers_for_pipeline_export(
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
js = None
|
320 |
if scroll:
|
321 |
js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}"
|
322 |
|
323 |
# TODO: modify this validate workflow to user input level and not executable label.
|
324 |
# (workflows that can be converted to UI interface, no logical validation)
|
325 |
-
gr.on(
|
326 |
triggers,
|
327 |
-
self.
|
328 |
inputs=[input_pipeline_state],
|
329 |
-
outputs=[],
|
330 |
).success(
|
331 |
fn=self.sm.get_formatted_config,
|
332 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
333 |
outputs=[self.config_output, self.error_display],
|
334 |
js=js,
|
335 |
)
|
|
|
|
|
|
|
|
8 |
from components.model_pipeline.state_manager import (
|
9 |
BasePipelineValidator,
|
10 |
PipelineStateManager,
|
|
|
11 |
)
|
12 |
from components.model_step.model_step import ModelStepComponent
|
13 |
+
from components.structs import ModelStepUIState, PipelineState, PipelineUIState
|
14 |
from components.utils import make_state
|
15 |
+
from workflows.structs import ModelStep, Workflow
|
16 |
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
17 |
|
18 |
from .state_manager import get_output_panel_state
|
|
|
23 |
class PipelineInterface:
|
24 |
"""UI for the pipeline."""
|
25 |
|
26 |
+
state_manager_cls = PipelineStateManager
|
27 |
+
|
28 |
def __init__(
|
29 |
self,
|
30 |
app: gr.Blocks,
|
|
|
44 |
self.workflow_state = make_state(workflow.model_dump())
|
45 |
self.variables_state = make_state(workflow.get_available_variables())
|
46 |
self.output_panel_state = make_state(get_output_panel_state(workflow))
|
47 |
+
self.ui_validator = validator
|
48 |
|
49 |
# Maintains the toggle state change for pipeline changes through user input.
|
50 |
self.pipeline_change = gr.State(False)
|
51 |
|
52 |
+
self.sm = self.state_manager_cls(validator)
|
53 |
+
pipeline_state_dict = self.sm.create_pipeline_state_dict(workflow=workflow, ui_state=ui_state)
|
54 |
+
self.pipeline_state = make_state(pipeline_state_dict)
|
|
|
|
|
|
|
|
|
55 |
|
56 |
def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
|
57 |
"""Get the auxiliary states for the pipeline."""
|
|
|
162 |
)
|
163 |
return add_step_btn
|
164 |
|
165 |
+
def validate_workflow_ui(self, state_dict: td.PipelineStateDict):
|
166 |
"""Validate the workflow."""
|
167 |
try:
|
168 |
state = self.sm.make_pipeline_state(state_dict)
|
169 |
+
self.ui_validator(state.workflow)
|
170 |
+
return gr.update(visible=False)
|
|
|
|
|
|
|
171 |
except ValueError as e:
|
172 |
logger.exception(e)
|
173 |
+
return gr.update(visible=True, value=str(e))
|
|
|
|
|
174 |
|
175 |
def _render_pipeline_header(self):
|
176 |
# Add Step button at top
|
|
|
263 |
)
|
264 |
|
265 |
# Connect the export button to show the workflow JSON
|
266 |
+
self.add_triggers_for_pipeline_export(
|
267 |
+
[export_btn.click], self.pipeline_state, scroll=True, expand_accordion=True
|
|
|
268 |
)
|
269 |
|
270 |
def render(self):
|
|
|
307 |
|
308 |
self._render_pipeline_preview()
|
309 |
|
310 |
+
def add_triggers_for_pipeline_export(
|
311 |
+
self,
|
312 |
+
triggers: list,
|
313 |
+
input_pipeline_state: gr.State,
|
314 |
+
scroll: bool = False,
|
315 |
+
expand_accordion: bool = False,
|
316 |
+
):
|
317 |
js = None
|
318 |
if scroll:
|
319 |
js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}"
|
320 |
|
321 |
# TODO: modify this validate workflow to user input level and not executable label.
|
322 |
# (workflows that can be converted to UI interface, no logical validation)
|
323 |
+
event = gr.on(
|
324 |
triggers,
|
325 |
+
self.validate_workflow_ui,
|
326 |
inputs=[input_pipeline_state],
|
327 |
+
outputs=[self.error_display],
|
328 |
).success(
|
329 |
fn=self.sm.get_formatted_config,
|
330 |
inputs=[self.pipeline_state, gr.State("yaml")],
|
331 |
outputs=[self.config_output, self.error_display],
|
332 |
js=js,
|
333 |
)
|
334 |
+
|
335 |
+
if expand_accordion:
|
336 |
+
event.then(fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion])
|
src/components/model_pipeline/state_manager.py
CHANGED
@@ -11,7 +11,7 @@ from pydantic import BaseModel, ValidationError
|
|
11 |
from app_configs import UNSELECTED_VAR_NAME
|
12 |
from components import typed_dicts as td
|
13 |
from components import utils
|
14 |
-
from components.structs import ModelStepUIState, PipelineState, TossupPipelineState
|
15 |
from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
|
16 |
from workflows.factory import create_new_llm_step
|
17 |
from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
|
@@ -72,6 +72,10 @@ class PipelineStateManager:
|
|
72 |
"""Make a state from a state dictionary."""
|
73 |
return self.pipeline_state_cls(**state_dict)
|
74 |
|
|
|
|
|
|
|
|
|
75 |
def add_step(
|
76 |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
77 |
) -> td.PipelineStateDict:
|
@@ -212,6 +216,7 @@ class PipelineStateManager:
|
|
212 |
"""Update a workflow from a YAML string."""
|
213 |
try:
|
214 |
workflow = self.parse_yaml_workflow(yaml_str, strict=True)
|
|
|
215 |
self.validator and self.validator(workflow)
|
216 |
state = self.pipeline_state_cls.from_workflow(workflow)
|
217 |
return state.model_dump(), not change_state, gr.update(visible=False)
|
@@ -229,6 +234,11 @@ class TossupPipelineStateManager(PipelineStateManager):
|
|
229 |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
|
230 |
return super().make_pipeline_state(state_dict)
|
231 |
|
|
|
|
|
|
|
|
|
|
|
232 |
def update_workflow_from_code(
|
233 |
self, yaml_str: str, change_state: bool
|
234 |
) -> tuple[td.TossupPipelineStateDict, bool, dict]:
|
|
|
11 |
from app_configs import UNSELECTED_VAR_NAME
|
12 |
from components import typed_dicts as td
|
13 |
from components import utils
|
14 |
+
from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
|
15 |
from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
|
16 |
from workflows.factory import create_new_llm_step
|
17 |
from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
|
|
|
72 |
"""Make a state from a state dictionary."""
|
73 |
return self.pipeline_state_cls(**state_dict)
|
74 |
|
75 |
+
def create_pipeline_state_dict(self, workflow: Workflow, ui_state: PipelineUIState) -> td.PipelineStateDict:
|
76 |
+
"""Create a pipeline state from a workflow."""
|
77 |
+
return self.pipeline_state_cls(workflow=workflow, ui_state=ui_state).model_dump()
|
78 |
+
|
79 |
def add_step(
|
80 |
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
|
81 |
) -> td.PipelineStateDict:
|
|
|
216 |
"""Update a workflow from a YAML string."""
|
217 |
try:
|
218 |
workflow = self.parse_yaml_workflow(yaml_str, strict=True)
|
219 |
+
logger.debug(f"Validator: {self.validator}")
|
220 |
self.validator and self.validator(workflow)
|
221 |
state = self.pipeline_state_cls.from_workflow(workflow)
|
222 |
return state.model_dump(), not change_state, gr.update(visible=False)
|
|
|
234 |
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
|
235 |
return super().make_pipeline_state(state_dict)
|
236 |
|
237 |
+
def create_pipeline_state_dict(
|
238 |
+
self, workflow: TossupWorkflow, ui_state: PipelineUIState
|
239 |
+
) -> td.TossupPipelineStateDict:
|
240 |
+
return super().create_pipeline_state_dict(workflow, ui_state)
|
241 |
+
|
242 |
def update_workflow_from_code(
|
243 |
self, yaml_str: str, change_state: bool
|
244 |
) -> tuple[td.TossupPipelineStateDict, bool, dict]:
|
src/components/model_pipeline/tossup_pipeline.py
CHANGED
@@ -9,7 +9,7 @@ from display.formatting import tiny_styled_warning
|
|
9 |
from workflows.structs import Buzzer, TossupWorkflow
|
10 |
|
11 |
from .model_pipeline import PipelineInterface
|
12 |
-
from .state_manager import BasePipelineValidator
|
13 |
|
14 |
|
15 |
def toggleable_slider(
|
@@ -33,6 +33,10 @@ def toggleable_slider(
|
|
33 |
|
34 |
|
35 |
class TossupPipelineInterface(PipelineInterface):
|
|
|
|
|
|
|
|
|
36 |
def __init__(
|
37 |
self,
|
38 |
app: gr.Blocks,
|
|
|
9 |
from workflows.structs import Buzzer, TossupWorkflow
|
10 |
|
11 |
from .model_pipeline import PipelineInterface
|
12 |
+
from .state_manager import BasePipelineValidator, TossupPipelineStateManager
|
13 |
|
14 |
|
15 |
def toggleable_slider(
|
|
|
33 |
|
34 |
|
35 |
class TossupPipelineInterface(PipelineInterface):
|
36 |
+
"""UI for the tossup pipeline."""
|
37 |
+
|
38 |
+
state_manager_cls = TossupPipelineStateManager
|
39 |
+
|
40 |
def __init__(
|
41 |
self,
|
42 |
app: gr.Blocks,
|
src/components/quizbowl/validation.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
from typing import Literal
|
2 |
|
3 |
-
from app_configs import CONFIGS
|
4 |
from components.structs import PipelineState, TossupPipelineState
|
5 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
6 |
from workflows.structs import TossupWorkflow, Workflow
|
7 |
-
from workflows.validators import WorkflowValidator
|
8 |
|
9 |
|
10 |
def validate_workflow(
|
@@ -83,3 +83,14 @@ class UserInputWorkflowValidator:
|
|
83 |
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
|
84 |
f"\nMake sure you have values set for all the outputs: {default_str}"
|
85 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from typing import Literal
|
2 |
|
3 |
+
from app_configs import AVAILABLE_MODELS, CONFIGS
|
4 |
from components.structs import PipelineState, TossupPipelineState
|
5 |
from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
|
6 |
from workflows.structs import TossupWorkflow, Workflow
|
7 |
+
from workflows.validators import WorkflowValidationError, WorkflowValidator
|
8 |
|
9 |
|
10 |
def validate_workflow(
|
|
|
83 |
"\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
|
84 |
f"\nMake sure you have values set for all the outputs: {default_str}"
|
85 |
)
|
86 |
+
|
87 |
+
# Validate the workflow
|
88 |
+
allowed_model_names = AVAILABLE_MODELS.keys()
|
89 |
+
self.validator = WorkflowValidator(allowed_model_names=allowed_model_names)
|
90 |
+
try:
|
91 |
+
self.validator.validate(workflow, allow_empty=True)
|
92 |
+
except WorkflowValidationError as e:
|
93 |
+
error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n"
|
94 |
+
error_msg_list = [f"- {err.message}" for err in e.errors]
|
95 |
+
error_msg = error_msg_total + "\n".join(error_msg_list)
|
96 |
+
raise ValueError(error_msg)
|
src/workflows/validators.py
CHANGED
@@ -15,10 +15,14 @@ MAX_DESCRIPTION_LENGTH = 200
|
|
15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
16 |
MAX_TEMPERATURE = 10.0
|
17 |
|
|
|
|
|
18 |
|
19 |
class ValidationErrorType(Enum):
|
20 |
"""Types of validation errors that can occur"""
|
21 |
|
|
|
|
|
22 |
STEP = "step"
|
23 |
DAG = "dag"
|
24 |
VARIABLE = "variable"
|
@@ -38,6 +42,17 @@ class ValidationError:
|
|
38 |
step_id: Optional[str] = None
|
39 |
field_name: Optional[str] = None
|
40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
class WorkflowValidationError(ValueError):
|
43 |
"""Base class for workflow validation errors"""
|
@@ -83,6 +98,9 @@ class WorkflowValidator:
|
|
83 |
max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
|
84 |
max_description_length: int = MAX_DESCRIPTION_LENGTH,
|
85 |
max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
|
|
|
|
|
|
|
86 |
):
|
87 |
self.errors: list[ValidationError] = []
|
88 |
self.workflow: Optional[Workflow] = None
|
@@ -91,99 +109,82 @@ class WorkflowValidator:
|
|
91 |
self.max_field_name_length = max_field_name_length
|
92 |
self.max_description_length = max_description_length
|
93 |
self.max_system_prompt_length = max_system_prompt_length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
94 |
|
95 |
-
def
|
96 |
-
"""Main validation entry point
|
|
|
|
|
|
|
|
|
97 |
self.errors = []
|
98 |
self.workflow = workflow
|
99 |
|
100 |
# Basic workflow validation
|
101 |
-
if not self._validate_workflow_basic(workflow):
|
102 |
return False
|
103 |
|
104 |
# If it's a single-step workflow, use simple validation
|
105 |
if len(workflow.steps) == 1:
|
106 |
-
return self.validate_simple_workflow(workflow)
|
107 |
|
108 |
# Otherwise use complex validation
|
109 |
-
return self.validate_complex_workflow(workflow)
|
110 |
|
111 |
-
def
|
112 |
-
"""Validates
|
113 |
-
if not self.workflow:
|
114 |
-
return False
|
115 |
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
|
|
|
|
|
|
121 |
return False
|
122 |
|
123 |
-
|
124 |
-
for input_var in workflow.inputs:
|
125 |
if not self._is_valid_external_input(input_var):
|
126 |
self.errors.append(
|
127 |
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
128 |
)
|
129 |
return False
|
130 |
-
|
131 |
-
# Validate output variables
|
132 |
-
for output_name, output_var in workflow.outputs.items():
|
133 |
-
if not output_var:
|
134 |
-
self.errors.append(
|
135 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
|
136 |
-
)
|
137 |
-
return False
|
138 |
-
|
139 |
-
# Check if output variable references a valid step output
|
140 |
-
if not self._is_valid_variable_reference(output_var):
|
141 |
-
self.errors.append(
|
142 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
143 |
-
)
|
144 |
-
return False
|
145 |
-
|
146 |
-
# Verify the output field exists in the step
|
147 |
-
_, field_name = _parse_variable_reference(output_var)
|
148 |
-
if not any(field.name == field_name for field in step.output_fields):
|
149 |
-
self.errors.append(
|
150 |
-
ValidationError(
|
151 |
-
ValidationErrorType.VARIABLE,
|
152 |
-
f"Output field '{field_name}' not found in step '{step.id}'",
|
153 |
-
step.id,
|
154 |
-
field_name,
|
155 |
-
)
|
156 |
-
)
|
157 |
-
return False
|
158 |
-
|
159 |
return True
|
160 |
|
161 |
-
def
|
162 |
-
"""Validates
|
163 |
-
if not self.workflow:
|
164 |
-
return False
|
165 |
-
|
166 |
-
# Validate each step
|
167 |
-
for step in workflow.steps.values():
|
168 |
-
if not self._validate_step(step):
|
169 |
-
return False
|
170 |
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
|
|
|
|
178 |
|
179 |
# Validate output variables
|
180 |
for output_name, output_var in workflow.outputs.items():
|
|
|
181 |
if not output_var:
|
|
|
|
|
182 |
self.errors.append(
|
183 |
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
|
184 |
)
|
185 |
return False
|
186 |
|
|
|
187 |
if not self._is_valid_variable_reference(output_var):
|
188 |
self.errors.append(
|
189 |
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
@@ -192,9 +193,10 @@ class WorkflowValidator:
|
|
192 |
|
193 |
# Verify the output field exists in the referenced step
|
194 |
step_id, field_name = _parse_variable_reference(output_var)
|
|
|
195 |
if step_id not in workflow.steps:
|
196 |
self.errors.append(
|
197 |
-
ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found")
|
198 |
)
|
199 |
return False
|
200 |
|
@@ -203,12 +205,57 @@ class WorkflowValidator:
|
|
203 |
self.errors.append(
|
204 |
ValidationError(
|
205 |
ValidationErrorType.VARIABLE,
|
206 |
-
f"Output field '{field_name}' not found in step '{step_id}'",
|
207 |
step_id,
|
208 |
field_name,
|
209 |
)
|
210 |
)
|
211 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
212 |
|
213 |
dep_graph = create_step_dep_graph(workflow)
|
214 |
if cycle_step_id := detect_cycles(dep_graph):
|
@@ -237,28 +284,17 @@ class WorkflowValidator:
|
|
237 |
|
238 |
return True
|
239 |
|
240 |
-
def _validate_workflow_basic(self, workflow: Workflow) -> bool:
|
241 |
"""Validates basic workflow properties"""
|
242 |
-
# Check for atleast one input
|
243 |
-
if not workflow.inputs:
|
244 |
-
self.errors.append(
|
245 |
-
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
|
246 |
-
)
|
247 |
-
return False
|
248 |
|
249 |
-
|
250 |
-
|
251 |
-
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
|
252 |
-
)
|
253 |
return False
|
254 |
|
255 |
-
for output_var in workflow.outputs.values():
|
256 |
-
if output_var is None:
|
257 |
-
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Output variable cannot be None"))
|
258 |
-
return False
|
259 |
-
|
260 |
# Check for empty workflow
|
261 |
if not workflow.steps:
|
|
|
|
|
262 |
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
|
263 |
return False
|
264 |
|
@@ -271,10 +307,26 @@ class WorkflowValidator:
|
|
271 |
return False
|
272 |
return True
|
273 |
|
274 |
-
def _validate_step(self, step: ModelStep) -> bool:
|
275 |
"""Validates a single step"""
|
276 |
# Validate required fields
|
277 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
|
279 |
return False
|
280 |
|
@@ -328,7 +380,7 @@ class WorkflowValidator:
|
|
328 |
# Validate input fields
|
329 |
input_names = set()
|
330 |
for field in step.input_fields:
|
331 |
-
if not self._validate_input_field(field):
|
332 |
return False
|
333 |
if field.name in input_names:
|
334 |
self.errors.append(
|
@@ -342,7 +394,7 @@ class WorkflowValidator:
|
|
342 |
# Validate output fields
|
343 |
output_names = set()
|
344 |
for field in step.output_fields:
|
345 |
-
if not self._validate_output_field(field):
|
346 |
return False
|
347 |
if field.name in output_names:
|
348 |
self.errors.append(
|
@@ -355,7 +407,7 @@ class WorkflowValidator:
|
|
355 |
|
356 |
return True
|
357 |
|
358 |
-
def _validate_input_field(self, field: InputField) -> bool:
|
359 |
"""Validates an input field"""
|
360 |
# Validate required fields
|
361 |
if not field.name or not field.description or not field.variable:
|
@@ -365,7 +417,7 @@ class WorkflowValidator:
|
|
365 |
return False
|
366 |
|
367 |
# Validate field name
|
368 |
-
if not self._is_valid_identifier(field.name):
|
369 |
self.errors.append(
|
370 |
ValidationError(
|
371 |
ValidationErrorType.NAMING,
|
@@ -410,7 +462,7 @@ class WorkflowValidator:
|
|
410 |
|
411 |
return True
|
412 |
|
413 |
-
def _validate_output_field(self, field: OutputField) -> bool:
|
414 |
"""Validates an output field"""
|
415 |
# Validate required fields
|
416 |
if not field.name or not field.description:
|
@@ -422,7 +474,7 @@ class WorkflowValidator:
|
|
422 |
return False
|
423 |
|
424 |
# Validate field name
|
425 |
-
if not self._is_valid_identifier(field.name):
|
426 |
self.errors.append(
|
427 |
ValidationError(
|
428 |
ValidationErrorType.NAMING,
|
@@ -528,10 +580,12 @@ class WorkflowValidator:
|
|
528 |
|
529 |
return True
|
530 |
|
531 |
-
def _is_valid_variable_reference(self, var: str) -> bool:
|
532 |
"""Validates if a variable reference is properly formatted"""
|
533 |
if not self.workflow:
|
534 |
return False
|
|
|
|
|
535 |
parts = var.split(".")
|
536 |
if len(parts) == 1:
|
537 |
return True # External input
|
@@ -554,8 +608,8 @@ class WorkflowValidator:
|
|
554 |
return False
|
555 |
return True
|
556 |
|
557 |
-
def _is_valid_identifier(self, name: str) -> bool:
|
558 |
"""Validates if a string is a valid Python identifier"""
|
559 |
if name and name.strip():
|
560 |
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
561 |
-
return
|
|
|
15 |
MAX_SYSTEM_PROMPT_LENGTH = 4000
|
16 |
MAX_TEMPERATURE = 10.0
|
17 |
|
18 |
+
from loguru import logger
|
19 |
+
|
20 |
|
21 |
class ValidationErrorType(Enum):
|
22 |
"""Types of validation errors that can occur"""
|
23 |
|
24 |
+
INPUTS = "inputs"
|
25 |
+
OUTPUTS = "outputs"
|
26 |
STEP = "step"
|
27 |
DAG = "dag"
|
28 |
VARIABLE = "variable"
|
|
|
42 |
step_id: Optional[str] = None
|
43 |
field_name: Optional[str] = None
|
44 |
|
45 |
+
def __str__(self):
|
46 |
+
subject = ""
|
47 |
+
if self.step_id:
|
48 |
+
subject = f"Model step '{self.step_id}'"
|
49 |
+
if self.field_name:
|
50 |
+
if self.step_id:
|
51 |
+
subject = f"Field '{self.step_id}.{self.field_name}'"
|
52 |
+
else:
|
53 |
+
subject = f"Field '{self.field_name}'"
|
54 |
+
return f"{self.error_type.value}: {subject} - {self.message}"
|
55 |
+
|
56 |
|
57 |
class WorkflowValidationError(ValueError):
|
58 |
"""Base class for workflow validation errors"""
|
|
|
98 |
max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
|
99 |
max_description_length: int = MAX_DESCRIPTION_LENGTH,
|
100 |
max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
|
101 |
+
allowed_model_names: Optional[list[str]] = None,
|
102 |
+
required_input_vars: Optional[list[str]] = None,
|
103 |
+
required_output_vars: Optional[list[str]] = None,
|
104 |
):
|
105 |
self.errors: list[ValidationError] = []
|
106 |
self.workflow: Optional[Workflow] = None
|
|
|
109 |
self.max_field_name_length = max_field_name_length
|
110 |
self.max_description_length = max_description_length
|
111 |
self.max_system_prompt_length = max_system_prompt_length
|
112 |
+
self.required_input_vars = required_input_vars
|
113 |
+
self.required_output_vars = required_output_vars
|
114 |
+
self.allowed_model_names = set(allowed_model_names) if allowed_model_names else None
|
115 |
+
|
116 |
+
def validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
117 |
+
validated = self._validate(workflow, allow_empty)
|
118 |
+
if not validated:
|
119 |
+
raise WorkflowValidationError(self.errors)
|
120 |
+
return True
|
121 |
|
122 |
+
def _validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
123 |
+
"""Main validation entry point
|
124 |
+
Args:
|
125 |
+
workflow: The workflow to validate.
|
126 |
+
allow_empty: If True, empty workflow is allowed. This flag is used to validate the intermediate states while User edits the workflow.
|
127 |
+
"""
|
128 |
self.errors = []
|
129 |
self.workflow = workflow
|
130 |
|
131 |
# Basic workflow validation
|
132 |
+
if not self._validate_workflow_basic(workflow, allow_empty):
|
133 |
return False
|
134 |
|
135 |
# If it's a single-step workflow, use simple validation
|
136 |
if len(workflow.steps) == 1:
|
137 |
+
return self.validate_simple_workflow(workflow, allow_empty)
|
138 |
|
139 |
# Otherwise use complex validation
|
140 |
+
return self.validate_complex_workflow(workflow, allow_empty)
|
141 |
|
142 |
+
def _validate_required_inputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
143 |
+
"""Validates that the workflow has the correct inputs"""
|
|
|
|
|
144 |
|
145 |
+
required_input_vars = self.required_input_vars or []
|
146 |
+
input_vars = set(workflow.inputs)
|
147 |
+
for req_var in required_input_vars:
|
148 |
+
if req_var in input_vars:
|
149 |
+
continue
|
150 |
+
self.errors.append(
|
151 |
+
ValidationError(ValidationErrorType.INPUTS, f"Workflow must have '{req_var}' as an input")
|
152 |
+
)
|
153 |
return False
|
154 |
|
155 |
+
for input_var in input_vars:
|
|
|
156 |
if not self._is_valid_external_input(input_var):
|
157 |
self.errors.append(
|
158 |
ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
|
159 |
)
|
160 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
return True
|
162 |
|
163 |
+
def _validate_required_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
164 |
+
"""Validates that the workflow has the correct outputs"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
required_output_vars = self.required_output_vars or []
|
167 |
+
output_vars = set(workflow.outputs)
|
168 |
+
for req_var in required_output_vars:
|
169 |
+
if req_var in output_vars:
|
170 |
+
continue
|
171 |
+
self.errors.append(
|
172 |
+
ValidationError(ValidationErrorType.OUTPUTS, f"Workflow must produce '{req_var}' as an output")
|
173 |
+
)
|
174 |
+
return False
|
175 |
|
176 |
# Validate output variables
|
177 |
for output_name, output_var in workflow.outputs.items():
|
178 |
+
logger.debug(f"Output name: {output_name}, Output var: {output_var}")
|
179 |
if not output_var:
|
180 |
+
if allow_empty:
|
181 |
+
continue
|
182 |
self.errors.append(
|
183 |
ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
|
184 |
)
|
185 |
return False
|
186 |
|
187 |
+
# Check if output variable references a valid step output
|
188 |
if not self._is_valid_variable_reference(output_var):
|
189 |
self.errors.append(
|
190 |
ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
|
|
|
193 |
|
194 |
# Verify the output field exists in the referenced step
|
195 |
step_id, field_name = _parse_variable_reference(output_var)
|
196 |
+
logger.debug(f"Step ID: {step_id}, Field name: {field_name}, Workflow steps: {workflow.steps.keys()}")
|
197 |
if step_id not in workflow.steps:
|
198 |
self.errors.append(
|
199 |
+
ValidationError(ValidationErrorType.VARIABLE, f"Referenced model step '{step_id}' not found")
|
200 |
)
|
201 |
return False
|
202 |
|
|
|
205 |
self.errors.append(
|
206 |
ValidationError(
|
207 |
ValidationErrorType.VARIABLE,
|
208 |
+
f"Output field '{field_name}' not found in model step '{step_id}'",
|
209 |
step_id,
|
210 |
field_name,
|
211 |
)
|
212 |
)
|
213 |
return False
|
214 |
+
return True
|
215 |
+
|
216 |
+
def validate_input_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
217 |
+
"""Validates the input and output variables"""
|
218 |
+
|
219 |
+
self._validate_required_inputs(workflow, allow_empty)
|
220 |
+
self._validate_required_outputs(workflow, allow_empty)
|
221 |
+
|
222 |
+
# Check for atleast one input
|
223 |
+
if not workflow.inputs:
|
224 |
+
self.errors.append(
|
225 |
+
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
|
226 |
+
)
|
227 |
+
|
228 |
+
# Check for atleast one output
|
229 |
+
if not workflow.outputs:
|
230 |
+
self.errors.append(
|
231 |
+
ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
|
232 |
+
)
|
233 |
+
|
234 |
+
return len(self.errors) == 0
|
235 |
+
|
236 |
+
def validate_simple_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
237 |
+
"""Validates a single-step workflow"""
|
238 |
+
if not self.workflow:
|
239 |
+
return False
|
240 |
+
|
241 |
+
# Get the single step
|
242 |
+
step = next(iter(workflow.steps.values()))
|
243 |
+
|
244 |
+
# Validate the step itself
|
245 |
+
if not self._validate_step(step, allow_empty):
|
246 |
+
return False
|
247 |
+
|
248 |
+
return True
|
249 |
+
|
250 |
+
def validate_complex_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
251 |
+
"""Validates a multi-step workflow"""
|
252 |
+
if not self.workflow:
|
253 |
+
return False
|
254 |
+
|
255 |
+
# Validate each step
|
256 |
+
for step in workflow.steps.values():
|
257 |
+
if not self._validate_step(step, allow_empty):
|
258 |
+
return False
|
259 |
|
260 |
dep_graph = create_step_dep_graph(workflow)
|
261 |
if cycle_step_id := detect_cycles(dep_graph):
|
|
|
284 |
|
285 |
return True
|
286 |
|
287 |
+
def _validate_workflow_basic(self, workflow: Workflow, allow_empty: bool = False) -> bool:
|
288 |
"""Validates basic workflow properties"""
|
|
|
|
|
|
|
|
|
|
|
|
|
289 |
|
290 |
+
# Check the workflow inputs and outputs
|
291 |
+
if not self.validate_input_outputs(workflow, allow_empty):
|
|
|
|
|
292 |
return False
|
293 |
|
|
|
|
|
|
|
|
|
|
|
294 |
# Check for empty workflow
|
295 |
if not workflow.steps:
|
296 |
+
if allow_empty:
|
297 |
+
return True
|
298 |
self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
|
299 |
return False
|
300 |
|
|
|
307 |
return False
|
308 |
return True
|
309 |
|
310 |
+
def _validate_step(self, step: ModelStep, allow_empty: bool = False) -> bool:
|
311 |
"""Validates a single step"""
|
312 |
# Validate required fields
|
313 |
+
|
314 |
+
model_name = step.get_full_model_name()
|
315 |
+
|
316 |
+
if model_name == "/" and not allow_empty:
|
317 |
+
self.errors.append(
|
318 |
+
ValidationError(ValidationErrorType.STEP, "Model name and provider cannot be empty", step.id)
|
319 |
+
)
|
320 |
+
return False
|
321 |
+
|
322 |
+
# Check if the model names are allowed
|
323 |
+
if self.allowed_model_names and model_name not in self.allowed_model_names:
|
324 |
+
self.errors.append(
|
325 |
+
ValidationError(ValidationErrorType.STEP, f"Model name '{model_name}' is not allowed", step.id)
|
326 |
+
)
|
327 |
+
return False
|
328 |
+
|
329 |
+
if not step.id or not step.call_type:
|
330 |
self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
|
331 |
return False
|
332 |
|
|
|
380 |
# Validate input fields
|
381 |
input_names = set()
|
382 |
for field in step.input_fields:
|
383 |
+
if not self._validate_input_field(field, allow_empty):
|
384 |
return False
|
385 |
if field.name in input_names:
|
386 |
self.errors.append(
|
|
|
394 |
# Validate output fields
|
395 |
output_names = set()
|
396 |
for field in step.output_fields:
|
397 |
+
if not self._validate_output_field(field, allow_empty):
|
398 |
return False
|
399 |
if field.name in output_names:
|
400 |
self.errors.append(
|
|
|
407 |
|
408 |
return True
|
409 |
|
410 |
+
def _validate_input_field(self, field: InputField, allow_empty: bool = False) -> bool:
|
411 |
"""Validates an input field"""
|
412 |
# Validate required fields
|
413 |
if not field.name or not field.description or not field.variable:
|
|
|
417 |
return False
|
418 |
|
419 |
# Validate field name
|
420 |
+
if not self._is_valid_identifier(field.name, allow_empty):
|
421 |
self.errors.append(
|
422 |
ValidationError(
|
423 |
ValidationErrorType.NAMING,
|
|
|
462 |
|
463 |
return True
|
464 |
|
465 |
+
def _validate_output_field(self, field: OutputField, allow_empty: bool = False) -> bool:
|
466 |
"""Validates an output field"""
|
467 |
# Validate required fields
|
468 |
if not field.name or not field.description:
|
|
|
474 |
return False
|
475 |
|
476 |
# Validate field name
|
477 |
+
if not self._is_valid_identifier(field.name, allow_empty):
|
478 |
self.errors.append(
|
479 |
ValidationError(
|
480 |
ValidationErrorType.NAMING,
|
|
|
580 |
|
581 |
return True
|
582 |
|
583 |
+
def _is_valid_variable_reference(self, var: str | None, allow_empty: bool = True) -> bool:
|
584 |
"""Validates if a variable reference is properly formatted"""
|
585 |
if not self.workflow:
|
586 |
return False
|
587 |
+
if var is None:
|
588 |
+
return allow_empty
|
589 |
parts = var.split(".")
|
590 |
if len(parts) == 1:
|
591 |
return True # External input
|
|
|
608 |
return False
|
609 |
return True
|
610 |
|
611 |
+
def _is_valid_identifier(self, name: str, allow_empty: bool = False) -> bool:
|
612 |
"""Validates if a string is a valid Python identifier"""
|
613 |
if name and name.strip():
|
614 |
return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
|
615 |
+
return allow_empty
|