File size: 6,914 Bytes
193db9d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import json
from typing import Any, Literal

import gradio as gr
import yaml
from pydantic import BaseModel, Field

from components import utils
from workflows.factory import create_new_llm_step
from workflows.structs import ModelStep, Workflow


def make_step_id(step_id: int):
    """Make a step id from a step name."""
    if step_id < 26:
        return chr(ord("A") + step_id)
    else:
        # For more than 26 steps, use AA, AB, AC, etc.
        first_char = chr(ord("A") + (step_id // 26) - 1)
        second_char = chr(ord("A") + (step_id % 26))
        return f"{first_char}{second_char}"


class ModelStepUIState(BaseModel):
    """Represents the UI state for a model step component."""

    expanded: bool = True
    active_tab: Literal["model-tab", "inputs-tab", "outputs-tab"] = "model-tab"

    def update(self, key: str, value: Any) -> "ModelStepUIState":
        """Update the UI state."""
        new_state = self.model_copy(update={key: value})
        return new_state


class PipelineUIState(BaseModel):
    """Represents the UI state for a pipeline component."""

    step_ids: list[str] = Field(default_factory=list)
    steps: dict[str, ModelStepUIState] = Field(default_factory=dict)

    def model_post_init(self, __context: utils.Any) -> None:
        if not self.steps and self.step_ids:
            self.steps = {step_id: ModelStepUIState() for step_id in self.step_ids}
        return super().model_post_init(__context)

    def get_step_position(self, step_id: str):
        """Get the position of a step in the pipeline."""
        return next((i for i, step in enumerate(self.step_ids) if step == step_id), None)

    @classmethod
    def from_workflow(cls, workflow: Workflow):
        """Create a pipeline UI state from a workflow."""
        return PipelineUIState(
            step_ids=list(workflow.steps.keys()),
            steps={step_id: ModelStepUIState() for step_id in workflow.steps.keys()},
        )


class PipelineState(BaseModel):
    """Represents the state for a pipeline component."""

    workflow: Workflow
    ui_state: PipelineUIState

    def insert_step(self, position: int, step: ModelStep):
        if step.id in self.workflow.steps:
            raise ValueError(f"Step {step.id} already exists in pipeline")

        # Validate position
        if position != -1 and (position < 0 or position > self.n_steps):
            raise ValueError(f"Invalid position: {position}. Must be between 0 and {self.n_steps} or -1")

        self.workflow.steps[step.id] = step

        self.ui_state = self.ui_state.model_copy()
        self.ui_state.steps[step.id] = ModelStepUIState()
        if position == -1:
            self.ui_state.step_ids.append(step.id)
        else:
            self.ui_state.step_ids.insert(position, step.id)
        return self

    def remove_step(self, position: int):
        step_id = self.ui_state.step_ids.pop(position)
        self.workflow.steps.pop(step_id)
        self.ui_state = self.ui_state.model_copy()
        self.ui_state.steps.pop(step_id)
        self.update_output_variables_mapping()

    def update_output_variables_mapping(self):
        available_variables = set(self.available_variables)
        for output_field in self.workflow.outputs:
            if self.workflow.outputs[output_field] not in available_variables:
                self.workflow.outputs[output_field] = None
        return self

    @property
    def available_variables(self):
        return self.workflow.get_available_variables()

    @property
    def n_steps(self):
        return len(self.workflow.steps)


class PipelineStateManager:
    """Manages a pipeline of multiple steps."""

    def get_formatted_config(self, state: PipelineState, format: Literal["json", "yaml"] = "yaml"):
        """Get the full pipeline configuration."""
        config = state.workflow.model_dump(exclude_defaults=True)
        if format == "yaml":
            return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
        else:
            return json.dumps(config, indent=4, sort_keys=False)

    def count_state(self):
        return gr.State(len(self.steps))

    def add_step(self, state: PipelineState, position: int = -1, name=""):
        """Create a new step and return its state."""
        step_id = make_step_id(state.n_steps)
        step_name = name or f"Step {state.n_steps + 1}"
        new_step = create_new_llm_step(step_id=step_id, name=step_name)
        state = state.insert_step(position, new_step)
        return state, state.ui_state, state.available_variables

    def remove_step(self, state: PipelineState, position: int):
        """Remove a step from the pipeline."""
        if 0 <= position < state.n_steps:
            state = state.remove_step(position)
        else:
            raise ValueError(f"Invalid step position: {position}")
        return state, state.ui_state, state.available_variables

    def move_up(self, ui_state: PipelineUIState, position: int):
        """Move a step up in the pipeline."""
        utils.move_item(ui_state.step_ids, position, "up")
        return ui_state.model_copy()

    def move_down(self, ui_state: PipelineUIState, position: int):
        """Move a step down in the pipeline."""
        utils.move_item(ui_state.step_ids, position, "down")
        return ui_state.model_copy()

    def update_model_step_state(self, state: PipelineState, model_step: ModelStep, ui_state: ModelStepUIState):
        """Update a step in the pipeline."""
        state.workflow.steps[model_step.id] = model_step.model_copy()
        state.ui_state.steps[model_step.id] = ui_state.model_copy()
        state.ui_state = state.ui_state.model_copy()
        state.update_output_variables_mapping()
        return state, state.ui_state, state.available_variables

    def update_output_variables(self, state: PipelineState, target: str, produced_variable: str):
        if produced_variable == "Choose variable...":
            produced_variable = None
        """Update the output variables for a step."""
        state.workflow.outputs.update({target: produced_variable})
        return state

    def update_model_step_ui(self, state: PipelineState, step_ui: ModelStepUIState, step_id: str):
        """Update a step in the pipeline."""
        state.ui_state.steps[step_id] = step_ui.model_copy()
        return state, state.ui_state

    def get_all_variables(self, state: PipelineState, model_step_id: str | None = None) -> list[str]:
        """Get all variables from all steps."""
        available_variables = state.available_variables
        if model_step_id is None:
            return available_variables
        else:
            prefix = f"{model_step_id}."
            return [var for var in available_variables if not var.startswith(prefix)]

    def get_pipeline_config(self):
        """Get the full pipeline configuration."""
        return self.workflow