| from __future__ import annotations |
|
|
| import logging |
| import warnings |
| from abc import ABC, abstractmethod |
| from typing import ( |
| TYPE_CHECKING, |
| Any, |
| List, |
| Union, |
| Generic, |
| TypeVar, |
| Callable, |
| Iterable, |
| Iterator, |
| Coroutine, |
| AsyncIterator, |
| ) |
| from contextlib import contextmanager, asynccontextmanager |
| from typing_extensions import TypedDict, override |
|
|
| import httpx |
|
|
| from ..._types import Body, Query, Headers, NotGiven |
| from ..._utils import consume_sync_iterator, consume_async_iterator |
| from ...types.beta import BetaMessage, BetaMessageParam |
| from ._beta_functions import ( |
| BetaFunctionTool, |
| BetaRunnableTool, |
| BetaAsyncFunctionTool, |
| BetaAsyncRunnableTool, |
| BetaBuiltinFunctionTool, |
| BetaAsyncBuiltinFunctionTool, |
| ) |
| from ._beta_compaction_control import DEFAULT_THRESHOLD, DEFAULT_SUMMARY_PROMPT, CompactionControl |
| from ..streaming._beta_messages import BetaMessageStream, BetaAsyncMessageStream |
| from ...types.beta.parsed_beta_message import ResponseFormatT, ParsedBetaMessage, ParsedBetaContentBlock |
| from ...types.beta.message_create_params import ParseMessageCreateParamsBase |
| from ...types.beta.beta_tool_result_block_param import BetaToolResultBlockParam |
|
|
| if TYPE_CHECKING: |
| from ..._client import Anthropic, AsyncAnthropic |
|
|
|
|
| AnyFunctionToolT = TypeVar( |
| "AnyFunctionToolT", |
| bound=Union[ |
| BetaFunctionTool[Any], BetaAsyncFunctionTool[Any], BetaBuiltinFunctionTool, BetaAsyncBuiltinFunctionTool |
| ], |
| ) |
| RunnerItemT = TypeVar("RunnerItemT") |
|
|
| log = logging.getLogger(__name__) |
|
|
|
|
| class RequestOptions(TypedDict, total=False): |
| extra_headers: Headers | None |
| extra_query: Query | None |
| extra_body: Body | None |
| timeout: float | httpx.Timeout | None | NotGiven |
|
|
|
|
| class BaseToolRunner(Generic[AnyFunctionToolT, ResponseFormatT]): |
| def __init__( |
| self, |
| *, |
| params: ParseMessageCreateParamsBase[ResponseFormatT], |
| options: RequestOptions, |
| tools: Iterable[AnyFunctionToolT], |
| max_iterations: int | None = None, |
| compaction_control: CompactionControl | None = None, |
| ) -> None: |
| self._tools_by_name = {tool.name: tool for tool in tools} |
| self._params: ParseMessageCreateParamsBase[ResponseFormatT] = { |
| **params, |
| "messages": [message for message in params["messages"]], |
| } |
| self._options = options |
| self._messages_modified = False |
| self._cached_tool_call_response: BetaMessageParam | None = None |
| self._max_iterations = max_iterations |
| self._iteration_count = 0 |
| self._compaction_control = compaction_control |
|
|
| def set_messages_params( |
| self, |
| params: ParseMessageCreateParamsBase[ResponseFormatT] |
| | Callable[[ParseMessageCreateParamsBase[ResponseFormatT]], ParseMessageCreateParamsBase[ResponseFormatT]], |
| ) -> None: |
| """ |
| Update the parameters for the next API call. This invalidates any cached tool responses. |
| |
| Args: |
| params (ParsedMessageCreateParamsBase[ResponseFormatT] | Callable): Either new parameters or a function to mutate existing parameters |
| """ |
| if callable(params): |
| params = params(self._params) |
| self._params = params |
|
|
| def append_messages(self, *messages: BetaMessageParam | ParsedBetaMessage[ResponseFormatT]) -> None: |
| """Add one or more messages to the conversation history. |
| |
| This invalidates the cached tool response, i.e. if tools were already called, then they will |
| be called again on the next loop iteration. |
| """ |
| message_params: List[BetaMessageParam] = [ |
| {"role": message.role, "content": message.content} if isinstance(message, BetaMessage) else message |
| for message in messages |
| ] |
| self._messages_modified = True |
| self.set_messages_params(lambda params: {**params, "messages": [*self._params["messages"], *message_params]}) |
| self._cached_tool_call_response = None |
|
|
| def _should_stop(self) -> bool: |
| if self._max_iterations is not None and self._iteration_count >= self._max_iterations: |
| return True |
| return False |
|
|
|
|
| class BaseSyncToolRunner(BaseToolRunner[BetaRunnableTool, ResponseFormatT], Generic[RunnerItemT, ResponseFormatT], ABC): |
| def __init__( |
| self, |
| *, |
| params: ParseMessageCreateParamsBase[ResponseFormatT], |
| options: RequestOptions, |
| tools: Iterable[BetaRunnableTool], |
| client: Anthropic, |
| max_iterations: int | None = None, |
| compaction_control: CompactionControl | None = None, |
| ) -> None: |
| super().__init__( |
| params=params, |
| options=options, |
| tools=tools, |
| max_iterations=max_iterations, |
| compaction_control=compaction_control, |
| ) |
| self._client = client |
|
|
| self._iterator = self.__run__() |
| self._last_message: ( |
| Callable[[], ParsedBetaMessage[ResponseFormatT]] | ParsedBetaMessage[ResponseFormatT] | None |
| ) = None |
|
|
| def __next__(self) -> RunnerItemT: |
| return self._iterator.__next__() |
|
|
| def __iter__(self) -> Iterator[RunnerItemT]: |
| for item in self._iterator: |
| yield item |
|
|
| @abstractmethod |
| @contextmanager |
| def _handle_request(self) -> Iterator[RunnerItemT]: |
| raise NotImplementedError() |
| yield |
|
|
| def _check_and_compact(self) -> bool: |
| """ |
| Check token usage and compact messages if threshold exceeded. |
| Returns True if compaction was performed, False otherwise. |
| """ |
| if self._compaction_control is None or not self._compaction_control["enabled"]: |
| return False |
|
|
| message = self._get_last_message() |
| tokens_used = 0 |
| if message is not None: |
| total_input_tokens = ( |
| message.usage.input_tokens |
| + (message.usage.cache_creation_input_tokens or 0) |
| + (message.usage.cache_read_input_tokens or 0) |
| ) |
| tokens_used = total_input_tokens + message.usage.output_tokens |
|
|
| threshold = self._compaction_control.get("context_token_threshold", DEFAULT_THRESHOLD) |
|
|
| if tokens_used < threshold: |
| return False |
|
|
| |
| log.info(f"Token usage {tokens_used} has exceeded the threshold of {threshold}. Performing compaction.") |
|
|
| model = self._compaction_control.get("model", self._params["model"]) |
|
|
| messages = list(self._params["messages"]) |
|
|
| if messages[-1]["role"] == "assistant": |
| |
| |
| non_tool_blocks = [ |
| block |
| for block in messages[-1]["content"] |
| if isinstance(block, dict) and block.get("type") != "tool_use" |
| ] |
|
|
| if non_tool_blocks: |
| messages[-1]["content"] = non_tool_blocks |
| else: |
| messages.pop() |
|
|
| messages = [ |
| *messages, |
| BetaMessageParam( |
| role="user", |
| content=self._compaction_control.get("summary_prompt", DEFAULT_SUMMARY_PROMPT), |
| ), |
| ] |
|
|
| response = self._client.beta.messages.create( |
| model=model, |
| messages=messages, |
| max_tokens=self._params["max_tokens"], |
| extra_headers={"X-Stainless-Helper": "compaction"}, |
| ) |
|
|
| log.info(f"Compaction complete. New token usage: {response.usage.output_tokens}") |
|
|
| first_content = list(response.content)[0] |
|
|
| if first_content.type != "text": |
| raise ValueError("Compaction response content is not of type 'text'") |
|
|
| self.set_messages_params( |
| lambda params: { |
| **params, |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": first_content.text, |
| } |
| ], |
| } |
| ], |
| } |
| ) |
| return True |
|
|
| def __run__(self) -> Iterator[RunnerItemT]: |
| while not self._should_stop(): |
| with self._handle_request() as item: |
| yield item |
| message = self._get_last_message() |
| assert message is not None |
|
|
| self._iteration_count += 1 |
|
|
| |
| if not self._check_and_compact(): |
| response = self.generate_tool_call_response() |
| if response is None: |
| log.debug("Tool call was not requested, exiting from tool runner loop.") |
| return |
|
|
| if not self._messages_modified: |
| self.append_messages(message, response) |
|
|
| self._messages_modified = False |
| self._cached_tool_call_response = None |
|
|
| def until_done(self) -> ParsedBetaMessage[ResponseFormatT]: |
| """ |
| Consumes the tool runner stream and returns the last message if it has not been consumed yet. |
| If it has, it simply returns the last message. |
| """ |
| consume_sync_iterator(self) |
| last_message = self._get_last_message() |
| assert last_message is not None |
| return last_message |
|
|
| def generate_tool_call_response(self) -> BetaMessageParam | None: |
| """Generate a MessageParam by calling tool functions with any tool use blocks from the last message. |
| |
| Note the tool call response is cached, repeated calls to this method will return the same response. |
| |
| None can be returned if no tool call was applicable. |
| """ |
| if self._cached_tool_call_response is not None: |
| log.debug("Returning cached tool call response.") |
| return self._cached_tool_call_response |
| response = self._generate_tool_call_response() |
| self._cached_tool_call_response = response |
| return response |
|
|
| def _generate_tool_call_response(self) -> BetaMessageParam | None: |
| content = self._get_last_assistant_message_content() |
| if not content: |
| return None |
|
|
| tool_use_blocks = [block for block in content if block.type == "tool_use"] |
| if not tool_use_blocks: |
| return None |
|
|
| results: list[BetaToolResultBlockParam] = [] |
|
|
| for tool_use in tool_use_blocks: |
| tool = self._tools_by_name.get(tool_use.name) |
| if tool is None: |
| warnings.warn( |
| f"Tool '{tool_use.name}' not found in tool runner. " |
| f"Available tools: {list(self._tools_by_name.keys())}. " |
| f"If using a raw tool definition, handle the tool call manually and use `append_messages()` to add the result. " |
| f"Otherwise, pass the tool using `beta_tool(func)` or a `@beta_tool` decorated function.", |
| UserWarning, |
| stacklevel=3, |
| ) |
| results.append( |
| { |
| "type": "tool_result", |
| "tool_use_id": tool_use.id, |
| "content": f"Error: Tool '{tool_use.name}' not found", |
| "is_error": True, |
| } |
| ) |
| continue |
|
|
| try: |
| result = tool.call(tool_use.input) |
| results.append({"type": "tool_result", "tool_use_id": tool_use.id, "content": result}) |
| except Exception as exc: |
| log.exception(f"Error occurred while calling tool: {tool.name}", exc_info=exc) |
| results.append( |
| { |
| "type": "tool_result", |
| "tool_use_id": tool_use.id, |
| "content": repr(exc), |
| "is_error": True, |
| } |
| ) |
|
|
| return {"role": "user", "content": results} |
|
|
| def _get_last_message(self) -> ParsedBetaMessage[ResponseFormatT] | None: |
| if callable(self._last_message): |
| return self._last_message() |
| return self._last_message |
|
|
| def _get_last_assistant_message_content(self) -> list[ParsedBetaContentBlock[ResponseFormatT]] | None: |
| last_message = self._get_last_message() |
| if last_message is None or last_message.role != "assistant" or not last_message.content: |
| return None |
|
|
| return last_message.content |
|
|
|
|
| class BetaToolRunner(BaseSyncToolRunner[ParsedBetaMessage[ResponseFormatT], ResponseFormatT]): |
| @override |
| @contextmanager |
| def _handle_request(self) -> Iterator[ParsedBetaMessage[ResponseFormatT]]: |
| message = self._client.beta.messages.parse(**self._params, **self._options) |
| self._last_message = message |
| yield message |
|
|
|
|
| class BetaStreamingToolRunner(BaseSyncToolRunner[BetaMessageStream[ResponseFormatT], ResponseFormatT]): |
| @override |
| @contextmanager |
| def _handle_request(self) -> Iterator[BetaMessageStream[ResponseFormatT]]: |
| with self._client.beta.messages.stream(**self._params, **self._options) as stream: |
| self._last_message = stream.get_final_message |
| yield stream |
|
|
|
|
| class BaseAsyncToolRunner( |
| BaseToolRunner[BetaAsyncRunnableTool, ResponseFormatT], Generic[RunnerItemT, ResponseFormatT], ABC |
| ): |
| def __init__( |
| self, |
| *, |
| params: ParseMessageCreateParamsBase[ResponseFormatT], |
| options: RequestOptions, |
| tools: Iterable[BetaAsyncRunnableTool], |
| client: AsyncAnthropic, |
| max_iterations: int | None = None, |
| compaction_control: CompactionControl | None = None, |
| ) -> None: |
| super().__init__( |
| params=params, |
| options=options, |
| tools=tools, |
| max_iterations=max_iterations, |
| compaction_control=compaction_control, |
| ) |
| self._client = client |
|
|
| self._iterator = self.__run__() |
| self._last_message: ( |
| Callable[[], Coroutine[None, None, ParsedBetaMessage[ResponseFormatT]]] |
| | ParsedBetaMessage[ResponseFormatT] |
| | None |
| ) = None |
|
|
| async def __anext__(self) -> RunnerItemT: |
| return await self._iterator.__anext__() |
|
|
| async def __aiter__(self) -> AsyncIterator[RunnerItemT]: |
| async for item in self._iterator: |
| yield item |
|
|
| @abstractmethod |
| @asynccontextmanager |
| async def _handle_request(self) -> AsyncIterator[RunnerItemT]: |
| raise NotImplementedError() |
| yield |
|
|
| async def _check_and_compact(self) -> bool: |
| """ |
| Check token usage and compact messages if threshold exceeded. |
| Returns True if compaction was performed, False otherwise. |
| """ |
| if self._compaction_control is None or not self._compaction_control["enabled"]: |
| return False |
|
|
| message = await self._get_last_message() |
| tokens_used = 0 |
| if message is not None: |
| total_input_tokens = ( |
| message.usage.input_tokens |
| + (message.usage.cache_creation_input_tokens or 0) |
| + (message.usage.cache_read_input_tokens or 0) |
| ) |
| tokens_used = total_input_tokens + message.usage.output_tokens |
|
|
| threshold = self._compaction_control.get("context_token_threshold", DEFAULT_THRESHOLD) |
|
|
| if tokens_used < threshold: |
| return False |
|
|
| |
| log.info(f"Token usage {tokens_used} has exceeded the threshold of {threshold}. Performing compaction.") |
|
|
| model = self._compaction_control.get("model", self._params["model"]) |
|
|
| messages = list(self._params["messages"]) |
|
|
| if messages[-1]["role"] == "assistant": |
| |
| |
| non_tool_blocks = [ |
| block |
| for block in messages[-1]["content"] |
| if isinstance(block, dict) and block.get("type") != "tool_use" |
| ] |
|
|
| if non_tool_blocks: |
| messages[-1]["content"] = non_tool_blocks |
| else: |
| messages.pop() |
|
|
| messages = [ |
| *self._params["messages"], |
| BetaMessageParam( |
| role="user", |
| content=self._compaction_control.get("summary_prompt", DEFAULT_SUMMARY_PROMPT), |
| ), |
| ] |
|
|
| response = await self._client.beta.messages.create( |
| model=model, |
| messages=messages, |
| max_tokens=self._params["max_tokens"], |
| extra_headers={"X-Stainless-Helper": "compaction"}, |
| ) |
|
|
| log.info(f"Compaction complete. New token usage: {response.usage.output_tokens}") |
|
|
| first_content = list(response.content)[0] |
|
|
| if first_content.type != "text": |
| raise ValueError("Compaction response content is not of type 'text'") |
|
|
| self.set_messages_params( |
| lambda params: { |
| **params, |
| "messages": [ |
| { |
| "role": "user", |
| "content": [ |
| { |
| "type": "text", |
| "text": first_content.text, |
| } |
| ], |
| } |
| ], |
| } |
| ) |
| return True |
|
|
| async def __run__(self) -> AsyncIterator[RunnerItemT]: |
| while not self._should_stop(): |
| async with self._handle_request() as item: |
| yield item |
| message = await self._get_last_message() |
| assert message is not None |
|
|
| self._iteration_count += 1 |
|
|
| |
| if not await self._check_and_compact(): |
| response = await self.generate_tool_call_response() |
| if response is None: |
| log.debug("Tool call was not requested, exiting from tool runner loop.") |
| return |
|
|
| if not self._messages_modified: |
| self.append_messages(message, response) |
|
|
| self._messages_modified = False |
| self._cached_tool_call_response = None |
|
|
| async def until_done(self) -> ParsedBetaMessage[ResponseFormatT]: |
| """ |
| Consumes the tool runner stream and returns the last message if it has not been consumed yet. |
| If it has, it simply returns the last message. |
| """ |
| await consume_async_iterator(self) |
| last_message = await self._get_last_message() |
| assert last_message is not None |
| return last_message |
|
|
| async def generate_tool_call_response(self) -> BetaMessageParam | None: |
| """Generate a MessageParam by calling tool functions with any tool use blocks from the last message. |
| |
| Note the tool call response is cached, repeated calls to this method will return the same response. |
| |
| None can be returned if no tool call was applicable. |
| """ |
| if self._cached_tool_call_response is not None: |
| log.debug("Returning cached tool call response.") |
| return self._cached_tool_call_response |
|
|
| response = await self._generate_tool_call_response() |
| self._cached_tool_call_response = response |
| return response |
|
|
| async def _get_last_message(self) -> ParsedBetaMessage[ResponseFormatT] | None: |
| if callable(self._last_message): |
| return await self._last_message() |
| return self._last_message |
|
|
| async def _get_last_assistant_message_content(self) -> list[ParsedBetaContentBlock[ResponseFormatT]] | None: |
| last_message = await self._get_last_message() |
| if last_message is None or last_message.role != "assistant" or not last_message.content: |
| return None |
|
|
| return last_message.content |
|
|
| async def _generate_tool_call_response(self) -> BetaMessageParam | None: |
| content = await self._get_last_assistant_message_content() |
| if not content: |
| return None |
|
|
| tool_use_blocks = [block for block in content if block.type == "tool_use"] |
| if not tool_use_blocks: |
| return None |
|
|
| results: list[BetaToolResultBlockParam] = [] |
|
|
| for tool_use in tool_use_blocks: |
| tool = self._tools_by_name.get(tool_use.name) |
| if tool is None: |
| warnings.warn( |
| f"Tool '{tool_use.name}' not found in tool runner. " |
| f"Available tools: {list(self._tools_by_name.keys())}. " |
| f"If using a raw tool definition, handle the tool call manually and use `append_messages()` to add the result. " |
| f"Otherwise, pass the tool using `beta_async_tool(func)` or a `@beta_async_tool` decorated function.", |
| UserWarning, |
| stacklevel=3, |
| ) |
| results.append( |
| { |
| "type": "tool_result", |
| "tool_use_id": tool_use.id, |
| "content": f"Error: Tool '{tool_use.name}' not found", |
| "is_error": True, |
| } |
| ) |
| continue |
|
|
| try: |
| result = await tool.call(tool_use.input) |
| results.append({"type": "tool_result", "tool_use_id": tool_use.id, "content": result}) |
| except Exception as exc: |
| log.exception(f"Error occurred while calling tool: {tool.name}", exc_info=exc) |
| results.append( |
| { |
| "type": "tool_result", |
| "tool_use_id": tool_use.id, |
| "content": repr(exc), |
| "is_error": True, |
| } |
| ) |
|
|
| return {"role": "user", "content": results} |
|
|
|
|
| class BetaAsyncToolRunner(BaseAsyncToolRunner[ParsedBetaMessage[ResponseFormatT], ResponseFormatT]): |
| @override |
| @asynccontextmanager |
| async def _handle_request(self) -> AsyncIterator[ParsedBetaMessage[ResponseFormatT]]: |
| message = await self._client.beta.messages.parse(**self._params, **self._options) |
| self._last_message = message |
| yield message |
|
|
|
|
| class BetaAsyncStreamingToolRunner(BaseAsyncToolRunner[BetaAsyncMessageStream[ResponseFormatT], ResponseFormatT]): |
| @override |
| @asynccontextmanager |
| async def _handle_request(self) -> AsyncIterator[BetaAsyncMessageStream[ResponseFormatT]]: |
| async with self._client.beta.messages.stream(**self._params, **self._options) as stream: |
| self._last_message = stream.get_final_message |
| yield stream |
|
|