Maharshi Gor
Restructure output panel rendering. This change fixes logprob issue and
1758388
raw
history blame
8.1 kB
import json
from typing import Literal
import yaml
from app_configs import UNSELECTED_VAR_NAME
from components import typed_dicts as td
from components import utils
from components.structs import ModelStepUIState, PipelineState, PipelineUIState, TossupPipelineState
from workflows.factory import create_new_llm_step
from workflows.structs import Buzzer, BuzzerMethod, ModelStep, TossupWorkflow, Workflow
def get_output_panel_state(workflow: Workflow) -> dict:
state = {
"variables": workflow.get_available_variables(),
"models": workflow.get_step_model_selections(),
"output_models": workflow.get_output_model_selections(),
}
if isinstance(workflow, TossupWorkflow):
state["buzzer"] = workflow.buzzer.model_dump(exclude_defaults=False)
return state
class PipelineStateManager:
"""Manages a pipeline of multiple steps."""
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> PipelineState:
"""Make a state from a state dictionary."""
return PipelineState(**state_dict)
def get_formatted_config(self, state_dict: td.PipelineStateDict, format: Literal["json", "yaml"] = "yaml") -> str:
"""Get the full pipeline configuration."""
state = self.make_pipeline_state(state_dict)
config = state.workflow.model_dump(exclude_defaults=True)
if isinstance(state.workflow, TossupWorkflow):
buzzer_config = state.workflow.buzzer.model_dump(exclude_defaults=False)
config["buzzer"] = buzzer_config
if format == "yaml":
return yaml.dump(config, default_flow_style=False, sort_keys=False, indent=4)
else:
return json.dumps(config, indent=4, sort_keys=False)
def add_step(
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int = -1, name=""
) -> td.PipelineStateDict:
"""Create a new step and return its state."""
state = self.make_pipeline_state(state_dict)
step_id = state.get_new_step_id()
step_name = name or f"Step {state.n_steps + 1}"
new_step = create_new_llm_step(step_id=step_id, name=step_name)
state = state.insert_step(position, new_step)
return state.model_dump(), not pipeline_change
def remove_step(
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int
) -> td.PipelineStateDict:
"""Remove a step from the pipeline."""
state = self.make_pipeline_state(state_dict)
if 0 <= position < state.n_steps:
state = state.remove_step(position)
else:
raise ValueError(f"Invalid step position: {position}")
return state.model_dump(), not pipeline_change
def _move_step(
self, state_dict: td.PipelineStateDict, position: int, direction: Literal["up", "down"]
) -> tuple[td.PipelineStateDict, bool]:
state = self.make_pipeline_state(state_dict)
old_order = list(state.ui_state.step_ids)
utils.move_item(state.ui_state.step_ids, position, direction)
return state.model_dump(), old_order != list(state.ui_state.step_ids)
def move_up(self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int) -> td.PipelineStateDict:
"""Move a step up in the pipeline."""
new_state_dict, change = self._move_step(state_dict, position, "up")
if change:
pipeline_change = not pipeline_change
return new_state_dict, pipeline_change
def move_down(
self, state_dict: td.PipelineStateDict, pipeline_change: bool, position: int
) -> td.PipelineStateDict:
"""Move a step down in the pipeline."""
new_state_dict, change = self._move_step(state_dict, position, "down")
if change:
pipeline_change = not pipeline_change
return new_state_dict, pipeline_change
def update_model_step_state(
self, state_dict: td.PipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
) -> td.PipelineStateDict:
"""Update a particular model step in the pipeline."""
state = self.make_pipeline_state(state_dict)
state = state.update_step(model_step, ui_state)
return state.model_dump()
def update_output_variables(
self, state_dict: td.PipelineStateDict, target: str, produced_variable: str
) -> td.PipelineStateDict:
if produced_variable == UNSELECTED_VAR_NAME:
produced_variable = None
"""Update the output variables for a step."""
state = self.make_pipeline_state(state_dict)
state.workflow.outputs[target] = produced_variable
return state.model_dump()
def update_model_step_ui(
self, state_dict: td.PipelineStateDict, step_ui: ModelStepUIState, step_id: str
) -> td.PipelineStateDict:
"""Update a step in the pipeline."""
state = self.make_pipeline_state(state_dict)
state.ui_state.steps[step_id] = step_ui.model_copy()
return state.model_dump()
def get_all_variables(self, state_dict: td.PipelineStateDict, model_step_id: str | None = None) -> list[str]:
"""Get all variables from all steps."""
return self.make_pipeline_state(state_dict)
def parse_yaml_workflow(self, yaml_str: str) -> Workflow:
"""Parse a YAML workflow."""
workflow = yaml.safe_load(yaml_str)
return Workflow(**workflow)
def update_workflow_from_code(self, yaml_str: str) -> td.PipelineStateDict:
"""Update a workflow from a YAML string."""
workflow = self.parse_yaml_workflow(yaml_str)
return PipelineState.from_workflow(workflow).model_dump()
class TossupPipelineStateManager(PipelineStateManager):
"""Manages a tossup pipeline state."""
def make_pipeline_state(self, state_dict: td.PipelineStateDict) -> TossupPipelineState:
"""Make a state from a state dictionary."""
return TossupPipelineState(**state_dict)
def parse_yaml_workflow(self, yaml_str: str) -> TossupWorkflow:
"""Parse a YAML workflow."""
workflow = yaml.safe_load(yaml_str)
return TossupWorkflow(**workflow)
def update_workflow_from_code(self, yaml_str: str, change_state: bool) -> tuple[td.PipelineStateDict, bool]:
"""Update a workflow from a YAML string."""
workflow = self.parse_yaml_workflow(yaml_str)
return TossupPipelineState.from_workflow(workflow).model_dump(), not change_state
def update_model_step_state(
self, state_dict: td.TossupPipelineStateDict, model_step: ModelStep, ui_state: ModelStepUIState
) -> td.TossupPipelineStateDict:
"""Update a particular model step in the pipeline."""
state = self.make_pipeline_state(state_dict)
state = state.update_step(model_step, ui_state)
state.workflow = state.workflow.refresh_buzzer()
return state.model_dump()
def update_output_variables(
self, state_dict: td.TossupPipelineStateDict, target: str, produced_variable: str
) -> td.TossupPipelineStateDict:
if produced_variable == UNSELECTED_VAR_NAME:
produced_variable = None
"""Update the output variables for a step."""
state = self.make_pipeline_state(state_dict)
state.workflow.outputs[target] = produced_variable
state.workflow = state.workflow.refresh_buzzer()
return state.model_dump()
def update_buzzer(
self,
state_dict: td.TossupPipelineStateDict,
confidence_threshold: float,
method: str,
tokens_prob: float | None,
) -> td.TossupPipelineStateDict:
"""Update the buzzer."""
state = self.make_pipeline_state(state_dict)
prob_threshold = float(tokens_prob) if tokens_prob and tokens_prob > 0 else None
if method == BuzzerMethod.OR and prob_threshold is None:
prob_threshold = 0.0
state.workflow.buzzer = Buzzer(
method=method, confidence_threshold=confidence_threshold, prob_threshold=prob_threshold
)
return state.model_dump()