Maharshi Gor commited on
Commit
e272e20
·
1 Parent(s): 7985347

Added more validation layers

Browse files
src/components/model_pipeline/model_pipeline.py CHANGED
@@ -8,12 +8,11 @@ from components import typed_dicts as td
8
  from components.model_pipeline.state_manager import (
9
  BasePipelineValidator,
10
  PipelineStateManager,
11
- TossupPipelineStateManager,
12
  )
13
  from components.model_step.model_step import ModelStepComponent
14
- from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
15
  from components.utils import make_state
16
- from workflows.structs import ModelStep, TossupWorkflow, Workflow
17
  from workflows.validators import WorkflowValidationError, WorkflowValidator
18
 
19
  from .state_manager import get_output_panel_state
@@ -24,6 +23,8 @@ DEFAULT_MAX_TEMPERATURE = 5.0
24
  class PipelineInterface:
25
  """UI for the pipeline."""
26
 
 
 
27
  def __init__(
28
  self,
29
  app: gr.Blocks,
@@ -43,17 +44,14 @@ class PipelineInterface:
43
  self.workflow_state = make_state(workflow.model_dump())
44
  self.variables_state = make_state(workflow.get_available_variables())
45
  self.output_panel_state = make_state(get_output_panel_state(workflow))
 
46
 
47
  # Maintains the toggle state change for pipeline changes through user input.
48
  self.pipeline_change = gr.State(False)
49
 
50
- if isinstance(workflow, TossupWorkflow):
51
- pipeline_state = TossupPipelineState(workflow=workflow, ui_state=ui_state)
52
- self.sm = TossupPipelineStateManager(validator)
53
- else:
54
- pipeline_state = PipelineState(workflow=workflow, ui_state=ui_state)
55
- self.sm = PipelineStateManager(validator)
56
- self.pipeline_state = make_state(pipeline_state.model_dump())
57
 
58
  def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
59
  """Get the auxiliary states for the pipeline."""
@@ -164,20 +162,15 @@ class PipelineInterface:
164
  )
165
  return add_step_btn
166
 
167
- def validate_workflow(self, state_dict: td.PipelineStateDict):
168
  """Validate the workflow."""
169
  try:
170
  state = self.sm.make_pipeline_state(state_dict)
171
- validator = WorkflowValidator(
172
- max_temperature=self.config.get("max_temperature", 10),
173
- )
174
- if not validator.validate(state.workflow):
175
- raise WorkflowValidationError(validator.errors)
176
  except ValueError as e:
177
  logger.exception(e)
178
- state_dict_str = yaml.dump(state_dict, default_flow_style=False, indent=2)
179
- logger.error(f"Could not validate workflow: \n{state_dict_str}")
180
- raise gr.Error(e)
181
 
182
  def _render_pipeline_header(self):
183
  # Add Step button at top
@@ -270,9 +263,8 @@ class PipelineInterface:
270
  )
271
 
272
  # Connect the export button to show the workflow JSON
273
- self.add_triggers_for_pipeline_export([export_btn.click], self.pipeline_state, scroll=True)
274
- export_btn.click(self.validate_workflow, inputs=[self.pipeline_state], outputs=[]).success(
275
- fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion]
276
  )
277
 
278
  def render(self):
@@ -315,21 +307,30 @@ class PipelineInterface:
315
 
316
  self._render_pipeline_preview()
317
 
318
- def add_triggers_for_pipeline_export(self, triggers: list, input_pipeline_state: gr.State, scroll: bool = False):
 
 
 
 
 
 
319
  js = None
320
  if scroll:
321
  js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}"
322
 
323
  # TODO: modify this validate workflow to user input level and not executable label.
324
  # (workflows that can be converted to UI interface, no logical validation)
325
- gr.on(
326
  triggers,
327
- self.validate_workflow,
328
  inputs=[input_pipeline_state],
329
- outputs=[],
330
  ).success(
331
  fn=self.sm.get_formatted_config,
332
  inputs=[self.pipeline_state, gr.State("yaml")],
333
  outputs=[self.config_output, self.error_display],
334
  js=js,
335
  )
 
 
 
 
8
  from components.model_pipeline.state_manager import (
9
  BasePipelineValidator,
10
  PipelineStateManager,
 
11
  )
12
  from components.model_step.model_step import ModelStepComponent
13
+ from components.structs import ModelStepUIState, PipelineState, PipelineUIState
14
  from components.utils import make_state
15
+ from workflows.structs import ModelStep, Workflow
16
  from workflows.validators import WorkflowValidationError, WorkflowValidator
17
 
18
  from .state_manager import get_output_panel_state
 
23
  class PipelineInterface:
24
  """UI for the pipeline."""
25
 
26
+ state_manager_cls = PipelineStateManager
27
+
28
  def __init__(
29
  self,
30
  app: gr.Blocks,
 
44
  self.workflow_state = make_state(workflow.model_dump())
45
  self.variables_state = make_state(workflow.get_available_variables())
46
  self.output_panel_state = make_state(get_output_panel_state(workflow))
47
+ self.ui_validator = validator
48
 
49
  # Maintains the toggle state change for pipeline changes through user input.
50
  self.pipeline_change = gr.State(False)
51
 
52
+ self.sm = self.state_manager_cls(validator)
53
+ pipeline_state_dict = self.sm.create_pipeline_state_dict(workflow=workflow, ui_state=ui_state)
54
+ self.pipeline_state = make_state(pipeline_state_dict)
 
 
 
 
55
 
56
  def get_aux_states(pipeline_state_dict: td.PipelineStateDict):
57
  """Get the auxiliary states for the pipeline."""
 
162
  )
163
  return add_step_btn
164
 
165
+ def validate_workflow_ui(self, state_dict: td.PipelineStateDict):
166
  """Validate the workflow."""
167
  try:
168
  state = self.sm.make_pipeline_state(state_dict)
169
+ self.ui_validator(state.workflow)
170
+ return gr.update(visible=False)
 
 
 
171
  except ValueError as e:
172
  logger.exception(e)
173
+ return gr.update(visible=True, value=str(e))
 
 
174
 
175
  def _render_pipeline_header(self):
176
  # Add Step button at top
 
263
  )
264
 
265
  # Connect the export button to show the workflow JSON
266
+ self.add_triggers_for_pipeline_export(
267
+ [export_btn.click], self.pipeline_state, scroll=True, expand_accordion=True
 
268
  )
269
 
270
  def render(self):
 
307
 
308
  self._render_pipeline_preview()
309
 
310
+ def add_triggers_for_pipeline_export(
311
+ self,
312
+ triggers: list,
313
+ input_pipeline_state: gr.State,
314
+ scroll: bool = False,
315
+ expand_accordion: bool = False,
316
+ ):
317
  js = None
318
  if scroll:
319
  js = "() => {document.querySelector('.pipeline-preview').scrollIntoView({behavior: 'smooth'})}"
320
 
321
  # TODO: modify this validate workflow to user input level and not executable label.
322
  # (workflows that can be converted to UI interface, no logical validation)
323
+ event = gr.on(
324
  triggers,
325
+ self.validate_workflow_ui,
326
  inputs=[input_pipeline_state],
327
+ outputs=[self.error_display],
328
  ).success(
329
  fn=self.sm.get_formatted_config,
330
  inputs=[self.pipeline_state, gr.State("yaml")],
331
  outputs=[self.config_output, self.error_display],
332
  js=js,
333
  )
334
+
335
+ if expand_accordion:
336
+ event.then(fn=lambda: gr.update(visible=True, open=True), outputs=[self.config_accordion])
src/components/model_pipeline/state_manager.py CHANGED
@@ -11,7 +11,7 @@ from pydantic import BaseModel, ValidationError
11
  from app_configs import UNSELECTED_VAR_NAME
12
  from components import typed_dicts as td
13
  from components import utils
14
- from components.structs import ModelStepUIState, PipelineState, TossupPipelineState
15
  from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
16
  from workflows.factory import create_new_llm_step
17
  from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
@@ -72,6 +72,10 @@ class PipelineStateManager:
72
  """Make a state from a state dictionary."""
73
  return self.pipeline_state_cls(**state_dict)
74
 
 
 
 
 
75
  def add_step(
76
  self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
77
  ) -> td.PipelineStateDict:
@@ -212,6 +216,7 @@ class PipelineStateManager:
212
  """Update a workflow from a YAML string."""
213
  try:
214
  workflow = self.parse_yaml_workflow(yaml_str, strict=True)
 
215
  self.validator and self.validator(workflow)
216
  state = self.pipeline_state_cls.from_workflow(workflow)
217
  return state.model_dump(), not change_state, gr.update(visible=False)
@@ -229,6 +234,11 @@ class TossupPipelineStateManager(PipelineStateManager):
229
  def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
230
  return super().make_pipeline_state(state_dict)
231
 
 
 
 
 
 
232
  def update_workflow_from_code(
233
  self, yaml_str: str, change_state: bool
234
  ) -> tuple[td.TossupPipelineStateDict, bool, dict]:
 
11
  from app_configs import UNSELECTED_VAR_NAME
12
  from components import typed_dicts as td
13
  from components import utils
14
+ from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
15
  from envs import DOCS_REPO_BRANCH, DOCS_REPO_URL
16
  from workflows.factory import create_new_llm_step
17
  from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
 
72
  """Make a state from a state dictionary."""
73
  return self.pipeline_state_cls(**state_dict)
74
 
75
+ def create_pipeline_state_dict(self, workflow: Workflow, ui_state: PipelineUIState) -> td.PipelineStateDict:
76
+ """Create a pipeline state from a workflow."""
77
+ return self.pipeline_state_cls(workflow=workflow, ui_state=ui_state).model_dump()
78
+
79
  def add_step(
80
  self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
81
  ) -> td.PipelineStateDict:
 
216
  """Update a workflow from a YAML string."""
217
  try:
218
  workflow = self.parse_yaml_workflow(yaml_str, strict=True)
219
+ logger.debug(f"Validator: {self.validator}")
220
  self.validator and self.validator(workflow)
221
  state = self.pipeline_state_cls.from_workflow(workflow)
222
  return state.model_dump(), not change_state, gr.update(visible=False)
 
234
  def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
235
  return super().make_pipeline_state(state_dict)
236
 
237
+ def create_pipeline_state_dict(
238
+ self, workflow: TossupWorkflow, ui_state: PipelineUIState
239
+ ) -> td.TossupPipelineStateDict:
240
+ return super().create_pipeline_state_dict(workflow, ui_state)
241
+
242
  def update_workflow_from_code(
243
  self, yaml_str: str, change_state: bool
244
  ) -> tuple[td.TossupPipelineStateDict, bool, dict]:
src/components/model_pipeline/tossup_pipeline.py CHANGED
@@ -9,7 +9,7 @@ from display.formatting import tiny_styled_warning
9
  from workflows.structs import Buzzer, TossupWorkflow
10
 
11
  from .model_pipeline import PipelineInterface
12
- from .state_manager import BasePipelineValidator
13
 
14
 
15
  def toggleable_slider(
@@ -33,6 +33,10 @@ def toggleable_slider(
33
 
34
 
35
  class TossupPipelineInterface(PipelineInterface):
 
 
 
 
36
  def __init__(
37
  self,
38
  app: gr.Blocks,
 
9
  from workflows.structs import Buzzer, TossupWorkflow
10
 
11
  from .model_pipeline import PipelineInterface
12
+ from .state_manager import BasePipelineValidator, TossupPipelineStateManager
13
 
14
 
15
  def toggleable_slider(
 
33
 
34
 
35
  class TossupPipelineInterface(PipelineInterface):
36
+ """UI for the tossup pipeline."""
37
+
38
+ state_manager_cls = TossupPipelineStateManager
39
+
40
  def __init__(
41
  self,
42
  app: gr.Blocks,
src/components/quizbowl/validation.py CHANGED
@@ -1,10 +1,10 @@
1
  from typing import Literal
2
 
3
- from app_configs import CONFIGS
4
  from components.structs import PipelineState, TossupPipelineState
5
  from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
6
  from workflows.structs import TossupWorkflow, Workflow
7
- from workflows.validators import WorkflowValidator
8
 
9
 
10
  def validate_workflow(
@@ -83,3 +83,14 @@ class UserInputWorkflowValidator:
83
  "\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
84
  f"\nMake sure you have values set for all the outputs: {default_str}"
85
  )
 
 
 
 
 
 
 
 
 
 
 
 
1
  from typing import Literal
2
 
3
+ from app_configs import AVAILABLE_MODELS, CONFIGS
4
  from components.structs import PipelineState, TossupPipelineState
5
  from components.typed_dicts import PipelineStateDict, TossupPipelineStateDict
6
  from workflows.structs import TossupWorkflow, Workflow
7
+ from workflows.validators import WorkflowValidationError, WorkflowValidator
8
 
9
 
10
  def validate_workflow(
 
83
  "\nDon't remove the keys from the 'outputs' field in the workflow. Only update their values."
84
  f"\nMake sure you have values set for all the outputs: {default_str}"
85
  )
86
+
87
+ # Validate the workflow
88
+ allowed_model_names = AVAILABLE_MODELS.keys()
89
+ self.validator = WorkflowValidator(allowed_model_names=allowed_model_names)
90
+ try:
91
+ self.validator.validate(workflow, allow_empty=True)
92
+ except WorkflowValidationError as e:
93
+ error_msg_total = f"Found {len(e.errors)} errors in the workflow:\n"
94
+ error_msg_list = [f"- {err.message}" for err in e.errors]
95
+ error_msg = error_msg_total + "\n".join(error_msg_list)
96
+ raise ValueError(error_msg)
src/workflows/validators.py CHANGED
@@ -15,10 +15,14 @@ MAX_DESCRIPTION_LENGTH = 200
15
  MAX_SYSTEM_PROMPT_LENGTH = 4000
16
  MAX_TEMPERATURE = 10.0
17
 
 
 
18
 
19
  class ValidationErrorType(Enum):
20
  """Types of validation errors that can occur"""
21
 
 
 
22
  STEP = "step"
23
  DAG = "dag"
24
  VARIABLE = "variable"
@@ -38,6 +42,17 @@ class ValidationError:
38
  step_id: Optional[str] = None
39
  field_name: Optional[str] = None
40
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
  class WorkflowValidationError(ValueError):
43
  """Base class for workflow validation errors"""
@@ -83,6 +98,9 @@ class WorkflowValidator:
83
  max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
84
  max_description_length: int = MAX_DESCRIPTION_LENGTH,
85
  max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
 
 
 
86
  ):
87
  self.errors: list[ValidationError] = []
88
  self.workflow: Optional[Workflow] = None
@@ -91,99 +109,82 @@ class WorkflowValidator:
91
  self.max_field_name_length = max_field_name_length
92
  self.max_description_length = max_description_length
93
  self.max_system_prompt_length = max_system_prompt_length
 
 
 
 
 
 
 
 
 
94
 
95
- def validate(self, workflow: Workflow) -> bool:
96
- """Main validation entry point"""
 
 
 
 
97
  self.errors = []
98
  self.workflow = workflow
99
 
100
  # Basic workflow validation
101
- if not self._validate_workflow_basic(workflow):
102
  return False
103
 
104
  # If it's a single-step workflow, use simple validation
105
  if len(workflow.steps) == 1:
106
- return self.validate_simple_workflow(workflow)
107
 
108
  # Otherwise use complex validation
109
- return self.validate_complex_workflow(workflow)
110
 
111
- def validate_simple_workflow(self, workflow: Workflow) -> bool:
112
- """Validates a single-step workflow"""
113
- if not self.workflow:
114
- return False
115
 
116
- # Get the single step
117
- step = next(iter(workflow.steps.values()))
118
-
119
- # Validate the step itself
120
- if not self._validate_step(step):
 
 
 
121
  return False
122
 
123
- # Validate input variables
124
- for input_var in workflow.inputs:
125
  if not self._is_valid_external_input(input_var):
126
  self.errors.append(
127
  ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
128
  )
129
  return False
130
-
131
- # Validate output variables
132
- for output_name, output_var in workflow.outputs.items():
133
- if not output_var:
134
- self.errors.append(
135
- ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
136
- )
137
- return False
138
-
139
- # Check if output variable references a valid step output
140
- if not self._is_valid_variable_reference(output_var):
141
- self.errors.append(
142
- ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
143
- )
144
- return False
145
-
146
- # Verify the output field exists in the step
147
- _, field_name = _parse_variable_reference(output_var)
148
- if not any(field.name == field_name for field in step.output_fields):
149
- self.errors.append(
150
- ValidationError(
151
- ValidationErrorType.VARIABLE,
152
- f"Output field '{field_name}' not found in step '{step.id}'",
153
- step.id,
154
- field_name,
155
- )
156
- )
157
- return False
158
-
159
  return True
160
 
161
- def validate_complex_workflow(self, workflow: Workflow) -> bool:
162
- """Validates a multi-step workflow"""
163
- if not self.workflow:
164
- return False
165
-
166
- # Validate each step
167
- for step in workflow.steps.values():
168
- if not self._validate_step(step):
169
- return False
170
 
171
- # Validate input variables
172
- for input_var in workflow.inputs:
173
- if not self._is_valid_external_input(input_var):
174
- self.errors.append(
175
- ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
176
- )
177
- return False
 
 
178
 
179
  # Validate output variables
180
  for output_name, output_var in workflow.outputs.items():
 
181
  if not output_var:
 
 
182
  self.errors.append(
183
  ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
184
  )
185
  return False
186
 
 
187
  if not self._is_valid_variable_reference(output_var):
188
  self.errors.append(
189
  ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
@@ -192,9 +193,10 @@ class WorkflowValidator:
192
 
193
  # Verify the output field exists in the referenced step
194
  step_id, field_name = _parse_variable_reference(output_var)
 
195
  if step_id not in workflow.steps:
196
  self.errors.append(
197
- ValidationError(ValidationErrorType.VARIABLE, f"Referenced step '{step_id}' not found")
198
  )
199
  return False
200
 
@@ -203,12 +205,57 @@ class WorkflowValidator:
203
  self.errors.append(
204
  ValidationError(
205
  ValidationErrorType.VARIABLE,
206
- f"Output field '{field_name}' not found in step '{step_id}'",
207
  step_id,
208
  field_name,
209
  )
210
  )
211
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
212
 
213
  dep_graph = create_step_dep_graph(workflow)
214
  if cycle_step_id := detect_cycles(dep_graph):
@@ -237,28 +284,17 @@ class WorkflowValidator:
237
 
238
  return True
239
 
240
- def _validate_workflow_basic(self, workflow: Workflow) -> bool:
241
  """Validates basic workflow properties"""
242
- # Check for atleast one input
243
- if not workflow.inputs:
244
- self.errors.append(
245
- ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
246
- )
247
- return False
248
 
249
- if not workflow.outputs:
250
- self.errors.append(
251
- ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
252
- )
253
  return False
254
 
255
- for output_var in workflow.outputs.values():
256
- if output_var is None:
257
- self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Output variable cannot be None"))
258
- return False
259
-
260
  # Check for empty workflow
261
  if not workflow.steps:
 
 
262
  self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
263
  return False
264
 
@@ -271,10 +307,26 @@ class WorkflowValidator:
271
  return False
272
  return True
273
 
274
- def _validate_step(self, step: ModelStep) -> bool:
275
  """Validates a single step"""
276
  # Validate required fields
277
- if not step.id or not step.name or not step.model or not step.provider or not step.call_type:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
  self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
279
  return False
280
 
@@ -328,7 +380,7 @@ class WorkflowValidator:
328
  # Validate input fields
329
  input_names = set()
330
  for field in step.input_fields:
331
- if not self._validate_input_field(field):
332
  return False
333
  if field.name in input_names:
334
  self.errors.append(
@@ -342,7 +394,7 @@ class WorkflowValidator:
342
  # Validate output fields
343
  output_names = set()
344
  for field in step.output_fields:
345
- if not self._validate_output_field(field):
346
  return False
347
  if field.name in output_names:
348
  self.errors.append(
@@ -355,7 +407,7 @@ class WorkflowValidator:
355
 
356
  return True
357
 
358
- def _validate_input_field(self, field: InputField) -> bool:
359
  """Validates an input field"""
360
  # Validate required fields
361
  if not field.name or not field.description or not field.variable:
@@ -365,7 +417,7 @@ class WorkflowValidator:
365
  return False
366
 
367
  # Validate field name
368
- if not self._is_valid_identifier(field.name):
369
  self.errors.append(
370
  ValidationError(
371
  ValidationErrorType.NAMING,
@@ -410,7 +462,7 @@ class WorkflowValidator:
410
 
411
  return True
412
 
413
- def _validate_output_field(self, field: OutputField) -> bool:
414
  """Validates an output field"""
415
  # Validate required fields
416
  if not field.name or not field.description:
@@ -422,7 +474,7 @@ class WorkflowValidator:
422
  return False
423
 
424
  # Validate field name
425
- if not self._is_valid_identifier(field.name):
426
  self.errors.append(
427
  ValidationError(
428
  ValidationErrorType.NAMING,
@@ -528,10 +580,12 @@ class WorkflowValidator:
528
 
529
  return True
530
 
531
- def _is_valid_variable_reference(self, var: str) -> bool:
532
  """Validates if a variable reference is properly formatted"""
533
  if not self.workflow:
534
  return False
 
 
535
  parts = var.split(".")
536
  if len(parts) == 1:
537
  return True # External input
@@ -554,8 +608,8 @@ class WorkflowValidator:
554
  return False
555
  return True
556
 
557
- def _is_valid_identifier(self, name: str) -> bool:
558
  """Validates if a string is a valid Python identifier"""
559
  if name and name.strip():
560
  return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
561
- return False
 
15
  MAX_SYSTEM_PROMPT_LENGTH = 4000
16
  MAX_TEMPERATURE = 10.0
17
 
18
+ from loguru import logger
19
+
20
 
21
  class ValidationErrorType(Enum):
22
  """Types of validation errors that can occur"""
23
 
24
+ INPUTS = "inputs"
25
+ OUTPUTS = "outputs"
26
  STEP = "step"
27
  DAG = "dag"
28
  VARIABLE = "variable"
 
42
  step_id: Optional[str] = None
43
  field_name: Optional[str] = None
44
 
45
+ def __str__(self):
46
+ subject = ""
47
+ if self.step_id:
48
+ subject = f"Model step '{self.step_id}'"
49
+ if self.field_name:
50
+ if self.step_id:
51
+ subject = f"Field '{self.step_id}.{self.field_name}'"
52
+ else:
53
+ subject = f"Field '{self.field_name}'"
54
+ return f"{self.error_type.value}: {subject} - {self.message}"
55
+
56
 
57
  class WorkflowValidationError(ValueError):
58
  """Base class for workflow validation errors"""
 
98
  max_field_name_length: int = MAX_FIELD_NAME_LENGTH,
99
  max_description_length: int = MAX_DESCRIPTION_LENGTH,
100
  max_system_prompt_length: int = MAX_SYSTEM_PROMPT_LENGTH,
101
+ allowed_model_names: Optional[list[str]] = None,
102
+ required_input_vars: Optional[list[str]] = None,
103
+ required_output_vars: Optional[list[str]] = None,
104
  ):
105
  self.errors: list[ValidationError] = []
106
  self.workflow: Optional[Workflow] = None
 
109
  self.max_field_name_length = max_field_name_length
110
  self.max_description_length = max_description_length
111
  self.max_system_prompt_length = max_system_prompt_length
112
+ self.required_input_vars = required_input_vars
113
+ self.required_output_vars = required_output_vars
114
+ self.allowed_model_names = set(allowed_model_names) if allowed_model_names else None
115
+
116
+ def validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
117
+ validated = self._validate(workflow, allow_empty)
118
+ if not validated:
119
+ raise WorkflowValidationError(self.errors)
120
+ return True
121
 
122
+ def _validate(self, workflow: Workflow, allow_empty: bool = False) -> bool:
123
+ """Main validation entry point
124
+ Args:
125
+ workflow: The workflow to validate.
126
+ allow_empty: If True, empty workflow is allowed. This flag is used to validate the intermediate states while User edits the workflow.
127
+ """
128
  self.errors = []
129
  self.workflow = workflow
130
 
131
  # Basic workflow validation
132
+ if not self._validate_workflow_basic(workflow, allow_empty):
133
  return False
134
 
135
  # If it's a single-step workflow, use simple validation
136
  if len(workflow.steps) == 1:
137
+ return self.validate_simple_workflow(workflow, allow_empty)
138
 
139
  # Otherwise use complex validation
140
+ return self.validate_complex_workflow(workflow, allow_empty)
141
 
142
+ def _validate_required_inputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
143
+ """Validates that the workflow has the correct inputs"""
 
 
144
 
145
+ required_input_vars = self.required_input_vars or []
146
+ input_vars = set(workflow.inputs)
147
+ for req_var in required_input_vars:
148
+ if req_var in input_vars:
149
+ continue
150
+ self.errors.append(
151
+ ValidationError(ValidationErrorType.INPUTS, f"Workflow must have '{req_var}' as an input")
152
+ )
153
  return False
154
 
155
+ for input_var in input_vars:
 
156
  if not self._is_valid_external_input(input_var):
157
  self.errors.append(
158
  ValidationError(ValidationErrorType.VARIABLE, f"Invalid input variable format: {input_var}")
159
  )
160
  return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  return True
162
 
163
+ def _validate_required_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
164
+ """Validates that the workflow has the correct outputs"""
 
 
 
 
 
 
 
165
 
166
+ required_output_vars = self.required_output_vars or []
167
+ output_vars = set(workflow.outputs)
168
+ for req_var in required_output_vars:
169
+ if req_var in output_vars:
170
+ continue
171
+ self.errors.append(
172
+ ValidationError(ValidationErrorType.OUTPUTS, f"Workflow must produce '{req_var}' as an output")
173
+ )
174
+ return False
175
 
176
  # Validate output variables
177
  for output_name, output_var in workflow.outputs.items():
178
+ logger.debug(f"Output name: {output_name}, Output var: {output_var}")
179
  if not output_var:
180
+ if allow_empty:
181
+ continue
182
  self.errors.append(
183
  ValidationError(ValidationErrorType.VARIABLE, f"Missing output variable for {output_name}")
184
  )
185
  return False
186
 
187
+ # Check if output variable references a valid step output
188
  if not self._is_valid_variable_reference(output_var):
189
  self.errors.append(
190
  ValidationError(ValidationErrorType.VARIABLE, f"Invalid output variable reference: {output_var}")
 
193
 
194
  # Verify the output field exists in the referenced step
195
  step_id, field_name = _parse_variable_reference(output_var)
196
+ logger.debug(f"Step ID: {step_id}, Field name: {field_name}, Workflow steps: {workflow.steps.keys()}")
197
  if step_id not in workflow.steps:
198
  self.errors.append(
199
+ ValidationError(ValidationErrorType.VARIABLE, f"Referenced model step '{step_id}' not found")
200
  )
201
  return False
202
 
 
205
  self.errors.append(
206
  ValidationError(
207
  ValidationErrorType.VARIABLE,
208
+ f"Output field '{field_name}' not found in model step '{step_id}'",
209
  step_id,
210
  field_name,
211
  )
212
  )
213
  return False
214
+ return True
215
+
216
+ def validate_input_outputs(self, workflow: Workflow, allow_empty: bool = False) -> bool:
217
+ """Validates the input and output variables"""
218
+
219
+ self._validate_required_inputs(workflow, allow_empty)
220
+ self._validate_required_outputs(workflow, allow_empty)
221
+
222
+ # Check for atleast one input
223
+ if not workflow.inputs:
224
+ self.errors.append(
225
+ ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one input")
226
+ )
227
+
228
+ # Check for atleast one output
229
+ if not workflow.outputs:
230
+ self.errors.append(
231
+ ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one output")
232
+ )
233
+
234
+ return len(self.errors) == 0
235
+
236
+ def validate_simple_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
237
+ """Validates a single-step workflow"""
238
+ if not self.workflow:
239
+ return False
240
+
241
+ # Get the single step
242
+ step = next(iter(workflow.steps.values()))
243
+
244
+ # Validate the step itself
245
+ if not self._validate_step(step, allow_empty):
246
+ return False
247
+
248
+ return True
249
+
250
+ def validate_complex_workflow(self, workflow: Workflow, allow_empty: bool = False) -> bool:
251
+ """Validates a multi-step workflow"""
252
+ if not self.workflow:
253
+ return False
254
+
255
+ # Validate each step
256
+ for step in workflow.steps.values():
257
+ if not self._validate_step(step, allow_empty):
258
+ return False
259
 
260
  dep_graph = create_step_dep_graph(workflow)
261
  if cycle_step_id := detect_cycles(dep_graph):
 
284
 
285
  return True
286
 
287
+ def _validate_workflow_basic(self, workflow: Workflow, allow_empty: bool = False) -> bool:
288
  """Validates basic workflow properties"""
 
 
 
 
 
 
289
 
290
+ # Check the workflow inputs and outputs
291
+ if not self.validate_input_outputs(workflow, allow_empty):
 
 
292
  return False
293
 
 
 
 
 
 
294
  # Check for empty workflow
295
  if not workflow.steps:
296
+ if allow_empty:
297
+ return True
298
  self.errors.append(ValidationError(ValidationErrorType.GENERAL, "Workflow must contain at least one step"))
299
  return False
300
 
 
307
  return False
308
  return True
309
 
310
+ def _validate_step(self, step: ModelStep, allow_empty: bool = False) -> bool:
311
  """Validates a single step"""
312
  # Validate required fields
313
+
314
+ model_name = step.get_full_model_name()
315
+
316
+ if model_name == "/" and not allow_empty:
317
+ self.errors.append(
318
+ ValidationError(ValidationErrorType.STEP, "Model name and provider cannot be empty", step.id)
319
+ )
320
+ return False
321
+
322
+ # Check if the model names are allowed
323
+ if self.allowed_model_names and model_name not in self.allowed_model_names:
324
+ self.errors.append(
325
+ ValidationError(ValidationErrorType.STEP, f"Model name '{model_name}' is not allowed", step.id)
326
+ )
327
+ return False
328
+
329
+ if not step.id or not step.call_type:
330
  self.errors.append(ValidationError(ValidationErrorType.STEP, "Step missing required fields", step.id))
331
  return False
332
 
 
380
  # Validate input fields
381
  input_names = set()
382
  for field in step.input_fields:
383
+ if not self._validate_input_field(field, allow_empty):
384
  return False
385
  if field.name in input_names:
386
  self.errors.append(
 
394
  # Validate output fields
395
  output_names = set()
396
  for field in step.output_fields:
397
+ if not self._validate_output_field(field, allow_empty):
398
  return False
399
  if field.name in output_names:
400
  self.errors.append(
 
407
 
408
  return True
409
 
410
+ def _validate_input_field(self, field: InputField, allow_empty: bool = False) -> bool:
411
  """Validates an input field"""
412
  # Validate required fields
413
  if not field.name or not field.description or not field.variable:
 
417
  return False
418
 
419
  # Validate field name
420
+ if not self._is_valid_identifier(field.name, allow_empty):
421
  self.errors.append(
422
  ValidationError(
423
  ValidationErrorType.NAMING,
 
462
 
463
  return True
464
 
465
+ def _validate_output_field(self, field: OutputField, allow_empty: bool = False) -> bool:
466
  """Validates an output field"""
467
  # Validate required fields
468
  if not field.name or not field.description:
 
474
  return False
475
 
476
  # Validate field name
477
+ if not self._is_valid_identifier(field.name, allow_empty):
478
  self.errors.append(
479
  ValidationError(
480
  ValidationErrorType.NAMING,
 
580
 
581
  return True
582
 
583
+ def _is_valid_variable_reference(self, var: str | None, allow_empty: bool = True) -> bool:
584
  """Validates if a variable reference is properly formatted"""
585
  if not self.workflow:
586
  return False
587
+ if var is None:
588
+ return allow_empty
589
  parts = var.split(".")
590
  if len(parts) == 1:
591
  return True # External input
 
608
  return False
609
  return True
610
 
611
+ def _is_valid_identifier(self, name: str, allow_empty: bool = False) -> bool:
612
  """Validates if a string is a valid Python identifier"""
613
  if name and name.strip():
614
  return bool(re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", name))
615
+ return allow_empty