|
from typing import Any, Literal, Union |
|
|
|
import gradio as gr |
|
from loguru import logger |
|
|
|
from app_configs import UNSELECTED_VAR_NAME |
|
from components.model_pipeline.state_manager import ModelStepUIState |
|
from components.utils import DIRECTIONS, move_item |
|
from utils import get_model_and_provider |
|
from workflows.structs import FieldType, ModelStep |
|
|
|
|
|
class ModelStepStateManager: |
|
def __init__(self, max_input_fields: int, max_output_fields: int): |
|
self.max_fields = { |
|
"input": max_input_fields, |
|
"output": max_output_fields, |
|
} |
|
|
|
|
|
def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState: |
|
return ui_state.update(key, value) |
|
|
|
|
|
def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep: |
|
"""Update the step name in state and accordion label.""" |
|
return model_step.update_property("name", value) |
|
|
|
def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep: |
|
return model_step.update_property("temperature", value) |
|
|
|
def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep: |
|
"""Update the model provider in the state.""" |
|
model, provider = get_model_and_provider(value) |
|
return model_step.update({"model": model, "provider": provider}) |
|
|
|
def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep: |
|
"""Update the system prompt in the state.""" |
|
return model_step.update_property("system_prompt", value) |
|
|
|
|
|
def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep: |
|
"""Update a specific field of an input field at the given index.""" |
|
return model_step.update_field("input", index, "name", value) |
|
|
|
def update_input_field_variable(self, model_step: ModelStep, value: str, name: str, index: int) -> ModelStep: |
|
"""Update a specific field of an input field at the given index.""" |
|
if value == UNSELECTED_VAR_NAME: |
|
return model_step.update_field("input", index, "variable", "") |
|
if name == "": |
|
suggested_name = value.split(".", 1)[-1] |
|
logger.info(f"Updating input field variable to {value}. Suggested name: {suggested_name}") |
|
model_step = model_step.update_field("input", index, "name", suggested_name) |
|
return model_step.update_field("input", index, "variable", value) |
|
|
|
def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep: |
|
"""Update a specific field of an input field at the given index.""" |
|
return model_step.update_field("input", index, "description", value) |
|
|
|
def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep: |
|
"""Update a specific field of an output field at the given index.""" |
|
return model_step.update_field("output", index, "name", value) |
|
|
|
def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep: |
|
"""Update a specific field of an output field at the given index.""" |
|
return model_step.update_field("output", index, "type", value) |
|
|
|
def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep: |
|
"""Update a specific field of an output field at the given index.""" |
|
return model_step.update_field("output", index, "variable", value) |
|
|
|
def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep: |
|
"""Update a specific field of an output field at the given index.""" |
|
return model_step.update_field("output", index, "description", value) |
|
|
|
def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]: |
|
fields = model_step.input_fields |
|
updates = [] |
|
for i in range(self.max_fields["input"]): |
|
if i < len(fields): |
|
updates.extend( |
|
[ |
|
gr.update(value=fields[i].name), |
|
gr.update(value=fields[i].variable), |
|
gr.update(value=fields[i].description), |
|
] |
|
) |
|
else: |
|
updates.extend([gr.skip(), gr.skip(), gr.skip()]) |
|
return updates |
|
|
|
def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]: |
|
fields = model_step.output_fields |
|
updates = [] |
|
for i in range(self.max_fields["output"]): |
|
if i < len(fields): |
|
updates.extend( |
|
[ |
|
gr.update(value=fields[i].name), |
|
gr.update(value=fields[i].type), |
|
gr.update(value=fields[i].description), |
|
] |
|
) |
|
else: |
|
updates.extend([gr.skip(), gr.skip(), gr.skip()]) |
|
return updates |
|
|
|
def _add_field( |
|
self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None |
|
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]: |
|
new_step = model_step.add_field(field_type, index, input_var) |
|
fields = new_step.fields(field_type) |
|
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])] |
|
return new_step, len(fields), *row_updates |
|
|
|
def _delete_field( |
|
self, model_step: ModelStep, field_type: FieldType, index: int |
|
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]: |
|
new_step = model_step.delete_field(field_type, index) |
|
fields = new_step.fields(field_type) |
|
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])] |
|
return new_step, len(fields), *row_updates |
|
|
|
|
|
def add_input_field(self, model_step: ModelStep, index: int = -1): |
|
updates = self._add_field(model_step, "input", index, input_var=UNSELECTED_VAR_NAME) |
|
return *updates, *self.make_input_field_updates(model_step) |
|
|
|
def add_output_field(self, model_step: ModelStep, index: int = -1): |
|
updates = self._add_field(model_step, "output", index) |
|
return *updates, *self.make_output_field_updates(model_step) |
|
|
|
def delete_input_field(self, model_step: ModelStep, index: int): |
|
updates = self._delete_field(model_step, "input", index) |
|
return *updates, *self.make_input_field_updates(model_step) |
|
|
|
def delete_output_field(self, model_step: ModelStep, index: int): |
|
updates = self._delete_field(model_step, "output", index) |
|
return *updates, *self.make_output_field_updates(model_step) |
|
|
|
def move_output_field( |
|
self, model_step: ModelStep, index: int, direction: DIRECTIONS |
|
) -> list[gr.State | dict[str, Any]]: |
|
""" |
|
Move an output field in the list either up or down. |
|
|
|
Args: |
|
index: Index of the output field to move |
|
direction: Direction to move the field ('up' or 'down') |
|
|
|
Returns: |
|
list: A list containing [updated_state, field_value_updates...] |
|
""" |
|
new_step = model_step.model_copy() |
|
move_item(new_step.output_fields, index, direction) |
|
|
|
|
|
updates = self.make_output_field_updates(new_step) |
|
|
|
return new_step, *updates |
|
|