File size: 6,652 Bytes
9756440 c225678 9756440 c225678 9756440 c225678 9756440 da814b0 9756440 da814b0 9756440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
from typing import Any, Literal
from pydantic import BaseModel, Field, model_validator
from workflows.structs import ModelStep, TossupWorkflow, Workflow
def make_step_id(step_number: int):
"""Make a step id from a step name."""
if step_number < 26:
return chr(ord("A") + step_number)
else:
# For more than 26 steps, use AA, AB, AC, etc.
first_char = chr(ord("A") + (step_number // 26) - 1)
second_char = chr(ord("A") + (step_number % 26))
return f"{first_char}{second_char}"
def make_step_number(step_id: str):
"""Make a step number from a step id."""
if len(step_id) == 1:
return ord(step_id) - ord("A")
else:
return (ord(step_id[0]) - ord("A")) * 26 + (ord(step_id[1]) - ord("A")) + 1
class ModelStepUIState(BaseModel):
"""Represents the UI state for a model step component."""
expanded: bool = True
active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"
class Config:
frozen = True
def update(self, key: str, value: Any) -> "ModelStepUIState":
"""Update the UI state."""
return self.model_copy(update={key: value})
class PipelineUIState(BaseModel):
"""Represents the UI state for a pipeline component."""
step_ids: list[str] = Field(default_factory=list)
steps: dict[str, ModelStepUIState] = Field(default_factory=dict)
def model_post_init(self, __context: Any) -> None:
if not self.steps and self.step_ids:
self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
return super().model_post_init(__context)
def get_step_position(self, step_id: str):
"""Get the position of a step in the pipeline."""
return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)
@property
def n_steps(self) -> int:
"""Get the number of steps in the pipeline."""
return len(self.step_ids)
@classmethod
def from_workflow(cls, workflow: Workflow):
"""Create a pipeline UI state from a workflow."""
return PipelineUIState(
step_ids=list(workflow.steps.keys()),
steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
)
@classmethod
def from_pipeline_state(cls, pipeline_state: "PipelineState"):
"""Create a pipeline UI state from a pipeline state."""
return cls.from_workflow(pipeline_state.workflow)
# Update methods
def insert_step(self, step_id: str, position: int = -1) -> "PipelineUIState":
"""Insert a step into the pipeline at the given position."""
if position == -1:
position = len(self.step_ids)
self.step_ids.insert(position, step_id)
steps = self.steps | {step_id: ModelStepUIState()}
return self.model_copy(update={"step_ids": self.step_ids, "steps": steps})
def remove_step(self, step_id: str) -> "PipelineUIState":
"""Remove a step from the pipeline."""
if step_id not in self.step_ids:
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}")
self.step_ids.remove(step_id)
self.steps.pop(step_id)
return self.model_copy(update={"step_ids": self.step_ids, "steps": self.steps})
def update_step(self, step_id: str, ui_state: ModelStepUIState) -> "PipelineUIState":
"""Update a step in the pipeline."""
if step_id not in self.steps:
raise ValueError(f"Step {step_id} not found in pipeline. Step IDs: {self.step_ids}")
return self.model_copy(update={"steps": self.steps | {step_id: ui_state}})
class PipelineState(BaseModel):
"""Represents the state for a pipeline component."""
workflow: Workflow
ui_state: PipelineUIState
@classmethod
def from_workflow(cls, workflow: Workflow):
"""Create a pipeline state from a workflow."""
return cls(workflow=workflow, ui_state=PipelineUIState.from_workflow(workflow))
def update_workflow(self, workflow: Workflow) -> "PipelineState":
return self.model_copy(update={"workflow": workflow})
def insert_step(self, position: int, step: ModelStep) -> "PipelineState":
if step.id in self.workflow.steps:
raise ValueError(f"Step {step.id} already exists in pipeline")
# Validate position
if position != -1 and (position < 0 or position > self.n_steps):
raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")
# Create a new workflow with updated steps
workflow = self.workflow.add_step(step)
self.ui_state = self.ui_state.insert_step(step.id, position)
# Return a new PipelineState with the updated workflow
return self.model_copy(update={"workflow": workflow, "ui_state": self.ui_state})
def remove_step(self, position: int) -> "PipelineState":
step_id = self.ui_state.step_ids[position]
workflow = self.workflow.remove_step(step_id)
ui_state = self.ui_state.remove_step(step_id)
return self.model_copy(update={"workflow": workflow, "ui_state": ui_state})
def update_step(self, step: ModelStep, ui_state: ModelStepUIState | None = None) -> "PipelineState":
"""Update a step in the pipeline."""
if step.id not in self.workflow.steps:
raise ValueError(f"Step {step.id} not found in pipeline")
workflow = self.workflow.update_step(step)
update = {"workflow": workflow}
if ui_state is not None:
update["ui_state"] = self.ui_state.update_step(step.id, ui_state)
return self.model_copy(update=update)
def get_available_variables(self, model_step_id: str | None = None) -> list[str]:
"""Get all variables from all steps."""
available_variables = self.available_variables
if model_step_id is None:
return available_variables
prefix = f"{model_step_id}."
return [var for var in available_variables if not var.startswith(prefix)]
@property
def available_variables(self) -> list[str]:
return self.workflow.get_available_variables()
@property
def n_steps(self) -> int:
return len(self.workflow.steps)
def get_new_step_id(self) -> str:
"""Get a step ID for a new step."""
if not self.workflow.steps:
return "A"
else:
last_step_number = max(map(make_step_number, self.workflow.steps.keys()))
return make_step_id(last_step_number + 1)
class TossupPipelineState(PipelineState):
workflow: TossupWorkflow
|