Maharshi Gor
Update ModelStep
e1ce295
raw
history blame
7.52 kB
from typing import Any, Literal, Union
import gradio as gr
from loguru import logger
from app_configs import UNSELECTED_VAR_NAME
from components.model_pipeline.state_manager import ModelStepUIState
from components.utils import DIRECTIONS, move_item
from utils import get_model_and_provider
from workflows.structs import FieldType, ModelStep
class ModelStepStateManager:
def __init__(self, max_input_fields: int, max_output_fields: int):
self.max_fields = {
"input": max_input_fields,
"output": max_output_fields,
}
# UI state update functions
def update_ui_state(self, ui_state: ModelStepUIState, key: str, value: Any) -> ModelStepUIState:
return ui_state.update(key, value)
# Property update functions
def update_step_name(self, model_step: ModelStep, value: str) -> ModelStep:
"""Update the step name in state and accordion label."""
return model_step.update_property("name", value)
def update_temperature(self, model_step: ModelStep, value: float) -> ModelStep:
return model_step.update_property("temperature", value)
def update_model_and_provider(self, model_step: ModelStep, value: str) -> ModelStep:
"""Update the model provider in the state."""
model, provider = get_model_and_provider(value)
return model_step.update({"model": model, "provider": provider})
def update_system_prompt(self, model_step: ModelStep, value: str) -> ModelStep:
"""Update the system prompt in the state."""
return model_step.update_property("system_prompt", value)
# Field update functions
def update_input_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an input field at the given index."""
return model_step.update_field("input", index, "name", value)
def update_input_field_variable(self, model_step: ModelStep, value: str, name: str, index: int) -> ModelStep:
"""Update a specific field of an input field at the given index."""
if value == UNSELECTED_VAR_NAME:
return model_step.update_field("input", index, "variable", "")
if name == "":
suggested_name = value.split(".", 1)[-1]
logger.info(f"Updating input field variable to {value}. Suggested name: {suggested_name}")
model_step = model_step.update_field("input", index, "name", suggested_name)
return model_step.update_field("input", index, "variable", value)
def update_input_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an input field at the given index."""
return model_step.update_field("input", index, "description", value)
def update_output_field_name(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "name", value)
def update_output_field_type(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "type", value)
def update_output_field_variable(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "variable", value)
def update_output_field_description(self, model_step: ModelStep, value: str, index: int) -> ModelStep:
"""Update a specific field of an output field at the given index."""
return model_step.update_field("output", index, "description", value)
def make_input_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
fields = model_step.input_fields
updates = []
for i in range(self.max_fields["input"]):
if i < len(fields):
updates.extend(
[
gr.update(value=fields[i].name),
gr.update(value=fields[i].variable),
gr.update(value=fields[i].description),
]
)
else:
updates.extend([gr.skip(), gr.skip(), gr.skip()])
return updates
def make_output_field_updates(self, model_step: ModelStep) -> list[gr.State | dict[str, Any]]:
fields = model_step.output_fields
updates = []
for i in range(self.max_fields["output"]):
if i < len(fields):
updates.extend(
[
gr.update(value=fields[i].name),
gr.update(value=fields[i].type),
gr.update(value=fields[i].description),
]
)
else:
updates.extend([gr.skip(), gr.skip(), gr.skip()])
return updates
def _add_field(
self, model_step: ModelStep, field_type: FieldType, index: int = -1, input_var: str | None = None
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
new_step = model_step.add_field(field_type, index, input_var)
fields = new_step.fields(field_type)
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
return new_step, len(fields), *row_updates
def _delete_field(
self, model_step: ModelStep, field_type: FieldType, index: int
) -> tuple[Union[ModelStep, int, dict[str, Any]], ...]:
new_step = model_step.delete_field(field_type, index)
fields = new_step.fields(field_type)
row_updates = [gr.update(visible=i < len(fields)) for i in range(self.max_fields[field_type])]
return new_step, len(fields), *row_updates
# Field add/delete functions
def add_input_field(self, model_step: ModelStep, index: int = -1):
updates = self._add_field(model_step, "input", index, input_var=UNSELECTED_VAR_NAME)
return *updates, *self.make_input_field_updates(model_step)
def add_output_field(self, model_step: ModelStep, index: int = -1):
updates = self._add_field(model_step, "output", index)
return *updates, *self.make_output_field_updates(model_step)
def delete_input_field(self, model_step: ModelStep, index: int):
updates = self._delete_field(model_step, "input", index)
return *updates, *self.make_input_field_updates(model_step)
def delete_output_field(self, model_step: ModelStep, index: int):
updates = self._delete_field(model_step, "output", index)
return *updates, *self.make_output_field_updates(model_step)
def move_output_field(
self, model_step: ModelStep, index: int, direction: DIRECTIONS
) -> list[gr.State | dict[str, Any]]:
"""
Move an output field in the list either up or down.
Args:
index: Index of the output field to move
direction: Direction to move the field ('up' or 'down')
Returns:
list: A list containing [updated_state, field_value_updates...]
"""
new_step = model_step.model_copy()
move_item(new_step.output_fields, index, direction)
# Update all output fields to reflect the new order
updates = self.make_output_field_updates(new_step)
return new_step, *updates