|
|
from __future__ import annotations |
|
|
|
|
|
import dataclasses |
|
|
from collections.abc import Mapping |
|
|
from dataclasses import fields, replace |
|
|
from typing import Annotated, Any, Literal, Union |
|
|
|
|
|
from openai import Omit as _Omit |
|
|
from openai._types import Body, Query |
|
|
from openai.types.responses import ResponseIncludable |
|
|
from openai.types.shared import Reasoning |
|
|
from pydantic import BaseModel, GetCoreSchemaHandler |
|
|
from pydantic.dataclasses import dataclass |
|
|
from pydantic_core import core_schema |
|
|
from typing_extensions import TypeAlias |
|
|
|
|
|
|
|
|
class _OmitTypeAnnotation: |
|
|
@classmethod |
|
|
def __get_pydantic_core_schema__( |
|
|
cls, |
|
|
_source_type: Any, |
|
|
_handler: GetCoreSchemaHandler, |
|
|
) -> core_schema.CoreSchema: |
|
|
def validate_from_none(value: None) -> _Omit: |
|
|
return _Omit() |
|
|
|
|
|
from_none_schema = core_schema.chain_schema( |
|
|
[ |
|
|
core_schema.none_schema(), |
|
|
core_schema.no_info_plain_validator_function(validate_from_none), |
|
|
] |
|
|
) |
|
|
return core_schema.json_or_python_schema( |
|
|
json_schema=from_none_schema, |
|
|
python_schema=core_schema.union_schema( |
|
|
[ |
|
|
|
|
|
core_schema.is_instance_schema(_Omit), |
|
|
from_none_schema, |
|
|
] |
|
|
), |
|
|
serialization=core_schema.plain_serializer_function_ser_schema(lambda instance: None), |
|
|
) |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class MCPToolChoice: |
|
|
server_label: str |
|
|
name: str |
|
|
|
|
|
|
|
|
Omit = Annotated[_Omit, _OmitTypeAnnotation] |
|
|
Headers: TypeAlias = Mapping[str, Union[str, Omit]] |
|
|
ToolChoice: TypeAlias = Union[Literal["auto", "required", "none"], str, MCPToolChoice, None] |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ModelSettings: |
|
|
"""Settings to use when calling an LLM. |
|
|
|
|
|
This class holds optional model configuration parameters (e.g. temperature, |
|
|
top_p, penalties, truncation, etc.). |
|
|
|
|
|
Not all models/providers support all of these parameters, so please check the API documentation |
|
|
for the specific model and provider you are using. |
|
|
""" |
|
|
|
|
|
temperature: float | None = None |
|
|
"""The temperature to use when calling the model.""" |
|
|
|
|
|
top_p: float | None = None |
|
|
"""The top_p to use when calling the model.""" |
|
|
|
|
|
frequency_penalty: float | None = None |
|
|
"""The frequency penalty to use when calling the model.""" |
|
|
|
|
|
presence_penalty: float | None = None |
|
|
"""The presence penalty to use when calling the model.""" |
|
|
|
|
|
tool_choice: ToolChoice | None = None |
|
|
"""The tool choice to use when calling the model.""" |
|
|
|
|
|
parallel_tool_calls: bool | None = None |
|
|
"""Controls whether the model can make multiple parallel tool calls in a single turn. |
|
|
If not provided (i.e., set to None), this behavior defers to the underlying |
|
|
model provider's default. For most current providers (e.g., OpenAI), this typically |
|
|
means parallel tool calls are enabled (True). |
|
|
Set to True to explicitly enable parallel tool calls, or False to restrict the |
|
|
model to at most one tool call per turn. |
|
|
""" |
|
|
|
|
|
truncation: Literal["auto", "disabled"] | None = None |
|
|
"""The truncation strategy to use when calling the model. |
|
|
See [Responses API documentation](https://platform.openai.com/docs/api-reference/responses/create#responses_create-truncation) |
|
|
for more details. |
|
|
""" |
|
|
|
|
|
max_tokens: int | None = None |
|
|
"""The maximum number of output tokens to generate.""" |
|
|
|
|
|
reasoning: Reasoning | None = None |
|
|
"""Configuration options for |
|
|
[reasoning models](https://platform.openai.com/docs/guides/reasoning). |
|
|
""" |
|
|
|
|
|
verbosity: Literal["low", "medium", "high"] | None = None |
|
|
"""Constrains the verbosity of the model's response. |
|
|
""" |
|
|
|
|
|
metadata: dict[str, str] | None = None |
|
|
"""Metadata to include with the model response call.""" |
|
|
|
|
|
store: bool | None = None |
|
|
"""Whether to store the generated model response for later retrieval. |
|
|
For Responses API: automatically enabled when not specified. |
|
|
For Chat Completions API: disabled when not specified.""" |
|
|
|
|
|
include_usage: bool | None = None |
|
|
"""Whether to include usage chunk. |
|
|
Only available for Chat Completions API.""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response_include: list[ResponseIncludable | str] | None = None |
|
|
"""Additional output data to include in the model response. |
|
|
[include parameter](https://platform.openai.com/docs/api-reference/responses/create#responses-create-include)""" |
|
|
|
|
|
top_logprobs: int | None = None |
|
|
"""Number of top tokens to return logprobs for. Setting this will |
|
|
automatically include ``"message.output_text.logprobs"`` in the response.""" |
|
|
|
|
|
extra_query: Query | None = None |
|
|
"""Additional query fields to provide with the request. |
|
|
Defaults to None if not provided.""" |
|
|
|
|
|
extra_body: Body | None = None |
|
|
"""Additional body fields to provide with the request. |
|
|
Defaults to None if not provided.""" |
|
|
|
|
|
extra_headers: Headers | None = None |
|
|
"""Additional headers to provide with the request. |
|
|
Defaults to None if not provided.""" |
|
|
|
|
|
extra_args: dict[str, Any] | None = None |
|
|
"""Arbitrary keyword arguments to pass to the model API call. |
|
|
These will be passed directly to the underlying model provider's API. |
|
|
Use with caution as not all models support all parameters.""" |
|
|
|
|
|
def resolve(self, override: ModelSettings | None) -> ModelSettings: |
|
|
"""Produce a new ModelSettings by overlaying any non-None values from the |
|
|
override on top of this instance.""" |
|
|
if override is None: |
|
|
return self |
|
|
|
|
|
changes = { |
|
|
field.name: getattr(override, field.name) |
|
|
for field in fields(self) |
|
|
if getattr(override, field.name) is not None |
|
|
} |
|
|
|
|
|
|
|
|
if self.extra_args is not None or override.extra_args is not None: |
|
|
merged_args = {} |
|
|
if self.extra_args: |
|
|
merged_args.update(self.extra_args) |
|
|
if override.extra_args: |
|
|
merged_args.update(override.extra_args) |
|
|
changes["extra_args"] = merged_args if merged_args else None |
|
|
|
|
|
return replace(self, **changes) |
|
|
|
|
|
def to_json_dict(self) -> dict[str, Any]: |
|
|
dataclass_dict = dataclasses.asdict(self) |
|
|
|
|
|
json_dict: dict[str, Any] = {} |
|
|
|
|
|
for field_name, value in dataclass_dict.items(): |
|
|
if isinstance(value, BaseModel): |
|
|
json_dict[field_name] = value.model_dump(mode="json") |
|
|
else: |
|
|
json_dict[field_name] = value |
|
|
|
|
|
return json_dict |
|
|
|