chadlinden's picture
Upload folder using huggingface_hub
8fd238c verified
raw
history blame
82 kB
from __future__ import annotations
import array
import asyncio
import concurrent.futures
import math
import socket
import sys
import threading
from asyncio import (
AbstractEventLoop,
CancelledError,
all_tasks,
create_task,
current_task,
get_running_loop,
sleep,
)
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined]
from collections import OrderedDict, deque
from collections.abc import AsyncIterator, Generator, Iterable
from concurrent.futures import Future
from contextlib import suppress
from contextvars import Context, copy_context
from dataclasses import dataclass
from functools import partial, wraps
from inspect import (
CORO_RUNNING,
CORO_SUSPENDED,
getcoroutinestate,
iscoroutine,
)
from io import IOBase
from os import PathLike
from queue import Queue
from signal import Signals
from socket import AddressFamily, SocketKind
from threading import Thread
from types import TracebackType
from typing import (
IO,
Any,
AsyncGenerator,
Awaitable,
Callable,
Collection,
ContextManager,
Coroutine,
Mapping,
Optional,
Sequence,
Tuple,
TypeVar,
cast,
)
from weakref import WeakKeyDictionary
import sniffio
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc
from .._core._eventloop import claim_worker_thread, threadlocals
from .._core._exceptions import (
BrokenResourceError,
BusyResourceError,
ClosedResourceError,
EndOfStream,
WouldBlock,
)
from .._core._sockets import convert_ipv6_sockaddr
from .._core._streams import create_memory_object_stream
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter
from .._core._synchronization import Event as BaseEvent
from .._core._synchronization import ResourceGuard
from .._core._tasks import CancelScope as BaseCancelScope
from ..abc import (
AsyncBackend,
IPSockAddrType,
SocketListener,
UDPPacketType,
UNIXDatagramPacketType,
)
from ..lowlevel import RunVar
from ..streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
if sys.version_info >= (3, 11):
from asyncio import Runner
from typing import TypeVarTuple, Unpack
else:
import contextvars
import enum
import signal
from asyncio import coroutines, events, exceptions, tasks
from exceptiongroup import BaseExceptionGroup
from typing_extensions import TypeVarTuple, Unpack
class _State(enum.Enum):
CREATED = "created"
INITIALIZED = "initialized"
CLOSED = "closed"
class Runner:
# Copied from CPython 3.11
def __init__(
self,
*,
debug: bool | None = None,
loop_factory: Callable[[], AbstractEventLoop] | None = None,
):
self._state = _State.CREATED
self._debug = debug
self._loop_factory = loop_factory
self._loop: AbstractEventLoop | None = None
self._context = None
self._interrupt_count = 0
self._set_event_loop = False
def __enter__(self) -> Runner:
self._lazy_init()
return self
def __exit__(
self,
exc_type: type[BaseException],
exc_val: BaseException,
exc_tb: TracebackType,
) -> None:
self.close()
def close(self) -> None:
"""Shutdown and close event loop."""
if self._state is not _State.INITIALIZED:
return
try:
loop = self._loop
_cancel_all_tasks(loop)
loop.run_until_complete(loop.shutdown_asyncgens())
if hasattr(loop, "shutdown_default_executor"):
loop.run_until_complete(loop.shutdown_default_executor())
else:
loop.run_until_complete(_shutdown_default_executor(loop))
finally:
if self._set_event_loop:
events.set_event_loop(None)
loop.close()
self._loop = None
self._state = _State.CLOSED
def get_loop(self) -> AbstractEventLoop:
"""Return embedded event loop."""
self._lazy_init()
return self._loop
def run(self, coro: Coroutine[T_Retval], *, context=None) -> T_Retval:
"""Run a coroutine inside the embedded event loop."""
if not coroutines.iscoroutine(coro):
raise ValueError(f"a coroutine was expected, got {coro!r}")
if events._get_running_loop() is not None:
# fail fast with short traceback
raise RuntimeError(
"Runner.run() cannot be called from a running event loop"
)
self._lazy_init()
if context is None:
context = self._context
task = context.run(self._loop.create_task, coro)
if (
threading.current_thread() is threading.main_thread()
and signal.getsignal(signal.SIGINT) is signal.default_int_handler
):
sigint_handler = partial(self._on_sigint, main_task=task)
try:
signal.signal(signal.SIGINT, sigint_handler)
except ValueError:
# `signal.signal` may throw if `threading.main_thread` does
# not support signals (e.g. embedded interpreter with signals
# not registered - see gh-91880)
sigint_handler = None
else:
sigint_handler = None
self._interrupt_count = 0
try:
return self._loop.run_until_complete(task)
except exceptions.CancelledError:
if self._interrupt_count > 0:
uncancel = getattr(task, "uncancel", None)
if uncancel is not None and uncancel() == 0:
raise KeyboardInterrupt()
raise # CancelledError
finally:
if (
sigint_handler is not None
and signal.getsignal(signal.SIGINT) is sigint_handler
):
signal.signal(signal.SIGINT, signal.default_int_handler)
def _lazy_init(self) -> None:
if self._state is _State.CLOSED:
raise RuntimeError("Runner is closed")
if self._state is _State.INITIALIZED:
return
if self._loop_factory is None:
self._loop = events.new_event_loop()
if not self._set_event_loop:
# Call set_event_loop only once to avoid calling
# attach_loop multiple times on child watchers
events.set_event_loop(self._loop)
self._set_event_loop = True
else:
self._loop = self._loop_factory()
if self._debug is not None:
self._loop.set_debug(self._debug)
self._context = contextvars.copy_context()
self._state = _State.INITIALIZED
def _on_sigint(self, signum, frame, main_task: asyncio.Task) -> None:
self._interrupt_count += 1
if self._interrupt_count == 1 and not main_task.done():
main_task.cancel()
# wakeup loop if it is blocked by select() with long timeout
self._loop.call_soon_threadsafe(lambda: None)
return
raise KeyboardInterrupt()
def _cancel_all_tasks(loop: AbstractEventLoop) -> None:
to_cancel = tasks.all_tasks(loop)
if not to_cancel:
return
for task in to_cancel:
task.cancel()
loop.run_until_complete(tasks.gather(*to_cancel, return_exceptions=True))
for task in to_cancel:
if task.cancelled():
continue
if task.exception() is not None:
loop.call_exception_handler(
{
"message": "unhandled exception during asyncio.run() shutdown",
"exception": task.exception(),
"task": task,
}
)
async def _shutdown_default_executor(loop: AbstractEventLoop) -> None:
"""Schedule the shutdown of the default executor."""
def _do_shutdown(future: asyncio.futures.Future) -> None:
try:
loop._default_executor.shutdown(wait=True) # type: ignore[attr-defined]
loop.call_soon_threadsafe(future.set_result, None)
except Exception as ex:
loop.call_soon_threadsafe(future.set_exception, ex)
loop._executor_shutdown_called = True
if loop._default_executor is None:
return
future = loop.create_future()
thread = threading.Thread(target=_do_shutdown, args=(future,))
thread.start()
try:
await future
finally:
thread.join()
T_Retval = TypeVar("T_Retval")
T_contra = TypeVar("T_contra", contravariant=True)
PosArgsT = TypeVarTuple("PosArgsT")
P = ParamSpec("P")
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task")
def find_root_task() -> asyncio.Task:
root_task = _root_task.get(None)
if root_task is not None and not root_task.done():
return root_task
# Look for a task that has been started via run_until_complete()
for task in all_tasks():
if task._callbacks and not task.done():
callbacks = [cb for cb, context in task._callbacks]
for cb in callbacks:
if (
cb is _run_until_complete_cb
or getattr(cb, "__module__", None) == "uvloop.loop"
):
_root_task.set(task)
return task
# Look up the topmost task in the AnyIO task tree, if possible
task = cast(asyncio.Task, current_task())
state = _task_states.get(task)
if state:
cancel_scope = state.cancel_scope
while cancel_scope and cancel_scope._parent_scope is not None:
cancel_scope = cancel_scope._parent_scope
if cancel_scope is not None:
return cast(asyncio.Task, cancel_scope._host_task)
return task
def get_callable_name(func: Callable) -> str:
module = getattr(func, "__module__", None)
qualname = getattr(func, "__qualname__", None)
return ".".join([x for x in (module, qualname) if x])
#
# Event loop
#
_run_vars: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] = WeakKeyDictionary()
def _task_started(task: asyncio.Task) -> bool:
"""Return ``True`` if the task has been started and has not finished."""
try:
return getcoroutinestate(task.get_coro()) in (CORO_RUNNING, CORO_SUSPENDED)
except AttributeError:
# task coro is async_genenerator_asend https://bugs.python.org/issue37771
raise Exception(f"Cannot determine if task {task} has started or not") from None
#
# Timeouts and cancellation
#
class CancelScope(BaseCancelScope):
def __new__(
cls, *, deadline: float = math.inf, shield: bool = False
) -> CancelScope:
return object.__new__(cls)
def __init__(self, deadline: float = math.inf, shield: bool = False):
self._deadline = deadline
self._shield = shield
self._parent_scope: CancelScope | None = None
self._child_scopes: set[CancelScope] = set()
self._cancel_called = False
self._cancelled_caught = False
self._active = False
self._timeout_handle: asyncio.TimerHandle | None = None
self._cancel_handle: asyncio.Handle | None = None
self._tasks: set[asyncio.Task] = set()
self._host_task: asyncio.Task | None = None
self._cancel_calls: int = 0
self._cancelling: int | None = None
def __enter__(self) -> CancelScope:
if self._active:
raise RuntimeError(
"Each CancelScope may only be used for a single 'with' block"
)
self._host_task = host_task = cast(asyncio.Task, current_task())
self._tasks.add(host_task)
try:
task_state = _task_states[host_task]
except KeyError:
task_state = TaskState(None, self)
_task_states[host_task] = task_state
else:
self._parent_scope = task_state.cancel_scope
task_state.cancel_scope = self
if self._parent_scope is not None:
self._parent_scope._child_scopes.add(self)
self._parent_scope._tasks.remove(host_task)
self._timeout()
self._active = True
if sys.version_info >= (3, 11):
self._cancelling = self._host_task.cancelling()
# Start cancelling the host task if the scope was cancelled before entering
if self._cancel_called:
self._deliver_cancellation(self)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
if not self._active:
raise RuntimeError("This cancel scope is not active")
if current_task() is not self._host_task:
raise RuntimeError(
"Attempted to exit cancel scope in a different task than it was "
"entered in"
)
assert self._host_task is not None
host_task_state = _task_states.get(self._host_task)
if host_task_state is None or host_task_state.cancel_scope is not self:
raise RuntimeError(
"Attempted to exit a cancel scope that isn't the current tasks's "
"current cancel scope"
)
self._active = False
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None
self._tasks.remove(self._host_task)
if self._parent_scope is not None:
self._parent_scope._child_scopes.remove(self)
self._parent_scope._tasks.add(self._host_task)
host_task_state.cancel_scope = self._parent_scope
# Restart the cancellation effort in the closest directly cancelled parent
# scope if this one was shielded
self._restart_cancellation_in_parent()
if self._cancel_called and exc_val is not None:
for exc in iterate_exceptions(exc_val):
if isinstance(exc, CancelledError):
self._cancelled_caught = self._uncancel(exc)
if self._cancelled_caught:
break
return self._cancelled_caught
return None
def _uncancel(self, cancelled_exc: CancelledError) -> bool:
if sys.version_info < (3, 9) or self._host_task is None:
self._cancel_calls = 0
return True
# Undo all cancellations done by this scope
if self._cancelling is not None:
while self._cancel_calls:
self._cancel_calls -= 1
if self._host_task.uncancel() <= self._cancelling:
return True
self._cancel_calls = 0
return f"Cancelled by cancel scope {id(self):x}" in cancelled_exc.args
def _timeout(self) -> None:
if self._deadline != math.inf:
loop = get_running_loop()
if loop.time() >= self._deadline:
self.cancel()
else:
self._timeout_handle = loop.call_at(self._deadline, self._timeout)
def _deliver_cancellation(self, origin: CancelScope) -> bool:
"""
Deliver cancellation to directly contained tasks and nested cancel scopes.
Schedule another run at the end if we still have tasks eligible for
cancellation.
:param origin: the cancel scope that originated the cancellation
:return: ``True`` if the delivery needs to be retried on the next cycle
"""
should_retry = False
current = current_task()
for task in self._tasks:
if task._must_cancel: # type: ignore[attr-defined]
continue
# The task is eligible for cancellation if it has started
should_retry = True
if task is not current and (task is self._host_task or _task_started(task)):
waiter = task._fut_waiter # type: ignore[attr-defined]
if not isinstance(waiter, asyncio.Future) or not waiter.done():
self._cancel_calls += 1
if sys.version_info >= (3, 9):
task.cancel(f"Cancelled by cancel scope {id(origin):x}")
else:
task.cancel()
# Deliver cancellation to child scopes that aren't shielded or running their own
# cancellation callbacks
for scope in self._child_scopes:
if not scope._shield and not scope.cancel_called:
should_retry = scope._deliver_cancellation(origin) or should_retry
# Schedule another callback if there are still tasks left
if origin is self:
if should_retry:
self._cancel_handle = get_running_loop().call_soon(
self._deliver_cancellation, origin
)
else:
self._cancel_handle = None
return should_retry
def _restart_cancellation_in_parent(self) -> None:
"""
Restart the cancellation effort in the closest directly cancelled parent scope.
"""
scope = self._parent_scope
while scope is not None:
if scope._cancel_called:
if scope._cancel_handle is None:
scope._deliver_cancellation(scope)
break
# No point in looking beyond any shielded scope
if scope._shield:
break
scope = scope._parent_scope
def _parent_cancelled(self) -> bool:
# Check whether any parent has been cancelled
cancel_scope = self._parent_scope
while cancel_scope is not None and not cancel_scope._shield:
if cancel_scope._cancel_called:
return True
else:
cancel_scope = cancel_scope._parent_scope
return False
def cancel(self) -> None:
if not self._cancel_called:
if self._timeout_handle:
self._timeout_handle.cancel()
self._timeout_handle = None
self._cancel_called = True
if self._host_task is not None:
self._deliver_cancellation(self)
@property
def deadline(self) -> float:
return self._deadline
@deadline.setter
def deadline(self, value: float) -> None:
self._deadline = float(value)
if self._timeout_handle is not None:
self._timeout_handle.cancel()
self._timeout_handle = None
if self._active and not self._cancel_called:
self._timeout()
@property
def cancel_called(self) -> bool:
return self._cancel_called
@property
def cancelled_caught(self) -> bool:
return self._cancelled_caught
@property
def shield(self) -> bool:
return self._shield
@shield.setter
def shield(self, value: bool) -> None:
if self._shield != value:
self._shield = value
if not value:
self._restart_cancellation_in_parent()
#
# Task states
#
class TaskState:
"""
Encapsulates auxiliary task information that cannot be added to the Task instance
itself because there are no guarantees about its implementation.
"""
__slots__ = "parent_id", "cancel_scope"
def __init__(self, parent_id: int | None, cancel_scope: CancelScope | None):
self.parent_id = parent_id
self.cancel_scope = cancel_scope
_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState]
#
# Task groups
#
class _AsyncioTaskStatus(abc.TaskStatus):
def __init__(self, future: asyncio.Future, parent_id: int):
self._future = future
self._parent_id = parent_id
def started(self, value: T_contra | None = None) -> None:
try:
self._future.set_result(value)
except asyncio.InvalidStateError:
raise RuntimeError(
"called 'started' twice on the same task status"
) from None
task = cast(asyncio.Task, current_task())
_task_states[task].parent_id = self._parent_id
def iterate_exceptions(
exception: BaseException,
) -> Generator[BaseException, None, None]:
if isinstance(exception, BaseExceptionGroup):
for exc in exception.exceptions:
yield from iterate_exceptions(exc)
else:
yield exception
class TaskGroup(abc.TaskGroup):
def __init__(self) -> None:
self.cancel_scope: CancelScope = CancelScope()
self._active = False
self._exceptions: list[BaseException] = []
self._tasks: set[asyncio.Task] = set()
async def __aenter__(self) -> TaskGroup:
self.cancel_scope.__enter__()
self._active = True
return self
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb)
if exc_val is not None:
self.cancel_scope.cancel()
if not isinstance(exc_val, CancelledError):
self._exceptions.append(exc_val)
cancelled_exc_while_waiting_tasks: CancelledError | None = None
while self._tasks:
try:
await asyncio.wait(self._tasks)
except CancelledError as exc:
# This task was cancelled natively; reraise the CancelledError later
# unless this task was already interrupted by another exception
self.cancel_scope.cancel()
if cancelled_exc_while_waiting_tasks is None:
cancelled_exc_while_waiting_tasks = exc
self._active = False
if self._exceptions:
raise BaseExceptionGroup(
"unhandled errors in a TaskGroup", self._exceptions
)
# Raise the CancelledError received while waiting for child tasks to exit,
# unless the context manager itself was previously exited with another
# exception, or if any of the child tasks raised an exception other than
# CancelledError
if cancelled_exc_while_waiting_tasks:
if exc_val is None or ignore_exception:
raise cancelled_exc_while_waiting_tasks
return ignore_exception
def _spawn(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
args: tuple[Unpack[PosArgsT]],
name: object,
task_status_future: asyncio.Future | None = None,
) -> asyncio.Task:
def task_done(_task: asyncio.Task) -> None:
task_state = _task_states[_task]
assert task_state.cancel_scope is not None
assert _task in task_state.cancel_scope._tasks
task_state.cancel_scope._tasks.remove(_task)
self._tasks.remove(task)
del _task_states[_task]
try:
exc = _task.exception()
except CancelledError as e:
while isinstance(e.__context__, CancelledError):
e = e.__context__
exc = e
if exc is not None:
if task_status_future is None or task_status_future.done():
if not isinstance(exc, CancelledError):
self._exceptions.append(exc)
if not self.cancel_scope._parent_cancelled():
self.cancel_scope.cancel()
else:
task_status_future.set_exception(exc)
elif task_status_future is not None and not task_status_future.done():
task_status_future.set_exception(
RuntimeError("Child exited without calling task_status.started()")
)
if not self._active:
raise RuntimeError(
"This task group is not active; no new tasks can be started."
)
kwargs = {}
if task_status_future:
parent_id = id(current_task())
kwargs["task_status"] = _AsyncioTaskStatus(
task_status_future, id(self.cancel_scope._host_task)
)
else:
parent_id = id(self.cancel_scope._host_task)
coro = func(*args, **kwargs)
if not iscoroutine(coro):
prefix = f"{func.__module__}." if hasattr(func, "__module__") else ""
raise TypeError(
f"Expected {prefix}{func.__qualname__}() to return a coroutine, but "
f"the return value ({coro!r}) is not a coroutine object"
)
name = get_callable_name(func) if name is None else str(name)
task = create_task(coro, name=name)
task.add_done_callback(task_done)
# Make the spawned task inherit the task group's cancel scope
_task_states[task] = TaskState(
parent_id=parent_id, cancel_scope=self.cancel_scope
)
self.cancel_scope._tasks.add(task)
self._tasks.add(task)
return task
def start_soon(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[Any]],
*args: Unpack[PosArgsT],
name: object = None,
) -> None:
self._spawn(func, args, name)
async def start(
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None
) -> Any:
future: asyncio.Future = asyncio.Future()
task = self._spawn(func, args, name, future)
# If the task raises an exception after sending a start value without a switch
# point between, the task group is cancelled and this method never proceeds to
# process the completed future. That's why we have to have a shielded cancel
# scope here.
try:
return await future
except CancelledError:
# Cancel the task and wait for it to exit before returning
task.cancel()
with CancelScope(shield=True), suppress(CancelledError):
await task
raise
#
# Threads
#
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]]
class WorkerThread(Thread):
MAX_IDLE_TIME = 10 # seconds
def __init__(
self,
root_task: asyncio.Task,
workers: set[WorkerThread],
idle_workers: deque[WorkerThread],
):
super().__init__(name="AnyIO worker thread")
self.root_task = root_task
self.workers = workers
self.idle_workers = idle_workers
self.loop = root_task._loop
self.queue: Queue[
tuple[Context, Callable, tuple, asyncio.Future, CancelScope] | None
] = Queue(2)
self.idle_since = AsyncIOBackend.current_time()
self.stopping = False
def _report_result(
self, future: asyncio.Future, result: Any, exc: BaseException | None
) -> None:
self.idle_since = AsyncIOBackend.current_time()
if not self.stopping:
self.idle_workers.append(self)
if not future.cancelled():
if exc is not None:
if isinstance(exc, StopIteration):
new_exc = RuntimeError("coroutine raised StopIteration")
new_exc.__cause__ = exc
exc = new_exc
future.set_exception(exc)
else:
future.set_result(result)
def run(self) -> None:
with claim_worker_thread(AsyncIOBackend, self.loop):
while True:
item = self.queue.get()
if item is None:
# Shutdown command received
return
context, func, args, future, cancel_scope = item
if not future.cancelled():
result = None
exception: BaseException | None = None
threadlocals.current_cancel_scope = cancel_scope
try:
result = context.run(func, *args)
except BaseException as exc:
exception = exc
finally:
del threadlocals.current_cancel_scope
if not self.loop.is_closed():
self.loop.call_soon_threadsafe(
self._report_result, future, result, exception
)
self.queue.task_done()
def stop(self, f: asyncio.Task | None = None) -> None:
self.stopping = True
self.queue.put_nowait(None)
self.workers.discard(self)
try:
self.idle_workers.remove(self)
except ValueError:
pass
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar(
"_threadpool_idle_workers"
)
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers")
class BlockingPortal(abc.BlockingPortal):
def __new__(cls) -> BlockingPortal:
return object.__new__(cls)
def __init__(self) -> None:
super().__init__()
self._loop = get_running_loop()
def _spawn_task_from_thread(
self,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
name: object,
future: Future[T_Retval],
) -> None:
AsyncIOBackend.run_sync_from_thread(
partial(self._task_group.start_soon, name=name),
(self._call_func, func, args, kwargs, future),
self._loop,
)
#
# Subprocesses
#
@dataclass(eq=False)
class StreamReaderWrapper(abc.ByteReceiveStream):
_stream: asyncio.StreamReader
async def receive(self, max_bytes: int = 65536) -> bytes:
data = await self._stream.read(max_bytes)
if data:
return data
else:
raise EndOfStream
async def aclose(self) -> None:
self._stream.feed_eof()
await AsyncIOBackend.checkpoint()
@dataclass(eq=False)
class StreamWriterWrapper(abc.ByteSendStream):
_stream: asyncio.StreamWriter
async def send(self, item: bytes) -> None:
self._stream.write(item)
await self._stream.drain()
async def aclose(self) -> None:
self._stream.close()
await AsyncIOBackend.checkpoint()
@dataclass(eq=False)
class Process(abc.Process):
_process: asyncio.subprocess.Process
_stdin: StreamWriterWrapper | None
_stdout: StreamReaderWrapper | None
_stderr: StreamReaderWrapper | None
async def aclose(self) -> None:
with CancelScope(shield=True):
if self._stdin:
await self._stdin.aclose()
if self._stdout:
await self._stdout.aclose()
if self._stderr:
await self._stderr.aclose()
try:
await self.wait()
except BaseException:
self.kill()
with CancelScope(shield=True):
await self.wait()
raise
async def wait(self) -> int:
return await self._process.wait()
def terminate(self) -> None:
self._process.terminate()
def kill(self) -> None:
self._process.kill()
def send_signal(self, signal: int) -> None:
self._process.send_signal(signal)
@property
def pid(self) -> int:
return self._process.pid
@property
def returncode(self) -> int | None:
return self._process.returncode
@property
def stdin(self) -> abc.ByteSendStream | None:
return self._stdin
@property
def stdout(self) -> abc.ByteReceiveStream | None:
return self._stdout
@property
def stderr(self) -> abc.ByteReceiveStream | None:
return self._stderr
def _forcibly_shutdown_process_pool_on_exit(
workers: set[Process], _task: object
) -> None:
"""
Forcibly shuts down worker processes belonging to this event loop."""
child_watcher: asyncio.AbstractChildWatcher | None = None
if sys.version_info < (3, 12):
try:
child_watcher = asyncio.get_event_loop_policy().get_child_watcher()
except NotImplementedError:
pass
# Close as much as possible (w/o async/await) to avoid warnings
for process in workers:
if process.returncode is None:
continue
process._stdin._stream._transport.close() # type: ignore[union-attr]
process._stdout._stream._transport.close() # type: ignore[union-attr]
process._stderr._stream._transport.close() # type: ignore[union-attr]
process.kill()
if child_watcher:
child_watcher.remove_child_handler(process.pid)
async def _shutdown_process_pool_on_exit(workers: set[abc.Process]) -> None:
"""
Shuts down worker processes belonging to this event loop.
NOTE: this only works when the event loop was started using asyncio.run() or
anyio.run().
"""
process: abc.Process
try:
await sleep(math.inf)
except asyncio.CancelledError:
for process in workers:
if process.returncode is None:
process.kill()
for process in workers:
await process.aclose()
#
# Sockets and networking
#
class StreamProtocol(asyncio.Protocol):
read_queue: deque[bytes]
read_event: asyncio.Event
write_event: asyncio.Event
exception: Exception | None = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.read_queue = deque()
self.read_event = asyncio.Event()
self.write_event = asyncio.Event()
self.write_event.set()
cast(asyncio.Transport, transport).set_write_buffer_limits(0)
def connection_lost(self, exc: Exception | None) -> None:
if exc:
self.exception = BrokenResourceError()
self.exception.__cause__ = exc
self.read_event.set()
self.write_event.set()
def data_received(self, data: bytes) -> None:
self.read_queue.append(data)
self.read_event.set()
def eof_received(self) -> bool | None:
self.read_event.set()
return True
def pause_writing(self) -> None:
self.write_event = asyncio.Event()
def resume_writing(self) -> None:
self.write_event.set()
class DatagramProtocol(asyncio.DatagramProtocol):
read_queue: deque[tuple[bytes, IPSockAddrType]]
read_event: asyncio.Event
write_event: asyncio.Event
exception: Exception | None = None
def connection_made(self, transport: asyncio.BaseTransport) -> None:
self.read_queue = deque(maxlen=100) # arbitrary value
self.read_event = asyncio.Event()
self.write_event = asyncio.Event()
self.write_event.set()
def connection_lost(self, exc: Exception | None) -> None:
self.read_event.set()
self.write_event.set()
def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None:
addr = convert_ipv6_sockaddr(addr)
self.read_queue.append((data, addr))
self.read_event.set()
def error_received(self, exc: Exception) -> None:
self.exception = exc
def pause_writing(self) -> None:
self.write_event.clear()
def resume_writing(self) -> None:
self.write_event.set()
class SocketStream(abc.SocketStream):
def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol):
self._transport = transport
self._protocol = protocol
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")
self._closed = False
@property
def _raw_socket(self) -> socket.socket:
return self._transport.get_extra_info("socket")
async def receive(self, max_bytes: int = 65536) -> bytes:
with self._receive_guard:
await AsyncIOBackend.checkpoint()
if (
not self._protocol.read_event.is_set()
and not self._transport.is_closing()
):
self._transport.resume_reading()
await self._protocol.read_event.wait()
self._transport.pause_reading()
try:
chunk = self._protocol.read_queue.popleft()
except IndexError:
if self._closed:
raise ClosedResourceError from None
elif self._protocol.exception:
raise self._protocol.exception from None
else:
raise EndOfStream from None
if len(chunk) > max_bytes:
# Split the oversized chunk
chunk, leftover = chunk[:max_bytes], chunk[max_bytes:]
self._protocol.read_queue.appendleft(leftover)
# If the read queue is empty, clear the flag so that the next call will
# block until data is available
if not self._protocol.read_queue:
self._protocol.read_event.clear()
return chunk
async def send(self, item: bytes) -> None:
with self._send_guard:
await AsyncIOBackend.checkpoint()
if self._closed:
raise ClosedResourceError
elif self._protocol.exception is not None:
raise self._protocol.exception
try:
self._transport.write(item)
except RuntimeError as exc:
if self._transport.is_closing():
raise BrokenResourceError from exc
else:
raise
await self._protocol.write_event.wait()
async def send_eof(self) -> None:
try:
self._transport.write_eof()
except OSError:
pass
async def aclose(self) -> None:
if not self._transport.is_closing():
self._closed = True
try:
self._transport.write_eof()
except OSError:
pass
self._transport.close()
await sleep(0)
self._transport.abort()
class _RawSocketMixin:
_receive_future: asyncio.Future | None = None
_send_future: asyncio.Future | None = None
_closing = False
def __init__(self, raw_socket: socket.socket):
self.__raw_socket = raw_socket
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")
@property
def _raw_socket(self) -> socket.socket:
return self.__raw_socket
def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
def callback(f: object) -> None:
del self._receive_future
loop.remove_reader(self.__raw_socket)
f = self._receive_future = asyncio.Future()
loop.add_reader(self.__raw_socket, f.set_result, None)
f.add_done_callback(callback)
return f
def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future:
def callback(f: object) -> None:
del self._send_future
loop.remove_writer(self.__raw_socket)
f = self._send_future = asyncio.Future()
loop.add_writer(self.__raw_socket, f.set_result, None)
f.add_done_callback(callback)
return f
async def aclose(self) -> None:
if not self._closing:
self._closing = True
if self.__raw_socket.fileno() != -1:
self.__raw_socket.close()
if self._receive_future:
self._receive_future.set_result(None)
if self._send_future:
self._send_future.set_result(None)
class UNIXSocketStream(_RawSocketMixin, abc.UNIXSocketStream):
async def send_eof(self) -> None:
with self._send_guard:
self._raw_socket.shutdown(socket.SHUT_WR)
async def receive(self, max_bytes: int = 65536) -> bytes:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
with self._receive_guard:
while True:
try:
data = self._raw_socket.recv(max_bytes)
except BlockingIOError:
await self._wait_until_readable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
if not data:
raise EndOfStream
return data
async def send(self, item: bytes) -> None:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
with self._send_guard:
view = memoryview(item)
while view:
try:
bytes_sent = self._raw_socket.send(view)
except BlockingIOError:
await self._wait_until_writable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
view = view[bytes_sent:]
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
if not isinstance(msglen, int) or msglen < 0:
raise ValueError("msglen must be a non-negative integer")
if not isinstance(maxfds, int) or maxfds < 1:
raise ValueError("maxfds must be a positive integer")
loop = get_running_loop()
fds = array.array("i")
await AsyncIOBackend.checkpoint()
with self._receive_guard:
while True:
try:
message, ancdata, flags, addr = self._raw_socket.recvmsg(
msglen, socket.CMSG_LEN(maxfds * fds.itemsize)
)
except BlockingIOError:
await self._wait_until_readable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
if not message and not ancdata:
raise EndOfStream
break
for cmsg_level, cmsg_type, cmsg_data in ancdata:
if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS:
raise RuntimeError(
f"Received unexpected ancillary data; message = {message!r}, "
f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}"
)
fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)])
return message, list(fds)
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
if not message:
raise ValueError("message must not be empty")
if not fds:
raise ValueError("fds must not be empty")
loop = get_running_loop()
filenos: list[int] = []
for fd in fds:
if isinstance(fd, int):
filenos.append(fd)
elif isinstance(fd, IOBase):
filenos.append(fd.fileno())
fdarray = array.array("i", filenos)
await AsyncIOBackend.checkpoint()
with self._send_guard:
while True:
try:
# The ignore can be removed after mypy picks up
# https://github.com/python/typeshed/pull/5545
self._raw_socket.sendmsg(
[message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)]
)
break
except BlockingIOError:
await self._wait_until_writable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
class TCPSocketListener(abc.SocketListener):
_accept_scope: CancelScope | None = None
_closed = False
def __init__(self, raw_socket: socket.socket):
self.__raw_socket = raw_socket
self._loop = cast(asyncio.BaseEventLoop, get_running_loop())
self._accept_guard = ResourceGuard("accepting connections from")
@property
def _raw_socket(self) -> socket.socket:
return self.__raw_socket
async def accept(self) -> abc.SocketStream:
if self._closed:
raise ClosedResourceError
with self._accept_guard:
await AsyncIOBackend.checkpoint()
with CancelScope() as self._accept_scope:
try:
client_sock, _addr = await self._loop.sock_accept(self._raw_socket)
except asyncio.CancelledError:
# Workaround for https://bugs.python.org/issue41317
try:
self._loop.remove_reader(self._raw_socket)
except (ValueError, NotImplementedError):
pass
if self._closed:
raise ClosedResourceError from None
raise
finally:
self._accept_scope = None
client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
transport, protocol = await self._loop.connect_accepted_socket(
StreamProtocol, client_sock
)
return SocketStream(transport, protocol)
async def aclose(self) -> None:
if self._closed:
return
self._closed = True
if self._accept_scope:
# Workaround for https://bugs.python.org/issue41317
try:
self._loop.remove_reader(self._raw_socket)
except (ValueError, NotImplementedError):
pass
self._accept_scope.cancel()
await sleep(0)
self._raw_socket.close()
class UNIXSocketListener(abc.SocketListener):
def __init__(self, raw_socket: socket.socket):
self.__raw_socket = raw_socket
self._loop = get_running_loop()
self._accept_guard = ResourceGuard("accepting connections from")
self._closed = False
async def accept(self) -> abc.SocketStream:
await AsyncIOBackend.checkpoint()
with self._accept_guard:
while True:
try:
client_sock, _ = self.__raw_socket.accept()
client_sock.setblocking(False)
return UNIXSocketStream(client_sock)
except BlockingIOError:
f: asyncio.Future = asyncio.Future()
self._loop.add_reader(self.__raw_socket, f.set_result, None)
f.add_done_callback(
lambda _: self._loop.remove_reader(self.__raw_socket)
)
await f
except OSError as exc:
if self._closed:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
async def aclose(self) -> None:
self._closed = True
self.__raw_socket.close()
@property
def _raw_socket(self) -> socket.socket:
return self.__raw_socket
class UDPSocket(abc.UDPSocket):
def __init__(
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
):
self._transport = transport
self._protocol = protocol
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")
self._closed = False
@property
def _raw_socket(self) -> socket.socket:
return self._transport.get_extra_info("socket")
async def aclose(self) -> None:
if not self._transport.is_closing():
self._closed = True
self._transport.close()
async def receive(self) -> tuple[bytes, IPSockAddrType]:
with self._receive_guard:
await AsyncIOBackend.checkpoint()
# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
self._protocol.read_event.clear()
await self._protocol.read_event.wait()
try:
return self._protocol.read_queue.popleft()
except IndexError:
if self._closed:
raise ClosedResourceError from None
else:
raise BrokenResourceError from None
async def send(self, item: UDPPacketType) -> None:
with self._send_guard:
await AsyncIOBackend.checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
elif self._transport.is_closing():
raise BrokenResourceError
else:
self._transport.sendto(*item)
class ConnectedUDPSocket(abc.ConnectedUDPSocket):
def __init__(
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol
):
self._transport = transport
self._protocol = protocol
self._receive_guard = ResourceGuard("reading from")
self._send_guard = ResourceGuard("writing to")
self._closed = False
@property
def _raw_socket(self) -> socket.socket:
return self._transport.get_extra_info("socket")
async def aclose(self) -> None:
if not self._transport.is_closing():
self._closed = True
self._transport.close()
async def receive(self) -> bytes:
with self._receive_guard:
await AsyncIOBackend.checkpoint()
# If the buffer is empty, ask for more data
if not self._protocol.read_queue and not self._transport.is_closing():
self._protocol.read_event.clear()
await self._protocol.read_event.wait()
try:
packet = self._protocol.read_queue.popleft()
except IndexError:
if self._closed:
raise ClosedResourceError from None
else:
raise BrokenResourceError from None
return packet[0]
async def send(self, item: bytes) -> None:
with self._send_guard:
await AsyncIOBackend.checkpoint()
await self._protocol.write_event.wait()
if self._closed:
raise ClosedResourceError
elif self._transport.is_closing():
raise BrokenResourceError
else:
self._transport.sendto(item)
class UNIXDatagramSocket(_RawSocketMixin, abc.UNIXDatagramSocket):
async def receive(self) -> UNIXDatagramPacketType:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
with self._receive_guard:
while True:
try:
data = self._raw_socket.recvfrom(65536)
except BlockingIOError:
await self._wait_until_readable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
return data
async def send(self, item: UNIXDatagramPacketType) -> None:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
with self._send_guard:
while True:
try:
self._raw_socket.sendto(*item)
except BlockingIOError:
await self._wait_until_writable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
return
class ConnectedUNIXDatagramSocket(_RawSocketMixin, abc.ConnectedUNIXDatagramSocket):
async def receive(self) -> bytes:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
with self._receive_guard:
while True:
try:
data = self._raw_socket.recv(65536)
except BlockingIOError:
await self._wait_until_readable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
return data
async def send(self, item: bytes) -> None:
loop = get_running_loop()
await AsyncIOBackend.checkpoint()
with self._send_guard:
while True:
try:
self._raw_socket.send(item)
except BlockingIOError:
await self._wait_until_writable(loop)
except OSError as exc:
if self._closing:
raise ClosedResourceError from None
else:
raise BrokenResourceError from exc
else:
return
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events")
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events")
#
# Synchronization
#
class Event(BaseEvent):
def __new__(cls) -> Event:
return object.__new__(cls)
def __init__(self) -> None:
self._event = asyncio.Event()
def set(self) -> None:
self._event.set()
def is_set(self) -> bool:
return self._event.is_set()
async def wait(self) -> None:
if self.is_set():
await AsyncIOBackend.checkpoint()
else:
await self._event.wait()
def statistics(self) -> EventStatistics:
return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined]
class CapacityLimiter(BaseCapacityLimiter):
_total_tokens: float = 0
def __new__(cls, total_tokens: float) -> CapacityLimiter:
return object.__new__(cls)
def __init__(self, total_tokens: float):
self._borrowers: set[Any] = set()
self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict()
self.total_tokens = total_tokens
async def __aenter__(self) -> None:
await self.acquire()
async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self.release()
@property
def total_tokens(self) -> float:
return self._total_tokens
@total_tokens.setter
def total_tokens(self, value: float) -> None:
if not isinstance(value, int) and not math.isinf(value):
raise TypeError("total_tokens must be an int or math.inf")
if value < 1:
raise ValueError("total_tokens must be >= 1")
waiters_to_notify = max(value - self._total_tokens, 0)
self._total_tokens = value
# Notify waiting tasks that they have acquired the limiter
while self._wait_queue and waiters_to_notify:
event = self._wait_queue.popitem(last=False)[1]
event.set()
waiters_to_notify -= 1
@property
def borrowed_tokens(self) -> int:
return len(self._borrowers)
@property
def available_tokens(self) -> float:
return self._total_tokens - len(self._borrowers)
def acquire_nowait(self) -> None:
self.acquire_on_behalf_of_nowait(current_task())
def acquire_on_behalf_of_nowait(self, borrower: object) -> None:
if borrower in self._borrowers:
raise RuntimeError(
"this borrower is already holding one of this CapacityLimiter's "
"tokens"
)
if self._wait_queue or len(self._borrowers) >= self._total_tokens:
raise WouldBlock
self._borrowers.add(borrower)
async def acquire(self) -> None:
return await self.acquire_on_behalf_of(current_task())
async def acquire_on_behalf_of(self, borrower: object) -> None:
await AsyncIOBackend.checkpoint_if_cancelled()
try:
self.acquire_on_behalf_of_nowait(borrower)
except WouldBlock:
event = asyncio.Event()
self._wait_queue[borrower] = event
try:
await event.wait()
except BaseException:
self._wait_queue.pop(borrower, None)
raise
self._borrowers.add(borrower)
else:
try:
await AsyncIOBackend.cancel_shielded_checkpoint()
except BaseException:
self.release()
raise
def release(self) -> None:
self.release_on_behalf_of(current_task())
def release_on_behalf_of(self, borrower: object) -> None:
try:
self._borrowers.remove(borrower)
except KeyError:
raise RuntimeError(
"this borrower isn't holding any of this CapacityLimiter's " "tokens"
) from None
# Notify the next task in line if this limiter has free capacity now
if self._wait_queue and len(self._borrowers) < self._total_tokens:
event = self._wait_queue.popitem(last=False)[1]
event.set()
def statistics(self) -> CapacityLimiterStatistics:
return CapacityLimiterStatistics(
self.borrowed_tokens,
self.total_tokens,
tuple(self._borrowers),
len(self._wait_queue),
)
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter")
#
# Operating system signals
#
class _SignalReceiver:
def __init__(self, signals: tuple[Signals, ...]):
self._signals = signals
self._loop = get_running_loop()
self._signal_queue: deque[Signals] = deque()
self._future: asyncio.Future = asyncio.Future()
self._handled_signals: set[Signals] = set()
def _deliver(self, signum: Signals) -> None:
self._signal_queue.append(signum)
if not self._future.done():
self._future.set_result(None)
def __enter__(self) -> _SignalReceiver:
for sig in set(self._signals):
self._loop.add_signal_handler(sig, self._deliver, sig)
self._handled_signals.add(sig)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> bool | None:
for sig in self._handled_signals:
self._loop.remove_signal_handler(sig)
return None
def __aiter__(self) -> _SignalReceiver:
return self
async def __anext__(self) -> Signals:
await AsyncIOBackend.checkpoint()
if not self._signal_queue:
self._future = asyncio.Future()
await self._future
return self._signal_queue.popleft()
#
# Testing and debugging
#
def _create_task_info(task: asyncio.Task) -> TaskInfo:
task_state = _task_states.get(task)
if task_state is None:
parent_id = None
else:
parent_id = task_state.parent_id
return TaskInfo(id(task), parent_id, task.get_name(), task.get_coro())
class TestRunner(abc.TestRunner):
_send_stream: MemoryObjectSendStream[tuple[Awaitable[Any], asyncio.Future[Any]]]
def __init__(
self,
*,
debug: bool | None = None,
use_uvloop: bool = False,
loop_factory: Callable[[], AbstractEventLoop] | None = None,
) -> None:
if use_uvloop and loop_factory is None:
import uvloop
loop_factory = uvloop.new_event_loop
self._runner = Runner(debug=debug, loop_factory=loop_factory)
self._exceptions: list[BaseException] = []
self._runner_task: asyncio.Task | None = None
def __enter__(self) -> TestRunner:
self._runner.__enter__()
self.get_loop().set_exception_handler(self._exception_handler)
return self
def __exit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
self._runner.__exit__(exc_type, exc_val, exc_tb)
def get_loop(self) -> AbstractEventLoop:
return self._runner.get_loop()
def _exception_handler(
self, loop: asyncio.AbstractEventLoop, context: dict[str, Any]
) -> None:
if isinstance(context.get("exception"), Exception):
self._exceptions.append(context["exception"])
else:
loop.default_exception_handler(context)
def _raise_async_exceptions(self) -> None:
# Re-raise any exceptions raised in asynchronous callbacks
if self._exceptions:
exceptions, self._exceptions = self._exceptions, []
if len(exceptions) == 1:
raise exceptions[0]
elif exceptions:
raise BaseExceptionGroup(
"Multiple exceptions occurred in asynchronous callbacks", exceptions
)
@staticmethod
async def _run_tests_and_fixtures(
receive_stream: MemoryObjectReceiveStream[
tuple[Awaitable[T_Retval], asyncio.Future[T_Retval]]
],
) -> None:
with receive_stream:
async for coro, future in receive_stream:
try:
retval = await coro
except BaseException as exc:
if not future.cancelled():
future.set_exception(exc)
else:
if not future.cancelled():
future.set_result(retval)
async def _call_in_runner_task(
self,
func: Callable[P, Awaitable[T_Retval]],
*args: P.args,
**kwargs: P.kwargs,
) -> T_Retval:
if not self._runner_task:
self._send_stream, receive_stream = create_memory_object_stream[
Tuple[Awaitable[Any], asyncio.Future]
](1)
self._runner_task = self.get_loop().create_task(
self._run_tests_and_fixtures(receive_stream)
)
coro = func(*args, **kwargs)
future: asyncio.Future[T_Retval] = self.get_loop().create_future()
self._send_stream.send_nowait((coro, future))
return await future
def run_asyncgen_fixture(
self,
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]],
kwargs: dict[str, Any],
) -> Iterable[T_Retval]:
asyncgen = fixture_func(**kwargs)
fixturevalue: T_Retval = self.get_loop().run_until_complete(
self._call_in_runner_task(asyncgen.asend, None)
)
self._raise_async_exceptions()
yield fixturevalue
try:
self.get_loop().run_until_complete(
self._call_in_runner_task(asyncgen.asend, None)
)
except StopAsyncIteration:
self._raise_async_exceptions()
else:
self.get_loop().run_until_complete(asyncgen.aclose())
raise RuntimeError("Async generator fixture did not stop")
def run_fixture(
self,
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]],
kwargs: dict[str, Any],
) -> T_Retval:
retval = self.get_loop().run_until_complete(
self._call_in_runner_task(fixture_func, **kwargs)
)
self._raise_async_exceptions()
return retval
def run_test(
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any]
) -> None:
try:
self.get_loop().run_until_complete(
self._call_in_runner_task(test_func, **kwargs)
)
except Exception as exc:
self._exceptions.append(exc)
self._raise_async_exceptions()
class AsyncIOBackend(AsyncBackend):
@classmethod
def run(
cls,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
kwargs: dict[str, Any],
options: dict[str, Any],
) -> T_Retval:
@wraps(func)
async def wrapper() -> T_Retval:
task = cast(asyncio.Task, current_task())
task.set_name(get_callable_name(func))
_task_states[task] = TaskState(None, None)
try:
return await func(*args)
finally:
del _task_states[task]
debug = options.get("debug", False)
loop_factory = options.get("loop_factory", None)
if loop_factory is None and options.get("use_uvloop", False):
import uvloop
loop_factory = uvloop.new_event_loop
with Runner(debug=debug, loop_factory=loop_factory) as runner:
return runner.run(wrapper())
@classmethod
def current_token(cls) -> object:
return get_running_loop()
@classmethod
def current_time(cls) -> float:
return get_running_loop().time()
@classmethod
def cancelled_exception_class(cls) -> type[BaseException]:
return CancelledError
@classmethod
async def checkpoint(cls) -> None:
await sleep(0)
@classmethod
async def checkpoint_if_cancelled(cls) -> None:
task = current_task()
if task is None:
return
try:
cancel_scope = _task_states[task].cancel_scope
except KeyError:
return
while cancel_scope:
if cancel_scope.cancel_called:
await sleep(0)
elif cancel_scope.shield:
break
else:
cancel_scope = cancel_scope._parent_scope
@classmethod
async def cancel_shielded_checkpoint(cls) -> None:
with CancelScope(shield=True):
await sleep(0)
@classmethod
async def sleep(cls, delay: float) -> None:
await sleep(delay)
@classmethod
def create_cancel_scope(
cls, *, deadline: float = math.inf, shield: bool = False
) -> CancelScope:
return CancelScope(deadline=deadline, shield=shield)
@classmethod
def current_effective_deadline(cls) -> float:
try:
cancel_scope = _task_states[
current_task() # type: ignore[index]
].cancel_scope
except KeyError:
return math.inf
deadline = math.inf
while cancel_scope:
deadline = min(deadline, cancel_scope.deadline)
if cancel_scope._cancel_called:
deadline = -math.inf
break
elif cancel_scope.shield:
break
else:
cancel_scope = cancel_scope._parent_scope
return deadline
@classmethod
def create_task_group(cls) -> abc.TaskGroup:
return TaskGroup()
@classmethod
def create_event(cls) -> abc.Event:
return Event()
@classmethod
def create_capacity_limiter(cls, total_tokens: float) -> abc.CapacityLimiter:
return CapacityLimiter(total_tokens)
@classmethod
async def run_sync_in_worker_thread(
cls,
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
abandon_on_cancel: bool = False,
limiter: abc.CapacityLimiter | None = None,
) -> T_Retval:
await cls.checkpoint()
# If this is the first run in this event loop thread, set up the necessary
# variables
try:
idle_workers = _threadpool_idle_workers.get()
workers = _threadpool_workers.get()
except LookupError:
idle_workers = deque()
workers = set()
_threadpool_idle_workers.set(idle_workers)
_threadpool_workers.set(workers)
async with limiter or cls.current_default_thread_limiter():
with CancelScope(shield=not abandon_on_cancel) as scope:
future: asyncio.Future = asyncio.Future()
root_task = find_root_task()
if not idle_workers:
worker = WorkerThread(root_task, workers, idle_workers)
worker.start()
workers.add(worker)
root_task.add_done_callback(worker.stop)
else:
worker = idle_workers.pop()
# Prune any other workers that have been idle for MAX_IDLE_TIME
# seconds or longer
now = cls.current_time()
while idle_workers:
if (
now - idle_workers[0].idle_since
< WorkerThread.MAX_IDLE_TIME
):
break
expired_worker = idle_workers.popleft()
expired_worker.root_task.remove_done_callback(
expired_worker.stop
)
expired_worker.stop()
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, None)
if abandon_on_cancel or scope._parent_scope is None:
worker_scope = scope
else:
worker_scope = scope._parent_scope
worker.queue.put_nowait((context, func, args, future, worker_scope))
return await future
@classmethod
def check_cancelled(cls) -> None:
scope: CancelScope | None = threadlocals.current_cancel_scope
while scope is not None:
if scope.cancel_called:
raise CancelledError(f"Cancelled by cancel scope {id(scope):x}")
if scope.shield:
return
scope = scope._parent_scope
@classmethod
def run_async_from_thread(
cls,
func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
async def task_wrapper(scope: CancelScope) -> T_Retval:
__tracebackhide__ = True
task = cast(asyncio.Task, current_task())
_task_states[task] = TaskState(None, scope)
scope._tasks.add(task)
try:
return await func(*args)
except CancelledError as exc:
raise concurrent.futures.CancelledError(str(exc)) from None
finally:
scope._tasks.discard(task)
loop = cast(AbstractEventLoop, token)
context = copy_context()
context.run(sniffio.current_async_library_cvar.set, "asyncio")
wrapper = task_wrapper(threadlocals.current_cancel_scope)
f: concurrent.futures.Future[T_Retval] = context.run(
asyncio.run_coroutine_threadsafe, wrapper, loop
)
return f.result()
@classmethod
def run_sync_from_thread(
cls,
func: Callable[[Unpack[PosArgsT]], T_Retval],
args: tuple[Unpack[PosArgsT]],
token: object,
) -> T_Retval:
@wraps(func)
def wrapper() -> None:
try:
sniffio.current_async_library_cvar.set("asyncio")
f.set_result(func(*args))
except BaseException as exc:
f.set_exception(exc)
if not isinstance(exc, Exception):
raise
f: concurrent.futures.Future[T_Retval] = Future()
loop = cast(AbstractEventLoop, token)
loop.call_soon_threadsafe(wrapper)
return f.result()
@classmethod
def create_blocking_portal(cls) -> abc.BlockingPortal:
return BlockingPortal()
@classmethod
async def open_process(
cls,
command: str | bytes | Sequence[str | bytes],
*,
shell: bool,
stdin: int | IO[Any] | None,
stdout: int | IO[Any] | None,
stderr: int | IO[Any] | None,
cwd: str | bytes | PathLike | None = None,
env: Mapping[str, str] | None = None,
start_new_session: bool = False,
) -> Process:
await cls.checkpoint()
if shell:
process = await asyncio.create_subprocess_shell(
cast("str | bytes", command),
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
else:
process = await asyncio.create_subprocess_exec(
*command,
stdin=stdin,
stdout=stdout,
stderr=stderr,
cwd=cwd,
env=env,
start_new_session=start_new_session,
)
stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None
stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None
stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None
return Process(process, stdin_stream, stdout_stream, stderr_stream)
@classmethod
def setup_process_pool_exit_at_shutdown(cls, workers: set[abc.Process]) -> None:
create_task(
_shutdown_process_pool_on_exit(workers),
name="AnyIO process pool shutdown task",
)
find_root_task().add_done_callback(
partial(_forcibly_shutdown_process_pool_on_exit, workers)
)
@classmethod
async def connect_tcp(
cls, host: str, port: int, local_address: IPSockAddrType | None = None
) -> abc.SocketStream:
transport, protocol = cast(
Tuple[asyncio.Transport, StreamProtocol],
await get_running_loop().create_connection(
StreamProtocol, host, port, local_addr=local_address
),
)
transport.pause_reading()
return SocketStream(transport, protocol)
@classmethod
async def connect_unix(cls, path: str | bytes) -> abc.UNIXSocketStream:
await cls.checkpoint()
loop = get_running_loop()
raw_socket = socket.socket(socket.AF_UNIX)
raw_socket.setblocking(False)
while True:
try:
raw_socket.connect(path)
except BlockingIOError:
f: asyncio.Future = asyncio.Future()
loop.add_writer(raw_socket, f.set_result, None)
f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
await f
except BaseException:
raw_socket.close()
raise
else:
return UNIXSocketStream(raw_socket)
@classmethod
def create_tcp_listener(cls, sock: socket.socket) -> SocketListener:
return TCPSocketListener(sock)
@classmethod
def create_unix_listener(cls, sock: socket.socket) -> SocketListener:
return UNIXSocketListener(sock)
@classmethod
async def create_udp_socket(
cls,
family: AddressFamily,
local_address: IPSockAddrType | None,
remote_address: IPSockAddrType | None,
reuse_port: bool,
) -> UDPSocket | ConnectedUDPSocket:
transport, protocol = await get_running_loop().create_datagram_endpoint(
DatagramProtocol,
local_addr=local_address,
remote_addr=remote_address,
family=family,
reuse_port=reuse_port,
)
if protocol.exception:
transport.close()
raise protocol.exception
if not remote_address:
return UDPSocket(transport, protocol)
else:
return ConnectedUDPSocket(transport, protocol)
@classmethod
async def create_unix_datagram_socket( # type: ignore[override]
cls, raw_socket: socket.socket, remote_path: str | bytes | None
) -> abc.UNIXDatagramSocket | abc.ConnectedUNIXDatagramSocket:
await cls.checkpoint()
loop = get_running_loop()
if remote_path:
while True:
try:
raw_socket.connect(remote_path)
except BlockingIOError:
f: asyncio.Future = asyncio.Future()
loop.add_writer(raw_socket, f.set_result, None)
f.add_done_callback(lambda _: loop.remove_writer(raw_socket))
await f
except BaseException:
raw_socket.close()
raise
else:
return ConnectedUNIXDatagramSocket(raw_socket)
else:
return UNIXDatagramSocket(raw_socket)
@classmethod
async def getaddrinfo(
cls,
host: bytes | str | None,
port: str | int | None,
*,
family: int | AddressFamily = 0,
type: int | SocketKind = 0,
proto: int = 0,
flags: int = 0,
) -> list[
tuple[
AddressFamily,
SocketKind,
int,
str,
tuple[str, int] | tuple[str, int, int, int],
]
]:
return await get_running_loop().getaddrinfo(
host, port, family=family, type=type, proto=proto, flags=flags
)
@classmethod
async def getnameinfo(
cls, sockaddr: IPSockAddrType, flags: int = 0
) -> tuple[str, str]:
return await get_running_loop().getnameinfo(sockaddr, flags)
@classmethod
async def wait_socket_readable(cls, sock: socket.socket) -> None:
await cls.checkpoint()
try:
read_events = _read_events.get()
except LookupError:
read_events = {}
_read_events.set(read_events)
if read_events.get(sock):
raise BusyResourceError("reading from") from None
loop = get_running_loop()
event = read_events[sock] = asyncio.Event()
loop.add_reader(sock, event.set)
try:
await event.wait()
finally:
if read_events.pop(sock, None) is not None:
loop.remove_reader(sock)
readable = True
else:
readable = False
if not readable:
raise ClosedResourceError
@classmethod
async def wait_socket_writable(cls, sock: socket.socket) -> None:
await cls.checkpoint()
try:
write_events = _write_events.get()
except LookupError:
write_events = {}
_write_events.set(write_events)
if write_events.get(sock):
raise BusyResourceError("writing to") from None
loop = get_running_loop()
event = write_events[sock] = asyncio.Event()
loop.add_writer(sock.fileno(), event.set)
try:
await event.wait()
finally:
if write_events.pop(sock, None) is not None:
loop.remove_writer(sock)
writable = True
else:
writable = False
if not writable:
raise ClosedResourceError
@classmethod
def current_default_thread_limiter(cls) -> CapacityLimiter:
try:
return _default_thread_limiter.get()
except LookupError:
limiter = CapacityLimiter(40)
_default_thread_limiter.set(limiter)
return limiter
@classmethod
def open_signal_receiver(
cls, *signals: Signals
) -> ContextManager[AsyncIterator[Signals]]:
return _SignalReceiver(signals)
@classmethod
def get_current_task(cls) -> TaskInfo:
return _create_task_info(current_task()) # type: ignore[arg-type]
@classmethod
def get_running_tasks(cls) -> list[TaskInfo]:
return [_create_task_info(task) for task in all_tasks() if not task.done()]
@classmethod
async def wait_all_tasks_blocked(cls) -> None:
await cls.checkpoint()
this_task = current_task()
while True:
for task in all_tasks():
if task is this_task:
continue
waiter = task._fut_waiter # type: ignore[attr-defined]
if waiter is None or waiter.done():
await sleep(0.1)
break
else:
return
@classmethod
def create_test_runner(cls, options: dict[str, Any]) -> TestRunner:
return TestRunner(**options)
backend_class = AsyncIOBackend