from __future__ import annotations import sys import threading from collections.abc import Awaitable, Callable, Generator from concurrent.futures import FIRST_COMPLETED, Future, ThreadPoolExecutor, wait from contextlib import AbstractContextManager, contextmanager from inspect import isawaitable from types import TracebackType from typing import ( Any, AsyncContextManager, ContextManager, Generic, Iterable, TypeVar, cast, overload, ) from ._core import _eventloop from ._core._eventloop import get_async_backend, get_cancelled_exc_class, threadlocals from ._core._synchronization import Event from ._core._tasks import CancelScope, create_task_group from .abc import AsyncBackend from .abc._tasks import TaskStatus if sys.version_info >= (3, 11): from typing import TypeVarTuple, Unpack else: from typing_extensions import TypeVarTuple, Unpack T_Retval = TypeVar("T_Retval") T_co = TypeVar("T_co", covariant=True) PosArgsT = TypeVarTuple("PosArgsT") def run( func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT] ) -> T_Retval: """ Call a coroutine function from a worker thread. :param func: a coroutine function :param args: positional arguments for the callable :return: the return value of the coroutine function """ try: async_backend = threadlocals.current_async_backend token = threadlocals.current_token except AttributeError: raise RuntimeError( "This function can only be run from an AnyIO worker thread" ) from None return async_backend.run_async_from_thread(func, args, token=token) def run_sync( func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] ) -> T_Retval: """ Call a function in the event loop thread from a worker thread. :param func: a callable :param args: positional arguments for the callable :return: the return value of the callable """ try: async_backend = threadlocals.current_async_backend token = threadlocals.current_token except AttributeError: raise RuntimeError( "This function can only be run from an AnyIO worker thread" ) from None return async_backend.run_sync_from_thread(func, args, token=token) class _BlockingAsyncContextManager(Generic[T_co], AbstractContextManager): _enter_future: Future[T_co] _exit_future: Future[bool | None] _exit_event: Event _exit_exc_info: tuple[ type[BaseException] | None, BaseException | None, TracebackType | None ] = (None, None, None) def __init__(self, async_cm: AsyncContextManager[T_co], portal: BlockingPortal): self._async_cm = async_cm self._portal = portal async def run_async_cm(self) -> bool | None: try: self._exit_event = Event() value = await self._async_cm.__aenter__() except BaseException as exc: self._enter_future.set_exception(exc) raise else: self._enter_future.set_result(value) try: # Wait for the sync context manager to exit. # This next statement can raise `get_cancelled_exc_class()` if # something went wrong in a task group in this async context # manager. await self._exit_event.wait() finally: # In case of cancellation, it could be that we end up here before # `_BlockingAsyncContextManager.__exit__` is called, and an # `_exit_exc_info` has been set. result = await self._async_cm.__aexit__(*self._exit_exc_info) return result def __enter__(self) -> T_co: self._enter_future = Future() self._exit_future = self._portal.start_task_soon(self.run_async_cm) return self._enter_future.result() def __exit__( self, __exc_type: type[BaseException] | None, __exc_value: BaseException | None, __traceback: TracebackType | None, ) -> bool | None: self._exit_exc_info = __exc_type, __exc_value, __traceback self._portal.call(self._exit_event.set) return self._exit_future.result() class _BlockingPortalTaskStatus(TaskStatus): def __init__(self, future: Future): self._future = future def started(self, value: object = None) -> None: self._future.set_result(value) class BlockingPortal: """An object that lets external threads run code in an asynchronous event loop.""" def __new__(cls) -> BlockingPortal: return get_async_backend().create_blocking_portal() def __init__(self) -> None: self._event_loop_thread_id: int | None = threading.get_ident() self._stop_event = Event() self._task_group = create_task_group() self._cancelled_exc_class = get_cancelled_exc_class() async def __aenter__(self) -> BlockingPortal: await self._task_group.__aenter__() return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: await self.stop() return await self._task_group.__aexit__(exc_type, exc_val, exc_tb) def _check_running(self) -> None: if self._event_loop_thread_id is None: raise RuntimeError("This portal is not running") if self._event_loop_thread_id == threading.get_ident(): raise RuntimeError( "This method cannot be called from the event loop thread" ) async def sleep_until_stopped(self) -> None: """Sleep until :meth:`stop` is called.""" await self._stop_event.wait() async def stop(self, cancel_remaining: bool = False) -> None: """ Signal the portal to shut down. This marks the portal as no longer accepting new calls and exits from :meth:`sleep_until_stopped`. :param cancel_remaining: ``True`` to cancel all the remaining tasks, ``False`` to let them finish before returning """ self._event_loop_thread_id = None self._stop_event.set() if cancel_remaining: self._task_group.cancel_scope.cancel() async def _call_func( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], future: Future[T_Retval], ) -> None: def callback(f: Future[T_Retval]) -> None: if f.cancelled() and self._event_loop_thread_id not in ( None, threading.get_ident(), ): self.call(scope.cancel) try: retval_or_awaitable = func(*args, **kwargs) if isawaitable(retval_or_awaitable): with CancelScope() as scope: if future.cancelled(): scope.cancel() else: future.add_done_callback(callback) retval = await retval_or_awaitable else: retval = retval_or_awaitable except self._cancelled_exc_class: future.cancel() future.set_running_or_notify_cancel() except BaseException as exc: if not future.cancelled(): future.set_exception(exc) # Let base exceptions fall through if not isinstance(exc, Exception): raise else: if not future.cancelled(): future.set_result(retval) finally: scope = None # type: ignore[assignment] def _spawn_task_from_thread( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], args: tuple[Unpack[PosArgsT]], kwargs: dict[str, Any], name: object, future: Future[T_Retval], ) -> None: """ Spawn a new task using the given callable. Implementors must ensure that the future is resolved when the task finishes. :param func: a callable :param args: positional arguments to be passed to the callable :param kwargs: keyword arguments to be passed to the callable :param name: name of the task (will be coerced to a string if not ``None``) :param future: a future that will resolve to the return value of the callable, or the exception raised during its execution """ raise NotImplementedError @overload def call( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT], ) -> T_Retval: ... @overload def call( self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT] ) -> T_Retval: ... def call( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], *args: Unpack[PosArgsT], ) -> T_Retval: """ Call the given function in the event loop thread. If the callable returns a coroutine object, it is awaited on. :param func: any callable :raises RuntimeError: if the portal is not running or if this method is called from within the event loop thread """ return cast(T_Retval, self.start_task_soon(func, *args).result()) @overload def start_task_soon( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval]], *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: ... @overload def start_task_soon( self, func: Callable[[Unpack[PosArgsT]], T_Retval], *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: ... def start_task_soon( self, func: Callable[[Unpack[PosArgsT]], Awaitable[T_Retval] | T_Retval], *args: Unpack[PosArgsT], name: object = None, ) -> Future[T_Retval]: """ Start a task in the portal's task group. The task will be run inside a cancel scope which can be cancelled by cancelling the returned future. :param func: the target function :param args: positional arguments passed to ``func`` :param name: name of the task (will be coerced to a string if not ``None``) :return: a future that resolves with the return value of the callable if the task completes successfully, or with the exception raised in the task :raises RuntimeError: if the portal is not running or if this method is called from within the event loop thread :rtype: concurrent.futures.Future[T_Retval] .. versionadded:: 3.0 """ self._check_running() f: Future[T_Retval] = Future() self._spawn_task_from_thread(func, args, {}, name, f) return f def start_task( self, func: Callable[..., Awaitable[T_Retval]], *args: object, name: object = None, ) -> tuple[Future[T_Retval], Any]: """ Start a task in the portal's task group and wait until it signals for readiness. This method works the same way as :meth:`.abc.TaskGroup.start`. :param func: the target function :param args: positional arguments passed to ``func`` :param name: name of the task (will be coerced to a string if not ``None``) :return: a tuple of (future, task_status_value) where the ``task_status_value`` is the value passed to ``task_status.started()`` from within the target function :rtype: tuple[concurrent.futures.Future[T_Retval], Any] .. versionadded:: 3.0 """ def task_done(future: Future[T_Retval]) -> None: if not task_status_future.done(): if future.cancelled(): task_status_future.cancel() elif future.exception(): task_status_future.set_exception(future.exception()) else: exc = RuntimeError( "Task exited without calling task_status.started()" ) task_status_future.set_exception(exc) self._check_running() task_status_future: Future = Future() task_status = _BlockingPortalTaskStatus(task_status_future) f: Future = Future() f.add_done_callback(task_done) self._spawn_task_from_thread(func, args, {"task_status": task_status}, name, f) return f, task_status_future.result() def wrap_async_context_manager( self, cm: AsyncContextManager[T_co] ) -> ContextManager[T_co]: """ Wrap an async context manager as a synchronous context manager via this portal. Spawns a task that will call both ``__aenter__()`` and ``__aexit__()``, stopping in the middle until the synchronous context manager exits. :param cm: an asynchronous context manager :return: a synchronous context manager .. versionadded:: 2.1 """ return _BlockingAsyncContextManager(cm, self) @contextmanager def start_blocking_portal( backend: str = "asyncio", backend_options: dict[str, Any] | None = None ) -> Generator[BlockingPortal, Any, None]: """ Start a new event loop in a new thread and run a blocking portal in its main task. The parameters are the same as for :func:`~anyio.run`. :param backend: name of the backend :param backend_options: backend options :return: a context manager that yields a blocking portal .. versionchanged:: 3.0 Usage as a context manager is now required. """ async def run_portal() -> None: async with BlockingPortal() as portal_: if future.set_running_or_notify_cancel(): future.set_result(portal_) await portal_.sleep_until_stopped() future: Future[BlockingPortal] = Future() with ThreadPoolExecutor(1) as executor: run_future = executor.submit( _eventloop.run, # type: ignore[arg-type] run_portal, backend=backend, backend_options=backend_options, ) try: wait( cast(Iterable[Future], [run_future, future]), return_when=FIRST_COMPLETED, ) except BaseException: future.cancel() run_future.cancel() raise if future.done(): portal = future.result() cancel_remaining_tasks = False try: yield portal except BaseException: cancel_remaining_tasks = True raise finally: try: portal.call(portal.stop, cancel_remaining_tasks) except RuntimeError: pass run_future.result() def check_cancelled() -> None: """ Check if the cancel scope of the host task's running the current worker thread has been cancelled. If the host task's current cancel scope has indeed been cancelled, the backend-specific cancellation exception will be raised. :raises RuntimeError: if the current thread was not spawned by :func:`.to_thread.run_sync` """ try: async_backend: AsyncBackend = threadlocals.current_async_backend except AttributeError: raise RuntimeError( "This function can only be run from an AnyIO worker thread" ) from None async_backend.check_cancelled()