Maharshi Gor
Workflow update.
1eeda1d
# %%
from copy import deepcopy
from enum import Enum
from typing import Any, Literal, Optional
import numpy as np
from pydantic import BaseModel, Field, model_validator
from .configs import AVAILABLE_MODELS
"""
Core data structures for defining workflows and their components.
This module defines the primary classes used to model workflows, steps, and their
input/output fields. These data structures serve as the foundation for workflow
definition, validation, and execution throughout the workflows package.
The primary components are:
- InputField: Represents an input to a model step with name and source variable
- OutputField: Represents an output from a model step with name and type
- ModelStep: Represents a single step in a workflow with inputs and outputs
- Workflow: A collection of interconnected steps with defined inputs and outputs
All classes use Pydantic's BaseModel for validation and serialization support.
"""
FieldType = Literal["input", "output"]
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"]
"""Supported field types for input and output fields"""
class InputField(BaseModel):
"""
Defines an input field for a model step.
An input field specifies what data a step requires, where it comes from,
and optional pre-processing to apply before use.
Attributes:
name: The name of the input field within the step's context
description: Human-readable description of the input's purpose
variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name)
func: Optional function name to transform the input value before use
"""
name: str
description: str
variable: str
# function to call on the input before passing it to the model
func: str | None = None
class Config:
frozen = True
class OutputField(BaseModel):
"""
Defines an output field produced by a model step.
An output field specifies a value that the step will produce, including
its data type and optional post-processing.
Attributes:
name: The name of the output field within the step's context
description: Human-readable description of the output's purpose
type: The data type of the output (one of SUPPORTED_TYPES)
func: Optional function name to transform the raw output value
"""
name: str
type: SUPPORTED_TYPES = Field(default="str")
description: str
# function to call on the output string from the model
func: str | None = None
class Config:
frozen = True
class CallType(str, Enum):
LLM = "llm"
SEARCH = "search"
PYTHON_FUNC = "python_func"
class ModelStep(BaseModel):
"""
Represents a single step in a workflow.
A model step encapsulates the details of a specific operation within a workflow,
including what model to use, what inputs it requires, and what outputs it produces.
Attributes:
id: Unique identifier for this step within a workflow
model: The model to use for this step (e.g., "gpt-4")
provider: The provider of the model (e.g., "openai")
call_type: The type of operation (e.g., "llm", "search")
system_prompt: Instructions for the model
input_fields: List of input fields required by this step
output_fields: List of output fields produced by this step
"""
id: str
name: str
model: str
provider: str
call_type: CallType = CallType.LLM
# TODO: Validate that this is not None for call_type = llm
temperature: Optional[float] = None
system_prompt: str
input_fields: list[InputField]
output_fields: list[OutputField]
class Config:
use_enum_values = True
def fields(self, field_type: FieldType) -> list[InputField | OutputField]:
return self.input_fields if field_type == "input" else self.output_fields
def get_full_model_name(self) -> str:
return f"{self.provider}/{self.model}"
def get_produced_variables(self) -> list[str]:
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name]
def update(self, update: dict[str, Any]) -> "ModelStep":
"""Returns a new copy with the updated properties."""
return self.model_copy(update=update)
def update_property(self, field: str, value: Any) -> "ModelStep":
"Update the `field` key of the model step with `value`."
return self.update({field: value})
def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep":
"""Update a specific field of an input or output field at the given index."""
if field_type == "input":
fields = self.input_fields
elif field_type == "output":
fields = self.output_fields
else:
raise ValueError(f"Invalid field type: {field_type}")
if index < len(fields):
fields[index] = fields[index].model_copy(update={key: value})
return self.model_copy()
@staticmethod
def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField:
if field_type == "input":
return InputField(name="", description="", variable=input_var)
elif field_type == "output":
return OutputField(name="", description="")
else:
raise ValueError(f"Invalid field type: {field_type}")
def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep":
"""Add a new field to the state and update visibility.
Args:
field_type: Type of field to add ('input' or 'output').
index: Position to insert the new field (-1 to append).
Returns:
A new ModelStep with the updated fields.
"""
if field_type == "input":
fields = deepcopy(self.input_fields)
new_field = ModelStep.create_new_field(field_type, input_var)
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
return self.model_copy(update={"input_fields": fields})
else:
fields = deepcopy(self.output_fields)
new_field = ModelStep.create_new_field(field_type)
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field)
return self.model_copy(update={"output_fields": fields})
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep":
"""
Delete an input or output field from the state and update visibility.
Args:
field_type: Type of field to delete ('input' or 'output').
index: Index of the field to delete. [-1 to delete the last field]
Returns:
A new ModelStep with the updated fields.
"""
fields = self.input_fields if field_type == "input" else self.output_fields
fields = deepcopy(fields)
fields.pop(index)
return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields})
class Workflow(BaseModel):
"""
Represents a complete workflow composed of interconnected steps.
A workflow defines a directed acyclic graph of model steps, where outputs
from earlier steps can be used as inputs to later steps.
Attributes:
inputs: List of input variables required by the workflow
outputs: List of output variables produced by the workflow
steps: Dictionary mapping step IDs to ModelStep instances
The inputs and outputs lists use the format "{step_id}.{field_name}"
to uniquely identify variables within the workflow.
"""
# variables of form {node}.{field}
inputs: list[str] = Field(default_factory=list)
# variables of form {node}.{field}
outputs: dict[str, str | None] = Field(default_factory=dict)
steps: dict[str, ModelStep] = Field(default_factory=dict)
def model_dump(self, *args, **kwargs):
data = super().model_dump(*args, **kwargs)
if "steps" in data:
data["steps"] = list(data["steps"].values())
return data
@model_validator(mode="before")
def dictify_steps(cls, data):
if "steps" in data and isinstance(data["steps"], list):
steps_dict = {}
for step in data["steps"]:
if isinstance(step, ModelStep):
step_id = step.id
else:
step_id = step["id"]
if step_id in steps_dict:
raise ValueError(f"Duplicate step ID: {step_id}")
steps_dict[step_id] = step
data["steps"] = steps_dict
return data
def get_step_variables(self, step_id: str) -> list[str]:
"""Get all variables from a specific step."""
step = self.steps[step_id]
variables = []
for output in step.output_fields:
if output.name == "":
continue
output_var = f"{step.id}.{output.name}"
variables.append(output_var)
return variables
def get_available_variables(self) -> list[str]:
"""Get all output variables from all steps."""
variables = set(self.inputs)
for step in self.steps.values():
variables.update(self.get_step_variables(step.id))
return list(variables)
def get_step_model_selections(self) -> dict[str, str]:
"""Get all model selections for all steps."""
return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()}
def get_output_model_selections(self) -> dict[str, str]:
"""Get all output model selections for all steps."""
return {
output_var: target_var.split(".")[0] if target_var else None
for output_var, target_var in self.outputs.items()
}
# Step update method
def add_step(self, step: ModelStep) -> "Workflow":
"""Add a step to the workflow."""
steps = self.steps | {step.id: step}
return self.model_copy(update={"steps": steps})
def remove_step(self, step_id: str) -> "Workflow":
"""Remove a step from the workflow."""
self.steps.pop(step_id)
workflow = self.model_copy(update={"steps": self.steps})
workflow.refresh_output_variables()
return workflow
def update_step(self, step: ModelStep) -> "Workflow":
"""Update a step in the workflow."""
self.steps[step.id] = step
steps = self.steps | {step.id: step}
workflow = self.model_copy(update={"steps": steps})
workflow.refresh_output_variables()
return workflow
# Output variables
def refresh_output_variables(self) -> "Workflow":
"""Refresh the output variables for the workflow."""
produced_variables = self.get_available_variables()
self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()}
return self
class BuzzerMethod(str, Enum):
AND = "AND"
OR = "OR"
class Buzzer(BaseModel):
"""Configuration for when to buzz in a tossup question."""
method: BuzzerMethod = BuzzerMethod.AND # Logic to combine thresholds
confidence_threshold: float = Field(default=0.5, ge=0.0, le=1.0) # Minimum confidence to trigger a buzz
prob_threshold: float | None = None # Optional log probability threshold
class Config:
use_enum_values = True
frozen = True
def update(self, **kwargs) -> "Buzzer":
"""Update the buzzer with the given kwargs."""
return self.model_copy(update=kwargs)
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool:
"""Run the buzzer logic."""
if logprob is not None and prob is not None:
raise ValueError("Cannot provide both logprob and prob")
if self.prob_threshold is None:
return confidence >= self.confidence_threshold
if logprob is None and prob is None:
raise ValueError("Must provide either logprob or prob if prob_threshold is not None")
prob = prob or float(np.exp(logprob))
if self.method == BuzzerMethod.AND:
return confidence >= self.confidence_threshold and prob >= self.prob_threshold
elif self.method == BuzzerMethod.OR:
return confidence >= self.confidence_threshold or prob >= self.prob_threshold
else:
raise ValueError(f"Invalid buzzer method: {self.method}")
@model_validator(mode="after")
def validate_method_with_log_prob(cls, data):
"""Validate that if prob_threshold is None, method must be 'and'."""
if data.prob_threshold is None and data.method != BuzzerMethod.AND:
raise ValueError("If prob_threshold is None, method must be 'and'")
return data
class TossupWorkflow(Workflow):
"""Workflow specialized for tossup questions with buzzing capability."""
buzzer: Buzzer = Field(default_factory=Buzzer)
def get_answer_model(self, answer_var: str | None = None) -> str | None:
answer_var = answer_var or self.outputs["answer"]
if answer_var is None:
return None
step_id = answer_var.split(".")[0]
return self.steps[step_id].get_full_model_name()
def is_token_probs_supported(self, answer_var: str | None = None) -> bool:
model_name = self.get_answer_model(answer_var)
if model_name is None:
return True
return AVAILABLE_MODELS[model_name].get("logprobs", False)
def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow":
"""Update the buzzer."""
return self.model_copy(update={"buzzer": buzzer})
def refresh_buzzer(self) -> "TossupWorkflow":
if not self.is_token_probs_supported():
return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND"))
return self