| from __future__ import annotations |
|
|
| __all__ = ( |
| "BufferedByteReceiveStream", |
| "BufferedByteStream", |
| "BufferedConnectable", |
| ) |
|
|
| import sys |
| from collections.abc import Callable, Iterable, Mapping |
| from dataclasses import dataclass, field |
| from typing import Any, SupportsIndex |
|
|
| from .. import ClosedResourceError, DelimiterNotFound, EndOfStream, IncompleteRead |
| from ..abc import ( |
| AnyByteReceiveStream, |
| AnyByteStream, |
| AnyByteStreamConnectable, |
| ByteReceiveStream, |
| ByteStream, |
| ByteStreamConnectable, |
| ) |
|
|
| if sys.version_info >= (3, 12): |
| from typing import override |
| else: |
| from typing_extensions import override |
|
|
|
|
| @dataclass(eq=False) |
| class BufferedByteReceiveStream(ByteReceiveStream): |
| """ |
| Wraps any bytes-based receive stream and uses a buffer to provide sophisticated |
| receiving capabilities in the form of a byte stream. |
| """ |
|
|
| receive_stream: AnyByteReceiveStream |
| _buffer: bytearray = field(init=False, default_factory=bytearray) |
| _closed: bool = field(init=False, default=False) |
|
|
| async def aclose(self) -> None: |
| await self.receive_stream.aclose() |
| self._closed = True |
|
|
| @property |
| def buffer(self) -> bytes: |
| """The bytes currently in the buffer.""" |
| return bytes(self._buffer) |
|
|
| @property |
| def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: |
| return self.receive_stream.extra_attributes |
|
|
| def feed_data(self, data: Iterable[SupportsIndex], /) -> None: |
| """ |
| Append data directly into the buffer. |
| |
| Any data in the buffer will be consumed by receive operations before receiving |
| anything from the wrapped stream. |
| |
| :param data: the data to append to the buffer (can be bytes or anything else |
| that supports ``__index__()``) |
| |
| """ |
| self._buffer.extend(data) |
|
|
| async def receive(self, max_bytes: int = 65536) -> bytes: |
| if self._closed: |
| raise ClosedResourceError |
|
|
| if self._buffer: |
| chunk = bytes(self._buffer[:max_bytes]) |
| del self._buffer[:max_bytes] |
| return chunk |
| elif isinstance(self.receive_stream, ByteReceiveStream): |
| return await self.receive_stream.receive(max_bytes) |
| else: |
| |
| |
| chunk = await self.receive_stream.receive() |
| if len(chunk) > max_bytes: |
| |
| self._buffer.extend(chunk[max_bytes:]) |
| return chunk[:max_bytes] |
| else: |
| return chunk |
|
|
| async def receive_exactly(self, nbytes: int) -> bytes: |
| """ |
| Read exactly the given amount of bytes from the stream. |
| |
| :param nbytes: the number of bytes to read |
| :return: the bytes read |
| :raises ~anyio.IncompleteRead: if the stream was closed before the requested |
| amount of bytes could be read from the stream |
| |
| """ |
| while True: |
| remaining = nbytes - len(self._buffer) |
| if remaining <= 0: |
| retval = self._buffer[:nbytes] |
| del self._buffer[:nbytes] |
| return bytes(retval) |
|
|
| try: |
| if isinstance(self.receive_stream, ByteReceiveStream): |
| chunk = await self.receive_stream.receive(remaining) |
| else: |
| chunk = await self.receive_stream.receive() |
| except EndOfStream as exc: |
| raise IncompleteRead from exc |
|
|
| self._buffer.extend(chunk) |
|
|
| async def receive_until(self, delimiter: bytes, max_bytes: int) -> bytes: |
| """ |
| Read from the stream until the delimiter is found or max_bytes have been read. |
| |
| :param delimiter: the marker to look for in the stream |
| :param max_bytes: maximum number of bytes that will be read before raising |
| :exc:`~anyio.DelimiterNotFound` |
| :return: the bytes read (not including the delimiter) |
| :raises ~anyio.IncompleteRead: if the stream was closed before the delimiter |
| was found |
| :raises ~anyio.DelimiterNotFound: if the delimiter is not found within the |
| bytes read up to the maximum allowed |
| |
| """ |
| delimiter_size = len(delimiter) |
| offset = 0 |
| while True: |
| |
| index = self._buffer.find(delimiter, offset) |
| if index >= 0: |
| found = self._buffer[:index] |
| del self._buffer[: index + len(delimiter) :] |
| return bytes(found) |
|
|
| |
| if len(self._buffer) >= max_bytes: |
| raise DelimiterNotFound(max_bytes) |
|
|
| |
| try: |
| data = await self.receive_stream.receive() |
| except EndOfStream as exc: |
| raise IncompleteRead from exc |
|
|
| |
| offset = max(len(self._buffer) - delimiter_size + 1, 0) |
| self._buffer.extend(data) |
|
|
|
|
| class BufferedByteStream(BufferedByteReceiveStream, ByteStream): |
| """ |
| A full-duplex variant of :class:`BufferedByteReceiveStream`. All writes are passed |
| through to the wrapped stream as-is. |
| """ |
|
|
| def __init__(self, stream: AnyByteStream): |
| """ |
| :param stream: the stream to be wrapped |
| |
| """ |
| super().__init__(stream) |
| self._stream = stream |
|
|
| @override |
| async def send_eof(self) -> None: |
| await self._stream.send_eof() |
|
|
| @override |
| async def send(self, item: bytes) -> None: |
| await self._stream.send(item) |
|
|
|
|
| class BufferedConnectable(ByteStreamConnectable): |
| def __init__(self, connectable: AnyByteStreamConnectable): |
| """ |
| :param connectable: the connectable to wrap |
| |
| """ |
| self.connectable = connectable |
|
|
| @override |
| async def connect(self) -> BufferedByteStream: |
| stream = await self.connectable.connect() |
| return BufferedByteStream(stream) |
|
|