|
|
|
from copy import deepcopy |
|
from enum import Enum |
|
from typing import Any, Literal, Optional |
|
|
|
import numpy as np |
|
from pydantic import BaseModel, Field, model_validator |
|
|
|
from .configs import AVAILABLE_MODELS |
|
|
|
""" |
|
Core data structures for defining workflows and their components. |
|
|
|
This module defines the primary classes used to model workflows, steps, and their |
|
input/output fields. These data structures serve as the foundation for workflow |
|
definition, validation, and execution throughout the workflows package. |
|
|
|
The primary components are: |
|
- InputField: Represents an input to a model step with name and source variable |
|
- OutputField: Represents an output from a model step with name and type |
|
- ModelStep: Represents a single step in a workflow with inputs and outputs |
|
- Workflow: A collection of interconnected steps with defined inputs and outputs |
|
|
|
All classes use Pydantic's BaseModel for validation and serialization support. |
|
""" |
|
FieldType = Literal["input", "output"] |
|
|
|
|
|
SUPPORTED_TYPES = Literal["str", "int", "float", "bool", "list[str]", "list[int]", "list[float]", "list[bool]"] |
|
"""Supported field types for input and output fields""" |
|
|
|
|
|
class InputField(BaseModel): |
|
""" |
|
Defines an input field for a model step. |
|
|
|
An input field specifies what data a step requires, where it comes from, |
|
and optional pre-processing to apply before use. |
|
|
|
Attributes: |
|
name: The name of the input field within the step's context |
|
description: Human-readable description of the input's purpose |
|
variable: Reference to the source variable (format: "{step_id}.{field_name}" or external input name) |
|
func: Optional function name to transform the input value before use |
|
""" |
|
|
|
name: str |
|
description: str |
|
variable: str |
|
|
|
|
|
func: str | None = None |
|
|
|
class Config: |
|
frozen = True |
|
|
|
|
|
class OutputField(BaseModel): |
|
""" |
|
Defines an output field produced by a model step. |
|
|
|
An output field specifies a value that the step will produce, including |
|
its data type and optional post-processing. |
|
|
|
Attributes: |
|
name: The name of the output field within the step's context |
|
description: Human-readable description of the output's purpose |
|
type: The data type of the output (one of SUPPORTED_TYPES) |
|
func: Optional function name to transform the raw output value |
|
""" |
|
|
|
name: str |
|
type: SUPPORTED_TYPES = Field(default="str") |
|
description: str |
|
|
|
|
|
func: str | None = None |
|
|
|
class Config: |
|
frozen = True |
|
|
|
|
|
class CallType(str, Enum): |
|
LLM = "llm" |
|
SEARCH = "search" |
|
PYTHON_FUNC = "python_func" |
|
|
|
|
|
class ModelStep(BaseModel): |
|
""" |
|
Represents a single step in a workflow. |
|
|
|
A model step encapsulates the details of a specific operation within a workflow, |
|
including what model to use, what inputs it requires, and what outputs it produces. |
|
|
|
Attributes: |
|
id: Unique identifier for this step within a workflow |
|
model: The model to use for this step (e.g., "gpt-4") |
|
provider: The provider of the model (e.g., "openai") |
|
call_type: The type of operation (e.g., "llm", "search") |
|
system_prompt: Instructions for the model |
|
input_fields: List of input fields required by this step |
|
output_fields: List of output fields produced by this step |
|
""" |
|
|
|
id: str |
|
name: str |
|
model: str |
|
provider: str |
|
call_type: CallType = CallType.LLM |
|
|
|
|
|
temperature: Optional[float] = None |
|
|
|
system_prompt: str |
|
input_fields: list[InputField] |
|
output_fields: list[OutputField] |
|
|
|
class Config: |
|
use_enum_values = True |
|
|
|
def fields(self, field_type: FieldType) -> list[InputField | OutputField]: |
|
return self.input_fields if field_type == "input" else self.output_fields |
|
|
|
def get_full_model_name(self) -> str: |
|
return f"{self.provider}/{self.model}" |
|
|
|
def get_produced_variables(self) -> list[str]: |
|
return [f"{self.id}.{field.name}" for field in self.output_fields if field.name] |
|
|
|
def update(self, update: dict[str, Any]) -> "ModelStep": |
|
"""Returns a new copy with the updated properties.""" |
|
return self.model_copy(update=update) |
|
|
|
def update_property(self, field: str, value: Any) -> "ModelStep": |
|
"Update the `field` key of the model step with `value`." |
|
return self.update({field: value}) |
|
|
|
def update_field(self, field_type: FieldType, index: int, key: str, value: str) -> "ModelStep": |
|
"""Update a specific field of an input or output field at the given index.""" |
|
if field_type == "input": |
|
fields = self.input_fields |
|
elif field_type == "output": |
|
fields = self.output_fields |
|
else: |
|
raise ValueError(f"Invalid field type: {field_type}") |
|
|
|
if index < len(fields): |
|
fields[index] = fields[index].model_copy(update={key: value}) |
|
return self.model_copy() |
|
|
|
@staticmethod |
|
def create_new_field(field_type: FieldType, input_var: str | None = None) -> InputField | OutputField: |
|
if field_type == "input": |
|
return InputField(name="", description="", variable=input_var) |
|
elif field_type == "output": |
|
return OutputField(name="", description="") |
|
else: |
|
raise ValueError(f"Invalid field type: {field_type}") |
|
|
|
def add_field(self, field_type: FieldType, index: int = -1, input_var: str | None = None) -> "ModelStep": |
|
"""Add a new field to the state and update visibility. |
|
|
|
Args: |
|
field_type: Type of field to add ('input' or 'output'). |
|
index: Position to insert the new field (-1 to append). |
|
Returns: |
|
A new ModelStep with the updated fields. |
|
""" |
|
if field_type == "input": |
|
fields = deepcopy(self.input_fields) |
|
new_field = ModelStep.create_new_field(field_type, input_var) |
|
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field) |
|
return self.model_copy(update={"input_fields": fields}) |
|
else: |
|
fields = deepcopy(self.output_fields) |
|
new_field = ModelStep.create_new_field(field_type) |
|
fields.insert(index + 1, new_field) if index != -1 else fields.append(new_field) |
|
return self.model_copy(update={"output_fields": fields}) |
|
|
|
def delete_field(self, field_type: FieldType, index: int) -> "ModelStep": |
|
""" |
|
Delete an input or output field from the state and update visibility. |
|
|
|
Args: |
|
field_type: Type of field to delete ('input' or 'output'). |
|
index: Index of the field to delete. [-1 to delete the last field] |
|
|
|
Returns: |
|
A new ModelStep with the updated fields. |
|
""" |
|
fields = self.input_fields if field_type == "input" else self.output_fields |
|
fields = deepcopy(fields) |
|
fields.pop(index) |
|
return self.model_copy(update={"input_fields": fields} if field_type == "input" else {"output_fields": fields}) |
|
|
|
|
|
class Workflow(BaseModel): |
|
""" |
|
Represents a complete workflow composed of interconnected steps. |
|
|
|
A workflow defines a directed acyclic graph of model steps, where outputs |
|
from earlier steps can be used as inputs to later steps. |
|
|
|
Attributes: |
|
inputs: List of input variables required by the workflow |
|
outputs: List of output variables produced by the workflow |
|
steps: Dictionary mapping step IDs to ModelStep instances |
|
|
|
The inputs and outputs lists use the format "{step_id}.{field_name}" |
|
to uniquely identify variables within the workflow. |
|
""" |
|
|
|
|
|
inputs: list[str] = Field(default_factory=list) |
|
|
|
|
|
outputs: dict[str, str | None] = Field(default_factory=dict) |
|
steps: dict[str, ModelStep] = Field(default_factory=dict) |
|
|
|
def model_dump(self, *args, **kwargs): |
|
data = super().model_dump(*args, **kwargs) |
|
if "steps" in data: |
|
data["steps"] = list(data["steps"].values()) |
|
return data |
|
|
|
@model_validator(mode="before") |
|
def dictify_steps(cls, data): |
|
if "steps" in data and isinstance(data["steps"], list): |
|
steps_dict = {} |
|
for step in data["steps"]: |
|
if isinstance(step, ModelStep): |
|
step_id = step.id |
|
else: |
|
step_id = step["id"] |
|
if step_id in steps_dict: |
|
raise ValueError(f"Duplicate step ID: {step_id}") |
|
steps_dict[step_id] = step |
|
data["steps"] = steps_dict |
|
return data |
|
|
|
def get_step_variables(self, step_id: str) -> list[str]: |
|
"""Get all variables from a specific step.""" |
|
step = self.steps[step_id] |
|
variables = [] |
|
for output in step.output_fields: |
|
if output.name == "": |
|
continue |
|
output_var = f"{step.id}.{output.name}" |
|
variables.append(output_var) |
|
return variables |
|
|
|
def get_available_variables(self) -> list[str]: |
|
"""Get all output variables from all steps.""" |
|
variables = set(self.inputs) |
|
for step in self.steps.values(): |
|
variables.update(self.get_step_variables(step.id)) |
|
return list(variables) |
|
|
|
def get_step_model_selections(self) -> dict[str, str]: |
|
"""Get all model selections for all steps.""" |
|
return {step_id: step.get_full_model_name() for step_id, step in self.steps.items()} |
|
|
|
def get_output_model_selections(self) -> dict[str, str]: |
|
"""Get all output model selections for all steps.""" |
|
return { |
|
output_var: target_var.split(".")[0] if target_var else None |
|
for output_var, target_var in self.outputs.items() |
|
} |
|
|
|
|
|
|
|
def add_step(self, step: ModelStep) -> "Workflow": |
|
"""Add a step to the workflow.""" |
|
steps = self.steps | {step.id: step} |
|
return self.model_copy(update={"steps": steps}) |
|
|
|
def remove_step(self, step_id: str) -> "Workflow": |
|
"""Remove a step from the workflow.""" |
|
self.steps.pop(step_id) |
|
workflow = self.model_copy(update={"steps": self.steps}) |
|
workflow.refresh_output_variables() |
|
return workflow |
|
|
|
def update_step(self, step: ModelStep) -> "Workflow": |
|
"""Update a step in the workflow.""" |
|
self.steps[step.id] = step |
|
steps = self.steps | {step.id: step} |
|
workflow = self.model_copy(update={"steps": steps}) |
|
workflow.refresh_output_variables() |
|
return workflow |
|
|
|
|
|
def refresh_output_variables(self) -> "Workflow": |
|
"""Refresh the output variables for the workflow.""" |
|
produced_variables = self.get_available_variables() |
|
self.outputs = {k: (v if v in produced_variables else None) for k, v in self.outputs.items()} |
|
return self |
|
|
|
|
|
class BuzzerMethod(str, Enum): |
|
AND = "AND" |
|
OR = "OR" |
|
|
|
|
|
class Buzzer(BaseModel): |
|
"""Configuration for when to buzz in a tossup question.""" |
|
|
|
method: BuzzerMethod = BuzzerMethod.AND |
|
confidence_threshold: float = Field(default=0.5, ge=0.0, le=1.0) |
|
prob_threshold: float | None = None |
|
|
|
class Config: |
|
use_enum_values = True |
|
frozen = True |
|
|
|
def update(self, **kwargs) -> "Buzzer": |
|
"""Update the buzzer with the given kwargs.""" |
|
return self.model_copy(update=kwargs) |
|
|
|
def run(self, confidence: float, prob: float | None = None, logprob: float | None = None) -> bool: |
|
"""Run the buzzer logic.""" |
|
if logprob is not None and prob is not None: |
|
raise ValueError("Cannot provide both logprob and prob") |
|
if self.prob_threshold is None: |
|
return confidence >= self.confidence_threshold |
|
if logprob is None and prob is None: |
|
raise ValueError("Must provide either logprob or prob if prob_threshold is not None") |
|
prob = prob or float(np.exp(logprob)) |
|
if self.method == BuzzerMethod.AND: |
|
return confidence >= self.confidence_threshold and prob >= self.prob_threshold |
|
elif self.method == BuzzerMethod.OR: |
|
return confidence >= self.confidence_threshold or prob >= self.prob_threshold |
|
else: |
|
raise ValueError(f"Invalid buzzer method: {self.method}") |
|
|
|
@model_validator(mode="after") |
|
def validate_method_with_log_prob(cls, data): |
|
"""Validate that if prob_threshold is None, method must be 'and'.""" |
|
if data.prob_threshold is None and data.method != BuzzerMethod.AND: |
|
raise ValueError("If prob_threshold is None, method must be 'and'") |
|
return data |
|
|
|
|
|
class TossupWorkflow(Workflow): |
|
"""Workflow specialized for tossup questions with buzzing capability.""" |
|
|
|
buzzer: Buzzer = Field(default_factory=Buzzer) |
|
|
|
def get_answer_model(self, answer_var: str | None = None) -> str | None: |
|
answer_var = answer_var or self.outputs["answer"] |
|
if answer_var is None: |
|
return None |
|
step_id = answer_var.split(".")[0] |
|
return self.steps[step_id].get_full_model_name() |
|
|
|
def is_token_probs_supported(self, answer_var: str | None = None) -> bool: |
|
model_name = self.get_answer_model(answer_var) |
|
if model_name is None: |
|
return True |
|
return AVAILABLE_MODELS[model_name].get("logprobs", False) |
|
|
|
def update_buzzer(self, buzzer: Buzzer) -> "TossupWorkflow": |
|
"""Update the buzzer.""" |
|
return self.model_copy(update={"buzzer": buzzer}) |
|
|
|
def refresh_buzzer(self) -> "TossupWorkflow": |
|
if not self.is_token_probs_supported(): |
|
return self.update_buzzer(self.buzzer.update(prob_threshold=None, method="AND")) |
|
return self |
|
|