|
import json |
|
from typing import Any |
|
|
|
import gradio as gr |
|
from gradio.components import FormComponent |
|
|
|
from app_configs import UNSELECTED_VAR_NAME |
|
from components.model_pipeline.state_manager import ModelStepUIState, PipelineStateManager |
|
from components.typed_dicts import PipelineStateDict |
|
from utils import get_full_model_name |
|
from workflows.structs import ModelStep |
|
|
|
from .state_manager import ModelStepStateManager |
|
from .ui_components import InputRowButtonGroup, OutputRowButtonGroup |
|
|
|
|
|
def _make_accordion_label(model_step: ModelStep): |
|
name = model_step.name if model_step.name else "Untitled" |
|
input_field_names = [field.name for field in model_step.input_fields] |
|
inputs_str = ", ".join(input_field_names) |
|
output_field_names = [field.name for field in model_step.output_fields] |
|
outputs_str = ", ".join(output_field_names) |
|
return "{}: {} ({}) → ({})".format(model_step.id, name, inputs_str, outputs_str) |
|
|
|
|
|
class ModelStepComponent(FormComponent): |
|
""" |
|
A custom Gradio component representing a single Step in a pipeline. |
|
It contains: |
|
1. Model Provider & System Prompt |
|
2. Inputs – fields with name, description, and variable used |
|
3. Outputs – fields with name, description, and variable used |
|
|
|
Listens to events: |
|
- on_model_step_change |
|
- on_ui_change |
|
""" |
|
|
|
def __init__( |
|
self, |
|
value: ModelStep | gr.State, |
|
ui_state: ModelStepUIState | gr.State | None = None, |
|
model_options: list[str] | None = None, |
|
input_variables: list[str] | None = None, |
|
max_input_fields=5, |
|
max_output_fields=5, |
|
max_temperature=5.0, |
|
pipeline_state_manager: PipelineStateManager | None = None, |
|
**kwargs, |
|
): |
|
self.max_fields = { |
|
"input": max_input_fields, |
|
"output": max_output_fields, |
|
} |
|
self.max_temperature = max_temperature |
|
self.model_options = model_options |
|
self.input_variables = [UNSELECTED_VAR_NAME] + input_variables |
|
self.sm = ModelStepStateManager(max_input_fields, max_output_fields) |
|
self.pipeline_sm: PipelineStateManager = pipeline_state_manager |
|
|
|
self.model_step_state = gr.State(value) |
|
ui_state = ui_state or ModelStepUIState() |
|
if not isinstance(ui_state, gr.State): |
|
ui_state = gr.State(ui_state) |
|
self.ui_state: gr.State = ui_state |
|
|
|
self.inputs_count_state = gr.State(len(value.input_fields)) |
|
self.outputs_count_state = gr.State(len(value.output_fields)) |
|
|
|
|
|
self.accordion = None |
|
self.ui = None |
|
self.step_name_input = None |
|
self.model_selection = None |
|
self.system_prompt = None |
|
self.input_rows = [] |
|
self.output_rows = [] |
|
|
|
super().__init__(**kwargs) |
|
|
|
self.setup_event_listeners() |
|
|
|
@property |
|
def model_step(self) -> ModelStep: |
|
return self.model_step_state.value |
|
|
|
@property |
|
def step_id(self) -> str: |
|
return self.model_step.id |
|
|
|
def get_step_config(self) -> dict: |
|
return self.model_step.model_dump() |
|
|
|
|
|
def is_open(self) -> bool: |
|
return self.ui_state.value.expanded |
|
|
|
def get_active_tab(self) -> str: |
|
"""Get the current active tab.""" |
|
return self.ui_state.value.active_tab |
|
|
|
def _render_input_row(self, i: int) -> tuple[gr.Row, tuple, tuple]: |
|
"""Render a single input row at index i.""" |
|
inputs = self.model_step.input_fields |
|
is_visible = i < len(inputs) |
|
label_visible = i == 0 |
|
disable_delete = i == 0 and len(inputs) == 1 |
|
initial_name = inputs[i].name if is_visible else "" |
|
initial_desc = inputs[i].description if is_visible else "" |
|
initial_var = inputs[i].variable or UNSELECTED_VAR_NAME if is_visible else UNSELECTED_VAR_NAME |
|
|
|
with gr.Row(visible=is_visible, elem_classes="field-row form") as row: |
|
button_group = InputRowButtonGroup(disable_delete=disable_delete) |
|
|
|
inp_var = gr.Dropdown( |
|
choices=self.input_variables, |
|
label="Variable Used", |
|
value=initial_var, |
|
elem_classes="field-variable", |
|
scale=1, |
|
show_label=label_visible, |
|
) |
|
inp_name = gr.Textbox( |
|
label="Input Name", |
|
placeholder="Field name", |
|
value=initial_name, |
|
elem_classes="field-name", |
|
scale=1, |
|
show_label=label_visible, |
|
) |
|
inp_desc = gr.Textbox( |
|
label="Description", |
|
placeholder="Field description", |
|
value=initial_desc, |
|
elem_classes="field-description", |
|
scale=3, |
|
show_label=label_visible, |
|
) |
|
fields = (inp_name, inp_var, inp_desc) |
|
|
|
return row, fields, button_group |
|
|
|
def _render_output_row(self, i: int) -> tuple[gr.Row, tuple, tuple]: |
|
"""Render a single output row at index i.""" |
|
outputs = self.model_step.output_fields |
|
is_visible = i < len(outputs) |
|
label_visible = i == 0 |
|
disable_delete = i == 0 and len(outputs) == 1 |
|
initial_name = outputs[i].name if is_visible else "" |
|
initial_desc = outputs[i].description if is_visible else "" |
|
initial_type = outputs[i].type if is_visible else "str" |
|
with gr.Row(visible=is_visible, elem_classes="field-row") as row: |
|
button_group = OutputRowButtonGroup(disable_delete=disable_delete) |
|
|
|
out_name = gr.Textbox( |
|
label="Output Field", |
|
placeholder="Variable identifier", |
|
value=initial_name, |
|
elem_classes="field-name", |
|
scale=1, |
|
show_label=label_visible, |
|
) |
|
out_type = gr.Dropdown( |
|
choices=["str", "int", "float", "bool"], |
|
allow_custom_value=True, |
|
label="Type", |
|
value=initial_type, |
|
elem_classes="field-type", |
|
scale=0, |
|
show_label=label_visible, |
|
interactive=True, |
|
) |
|
out_desc = gr.Textbox( |
|
label="Description", |
|
placeholder="Field description", |
|
value=initial_desc, |
|
elem_classes="field-description", |
|
scale=3, |
|
show_label=label_visible, |
|
) |
|
|
|
fields = (out_name, out_type, out_desc) |
|
return row, fields, button_group |
|
|
|
def _render_prompt_tab_content(self): |
|
self.system_prompt = gr.Textbox( |
|
label="System Prompt", |
|
placeholder="Enter the system prompt for this step", |
|
lines=5, |
|
value=self.model_step.system_prompt, |
|
elem_classes="system-prompt", |
|
) |
|
|
|
def _render_inputs_tab_content(self): |
|
with gr.Column(variant="panel", elem_classes="fields-panel") as self.inputs_column: |
|
|
|
for i in range(self.max_fields["input"]): |
|
row = self._render_input_row(i) |
|
self.input_rows.append(row) |
|
|
|
def _render_outputs_tab_content(self): |
|
with gr.Column(variant="panel", elem_classes="fields-panel") as self.outputs_column: |
|
|
|
for i in range(self.max_fields["output"]): |
|
row = self._render_output_row(i) |
|
self.output_rows.append(row) |
|
|
|
def _render_tab_content(self, tab_id: str): |
|
if tab_id == "model-tab": |
|
self._render_prompt_tab_content() |
|
elif tab_id == "inputs-tab": |
|
self._render_inputs_tab_content() |
|
elif tab_id == "outputs-tab": |
|
self._render_outputs_tab_content() |
|
|
|
def _render_header(self, model_options: tuple[str]): |
|
|
|
with gr.Row(elem_classes="step-header-row"): |
|
self.step_name_input = gr.Textbox( |
|
label="", |
|
value=self.model_step.name, |
|
elem_classes="step-name", |
|
show_label=False, |
|
placeholder="Model name...", |
|
) |
|
unselected_choice = "Select Model..." |
|
current_value = ( |
|
get_full_model_name(self.model_step.model, self.model_step.provider) |
|
if self.model_step.model |
|
else unselected_choice |
|
) |
|
self.model_selection = gr.Dropdown( |
|
choices=[unselected_choice] + model_options, |
|
label="Model Provider", |
|
show_label=False, |
|
value=current_value, |
|
elem_classes="model-dropdown", |
|
scale=1, |
|
) |
|
self.temperature_slider = gr.Slider( |
|
value=self.model_step.temperature, |
|
minimum=0.0, |
|
maximum=self.max_temperature, |
|
step=0.05, |
|
info="Temperature", |
|
show_label=False, |
|
show_reset_button=False, |
|
) |
|
|
|
def render(self): |
|
"""Render the component UI""" |
|
|
|
self.input_rows = [] |
|
self.output_rows = [] |
|
self.tabs = {} |
|
|
|
|
|
accordion_label = _make_accordion_label(self.model_step) |
|
self.accordion = gr.Accordion(label=accordion_label, open=self.is_open(), elem_classes="step-accordion") |
|
|
|
|
|
with self.accordion: |
|
self._render_header(self.model_options) |
|
|
|
|
|
selected_tab = self.get_active_tab() |
|
with gr.Tabs(elem_classes="step-tabs", selected=selected_tab): |
|
tab_ids = ("model-tab", "inputs-tab", "outputs-tab") |
|
tab_labels = ("Model", "Inputs", "Outputs") |
|
for tab_id, label in zip(tab_ids, tab_labels): |
|
with gr.TabItem(label, elem_classes="tab-content", id=tab_id) as tab: |
|
self._render_tab_content(tab_id) |
|
self.tabs[tab_id] = tab |
|
|
|
return self.accordion |
|
|
|
def _setup_event_listeners_for_view_change(self): |
|
for tab_id, tab in self.tabs.items(): |
|
tab.select( |
|
fn=self.sm.update_ui_state, |
|
inputs=[self.ui_state, gr.State("active_tab"), gr.State(tab_id)], |
|
outputs=[self.ui_state], |
|
) |
|
self.accordion.collapse( |
|
fn=self.sm.update_ui_state, |
|
inputs=[self.ui_state, gr.State("expanded"), gr.State(False)], |
|
outputs=[self.ui_state], |
|
) |
|
self.accordion.expand( |
|
fn=self.sm.update_ui_state, |
|
inputs=[self.ui_state, gr.State("expanded"), gr.State(True)], |
|
outputs=[self.ui_state], |
|
) |
|
|
|
def _setup_event_listeners_model_tab(self): |
|
|
|
self.step_name_input.blur( |
|
fn=self._update_state_and_label, |
|
inputs=[self.model_step_state, self.step_name_input], |
|
outputs=[self.model_step_state, self.accordion], |
|
) |
|
|
|
self.temperature_slider.release( |
|
fn=self.sm.update_temperature, |
|
inputs=[self.model_step_state, self.temperature_slider], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
|
|
self.model_selection.input( |
|
fn=self.sm.update_model_and_provider, |
|
inputs=[self.model_step_state, self.model_selection], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
self.system_prompt.blur( |
|
fn=self.sm.update_system_prompt, |
|
inputs=[self.model_step_state, self.system_prompt], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
def _setup_event_listeners_inputs_tab(self): |
|
|
|
for i, (row, fields, button_group) in enumerate(self.input_rows): |
|
inp_name, inp_var, inp_desc = fields |
|
row_index = gr.State(i) |
|
|
|
|
|
inp_name.blur( |
|
fn=self.sm.update_input_field_name, |
|
inputs=[self.model_step_state, inp_name, row_index], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
inp_var.change( |
|
fn=self.sm.update_input_field_variable, |
|
inputs=[self.model_step_state, inp_var, inp_name, row_index], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
inp_desc.blur( |
|
fn=self.sm.update_input_field_description, |
|
inputs=[self.model_step_state, inp_desc, row_index], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
rows = [row for (row, _, _) in self.input_rows] |
|
input_fields = [field for (_, fields, _) in self.input_rows for field in fields] |
|
|
|
|
|
button_group.delete( |
|
fn=self.sm.delete_input_field, |
|
inputs=[self.model_step_state, row_index], |
|
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields, |
|
) |
|
|
|
button_group.add( |
|
fn=self.sm.add_input_field, |
|
inputs=[self.model_step_state, row_index], |
|
outputs=[self.model_step_state, self.inputs_count_state] + rows + input_fields, |
|
) |
|
|
|
def _setup_event_listeners_outputs_tab(self): |
|
|
|
for i, (row, fields, button_group) in enumerate(self.output_rows): |
|
out_name, out_type, out_desc = fields |
|
|
|
row_index = gr.State(i) |
|
|
|
|
|
out_name.blur( |
|
fn=self.sm.update_output_field_name, |
|
inputs=[self.model_step_state, out_name, row_index], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
out_type.change( |
|
fn=self.sm.update_output_field_type, |
|
inputs=[self.model_step_state, out_type, row_index], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
out_desc.blur( |
|
fn=self.sm.update_output_field_description, |
|
inputs=[self.model_step_state, out_desc, row_index], |
|
outputs=[self.model_step_state], |
|
) |
|
|
|
rows = [row for (row, _, _) in self.output_rows] |
|
output_fields = [field for (_, fields, _) in self.output_rows for field in fields] |
|
|
|
|
|
button_group.delete( |
|
fn=self.sm.delete_output_field, |
|
inputs=[self.model_step_state, row_index], |
|
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields, |
|
) |
|
|
|
button_group.add( |
|
fn=self.sm.add_output_field, |
|
inputs=[self.model_step_state, row_index], |
|
outputs=[self.model_step_state, self.outputs_count_state] + rows + output_fields, |
|
) |
|
|
|
button_group.up( |
|
fn=self.sm.move_output_field, |
|
inputs=[self.model_step_state, row_index, gr.State("up")], |
|
outputs=[self.model_step_state] + output_fields, |
|
) |
|
|
|
button_group.down( |
|
fn=self.sm.move_output_field, |
|
inputs=[self.model_step_state, row_index, gr.State("down")], |
|
outputs=[self.model_step_state] + output_fields, |
|
) |
|
|
|
|
|
def setup_event_listeners(self): |
|
"""Set up all event listeners for this component""" |
|
self._setup_event_listeners_for_view_change() |
|
self._setup_event_listeners_model_tab() |
|
self._setup_event_listeners_inputs_tab() |
|
self._setup_event_listeners_outputs_tab() |
|
|
|
def state_str(x, limited: bool = False): |
|
d = x.model_dump() |
|
if limited: |
|
d = {k: d[k] for k in {"name", "temperature"}} |
|
return json.dumps(d, indent=2) |
|
|
|
def log_step_states(x, y, src: str): |
|
print(f"{src} triggered! UI:\n{state_str(x)}\n\nData:\n{state_str(y, True)}") |
|
print("--------------------------------") |
|
print(f"self.model_step_state: \n{self.get_step_config()}") |
|
print("--------------------------------") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_model_step_change(self, fn, inputs, outputs): |
|
"""Set up an event listener for the model change event.""" |
|
return self.model_step_state.change(fn, inputs, outputs) |
|
|
|
def on_ui_change(self, fn, inputs, outputs): |
|
"""Set up an event listener for the UI change event.""" |
|
return self.ui_state.change(fn, inputs, outputs) |
|
|
|
def _update_state_and_label(self, model_step: ModelStep, name: str): |
|
"""Update both the state and the accordion label.""" |
|
new_model_step = self.sm.update_step_name(model_step, name) |
|
new_label = _make_accordion_label(new_model_step) |
|
return new_model_step, gr.update(label=new_label) |
|
|
|
def refresh_variable_dropdowns(self, pipeline_state_dict: PipelineStateDict): |
|
|
|
"""Refresh the variable dropdown options in all input rows.""" |
|
variable_choices = [] |
|
if self.pipeline_sm is not None: |
|
variable_choices = self.pipeline_sm.get_all_variables(pipeline_state_dict) |
|
|
|
for _, fields, _ in self.input_rows: |
|
_, inp_var, _ = fields |
|
inp_var.update(choices=variable_choices) |
|
|
|
def _update_model_and_refresh_ui(self, updated_model_step): |
|
"""Update the model step state and refresh UI elements that depend on it.""" |
|
self.model_step_state.value = updated_model_step |
|
|
|
new_label = _make_accordion_label(updated_model_step) |
|
if self.accordion: |
|
self.accordion.update(label=new_label) |
|
return updated_model_step |
|
|