| |
| from __future__ import annotations |
|
|
| import abc |
| import json |
| import inspect |
| import warnings |
| from types import TracebackType |
| from typing import TYPE_CHECKING, Any, Generic, TypeVar, Iterator, AsyncIterator, cast |
| from typing_extensions import Self, Protocol, TypeGuard, override, get_origin, runtime_checkable |
|
|
| import httpx |
|
|
| from ._utils import is_dict, extract_type_var_from_base |
|
|
| if TYPE_CHECKING: |
| from ._client import Anthropic, AsyncAnthropic |
|
|
|
|
| _T = TypeVar("_T") |
|
|
|
|
| class _SyncStreamMeta(abc.ABCMeta): |
| @override |
| def __instancecheck__(self, instance: Any) -> bool: |
| |
| |
| |
| |
|
|
| from .lib.streaming import MessageStream |
|
|
| if isinstance(instance, MessageStream): |
| warnings.warn( |
| "Using `isinstance()` to check if a `MessageStream` object is an instance of `Stream` is deprecated & will be removed in the next major version", |
| DeprecationWarning, |
| stacklevel=2, |
| ) |
| return True |
|
|
| return False |
|
|
|
|
| class Stream(Generic[_T], metaclass=_SyncStreamMeta): |
| """Provides the core interface to iterate over a synchronous stream response.""" |
|
|
| response: httpx.Response |
|
|
| _decoder: SSEBytesDecoder |
|
|
| def __init__( |
| self, |
| *, |
| cast_to: type[_T], |
| response: httpx.Response, |
| client: Anthropic, |
| ) -> None: |
| self.response = response |
| self._cast_to = cast_to |
| self._client = client |
| self._decoder = client._make_sse_decoder() |
| self._iterator = self.__stream__() |
|
|
| def __next__(self) -> _T: |
| return self._iterator.__next__() |
|
|
| def __iter__(self) -> Iterator[_T]: |
| for item in self._iterator: |
| yield item |
|
|
| def _iter_events(self) -> Iterator[ServerSentEvent]: |
| yield from self._decoder.iter_bytes(self.response.iter_bytes()) |
|
|
| def __stream__(self) -> Iterator[_T]: |
| cast_to = cast(Any, self._cast_to) |
| response = self.response |
| process_data = self._client._process_response_data |
| iterator = self._iter_events() |
|
|
| try: |
| for sse in iterator: |
| if sse.event == "completion": |
| yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
|
|
| if ( |
| sse.event == "message_start" |
| or sse.event == "message_delta" |
| or sse.event == "message_stop" |
| or sse.event == "content_block_start" |
| or sse.event == "content_block_delta" |
| or sse.event == "content_block_stop" |
| ): |
| data = sse.json() |
| if is_dict(data) and "type" not in data: |
| data["type"] = sse.event |
|
|
| yield process_data(data=data, cast_to=cast_to, response=response) |
|
|
| if sse.event == "ping": |
| continue |
|
|
| if sse.event == "error": |
| body = sse.data |
|
|
| try: |
| body = sse.json() |
| err_msg = f"{body}" |
| except Exception: |
| err_msg = sse.data or f"Error code: {response.status_code}" |
|
|
| raise self._client._make_status_error( |
| err_msg, |
| body=body, |
| response=self.response, |
| ) |
| finally: |
| |
| response.close() |
|
|
| def __enter__(self) -> Self: |
| return self |
|
|
| def __exit__( |
| self, |
| exc_type: type[BaseException] | None, |
| exc: BaseException | None, |
| exc_tb: TracebackType | None, |
| ) -> None: |
| self.close() |
|
|
| def close(self) -> None: |
| """ |
| Close the response and release the connection. |
| |
| Automatically called if the response body is read to completion. |
| """ |
| self.response.close() |
|
|
|
|
| class _AsyncStreamMeta(abc.ABCMeta): |
| @override |
| def __instancecheck__(self, instance: Any) -> bool: |
| |
| |
| |
| |
|
|
| from .lib.streaming import AsyncMessageStream |
|
|
| if isinstance(instance, AsyncMessageStream): |
| warnings.warn( |
| "Using `isinstance()` to check if a `AsyncMessageStream` object is an instance of `AsyncStream` is deprecated & will be removed in the next major version", |
| DeprecationWarning, |
| stacklevel=2, |
| ) |
| return True |
|
|
| return False |
|
|
|
|
| class AsyncStream(Generic[_T], metaclass=_AsyncStreamMeta): |
| """Provides the core interface to iterate over an asynchronous stream response.""" |
|
|
| response: httpx.Response |
|
|
| _decoder: SSEDecoder | SSEBytesDecoder |
|
|
| def __init__( |
| self, |
| *, |
| cast_to: type[_T], |
| response: httpx.Response, |
| client: AsyncAnthropic, |
| ) -> None: |
| self.response = response |
| self._cast_to = cast_to |
| self._client = client |
| self._decoder = client._make_sse_decoder() |
| self._iterator = self.__stream__() |
|
|
| async def __anext__(self) -> _T: |
| return await self._iterator.__anext__() |
|
|
| async def __aiter__(self) -> AsyncIterator[_T]: |
| async for item in self._iterator: |
| yield item |
|
|
| async def _iter_events(self) -> AsyncIterator[ServerSentEvent]: |
| async for sse in self._decoder.aiter_bytes(self.response.aiter_bytes()): |
| yield sse |
|
|
| async def __stream__(self) -> AsyncIterator[_T]: |
| cast_to = cast(Any, self._cast_to) |
| response = self.response |
| process_data = self._client._process_response_data |
| iterator = self._iter_events() |
|
|
| try: |
| async for sse in iterator: |
| if sse.event == "completion": |
| yield process_data(data=sse.json(), cast_to=cast_to, response=response) |
|
|
| if ( |
| sse.event == "message_start" |
| or sse.event == "message_delta" |
| or sse.event == "message_stop" |
| or sse.event == "content_block_start" |
| or sse.event == "content_block_delta" |
| or sse.event == "content_block_stop" |
| ): |
| data = sse.json() |
| if is_dict(data) and "type" not in data: |
| data["type"] = sse.event |
|
|
| yield process_data(data=data, cast_to=cast_to, response=response) |
|
|
| if sse.event == "ping": |
| continue |
|
|
| if sse.event == "error": |
| body = sse.data |
|
|
| try: |
| body = sse.json() |
| err_msg = f"{body}" |
| except Exception: |
| err_msg = sse.data or f"Error code: {response.status_code}" |
|
|
| raise self._client._make_status_error( |
| err_msg, |
| body=body, |
| response=self.response, |
| ) |
| finally: |
| |
| await response.aclose() |
|
|
| async def __aenter__(self) -> Self: |
| return self |
|
|
| async def __aexit__( |
| self, |
| exc_type: type[BaseException] | None, |
| exc: BaseException | None, |
| exc_tb: TracebackType | None, |
| ) -> None: |
| await self.close() |
|
|
| async def close(self) -> None: |
| """ |
| Close the response and release the connection. |
| |
| Automatically called if the response body is read to completion. |
| """ |
| await self.response.aclose() |
|
|
|
|
| class ServerSentEvent: |
| def __init__( |
| self, |
| *, |
| event: str | None = None, |
| data: str | None = None, |
| id: str | None = None, |
| retry: int | None = None, |
| ) -> None: |
| if data is None: |
| data = "" |
|
|
| self._id = id |
| self._data = data |
| self._event = event or None |
| self._retry = retry |
|
|
| @property |
| def event(self) -> str | None: |
| return self._event |
|
|
| @property |
| def id(self) -> str | None: |
| return self._id |
|
|
| @property |
| def retry(self) -> int | None: |
| return self._retry |
|
|
| @property |
| def data(self) -> str: |
| return self._data |
|
|
| def json(self) -> Any: |
| return json.loads(self.data) |
|
|
| @override |
| def __repr__(self) -> str: |
| return f"ServerSentEvent(event={self.event}, data={self.data}, id={self.id}, retry={self.retry})" |
|
|
|
|
| class SSEDecoder: |
| _data: list[str] |
| _event: str | None |
| _retry: int | None |
| _last_event_id: str | None |
|
|
| def __init__(self) -> None: |
| self._event = None |
| self._data = [] |
| self._last_event_id = None |
| self._retry = None |
|
|
| def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: |
| """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" |
| for chunk in self._iter_chunks(iterator): |
| |
| for raw_line in chunk.splitlines(): |
| line = raw_line.decode("utf-8") |
| sse = self.decode(line) |
| if sse: |
| yield sse |
|
|
| def _iter_chunks(self, iterator: Iterator[bytes]) -> Iterator[bytes]: |
| """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" |
| data = b"" |
| for chunk in iterator: |
| for line in chunk.splitlines(keepends=True): |
| data += line |
| if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): |
| yield data |
| data = b"" |
| if data: |
| yield data |
|
|
| async def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: |
| """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" |
| async for chunk in self._aiter_chunks(iterator): |
| |
| for raw_line in chunk.splitlines(): |
| line = raw_line.decode("utf-8") |
| sse = self.decode(line) |
| if sse: |
| yield sse |
|
|
| async def _aiter_chunks(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[bytes]: |
| """Given an iterator that yields raw binary data, iterate over it and yield individual SSE chunks""" |
| data = b"" |
| async for chunk in iterator: |
| for line in chunk.splitlines(keepends=True): |
| data += line |
| if data.endswith((b"\r\r", b"\n\n", b"\r\n\r\n")): |
| yield data |
| data = b"" |
| if data: |
| yield data |
|
|
| def decode(self, line: str) -> ServerSentEvent | None: |
| |
|
|
| if not line: |
| if not self._event and not self._data and not self._last_event_id and self._retry is None: |
| return None |
|
|
| sse = ServerSentEvent( |
| event=self._event, |
| data="\n".join(self._data), |
| id=self._last_event_id, |
| retry=self._retry, |
| ) |
|
|
| |
| self._event = None |
| self._data = [] |
| self._retry = None |
|
|
| return sse |
|
|
| if line.startswith(":"): |
| return None |
|
|
| fieldname, _, value = line.partition(":") |
|
|
| if value.startswith(" "): |
| value = value[1:] |
|
|
| if fieldname == "event": |
| self._event = value |
| elif fieldname == "data": |
| self._data.append(value) |
| elif fieldname == "id": |
| if "\0" in value: |
| pass |
| else: |
| self._last_event_id = value |
| elif fieldname == "retry": |
| try: |
| self._retry = int(value) |
| except (TypeError, ValueError): |
| pass |
| else: |
| pass |
|
|
| return None |
|
|
|
|
| @runtime_checkable |
| class SSEBytesDecoder(Protocol): |
| def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[ServerSentEvent]: |
| """Given an iterator that yields raw binary data, iterate over it & yield every event encountered""" |
| ... |
|
|
| def aiter_bytes(self, iterator: AsyncIterator[bytes]) -> AsyncIterator[ServerSentEvent]: |
| """Given an async iterator that yields raw binary data, iterate over it & yield every event encountered""" |
| ... |
|
|
|
|
| def is_stream_class_type(typ: type) -> TypeGuard[type[Stream[object]] | type[AsyncStream[object]]]: |
| """TypeGuard for determining whether or not the given type is a subclass of `Stream` / `AsyncStream`""" |
| origin = get_origin(typ) or typ |
| return inspect.isclass(origin) and issubclass(origin, (Stream, AsyncStream)) |
|
|
|
|
| def extract_stream_chunk_type( |
| stream_cls: type, |
| *, |
| failure_message: str | None = None, |
| ) -> type: |
| """Given a type like `Stream[T]`, returns the generic type variable `T`. |
| |
| This also handles the case where a concrete subclass is given, e.g. |
| ```py |
| class MyStream(Stream[bytes]): |
| ... |
| |
| extract_stream_chunk_type(MyStream) -> bytes |
| ``` |
| """ |
| from ._base_client import Stream, AsyncStream |
|
|
| return extract_type_var_from_base( |
| stream_cls, |
| index=0, |
| generic_bases=cast("tuple[type, ...]", (Stream, AsyncStream)), |
| failure_message=failure_message, |
| ) |
|
|