| """ |
| Async utilities for SkyDiscover |
| """ |
|
|
| import asyncio |
| import logging |
| from typing import Any, Callable, List, Optional, Sequence, Tuple |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TaskPool: |
| """ |
| A simple task pool for managing and limiting concurrent tasks |
| """ |
|
|
| def __init__(self, max_concurrency: int = 10): |
| self.max_concurrency = max_concurrency |
| self._semaphore: Optional[asyncio.Semaphore] = None |
| self.tasks: List[asyncio.Task] = [] |
|
|
| @property |
| def semaphore(self) -> asyncio.Semaphore: |
| """Lazy-initialize the semaphore when first needed.""" |
| if self._semaphore is None: |
| self._semaphore = asyncio.Semaphore(self.max_concurrency) |
| return self._semaphore |
|
|
| async def run(self, coro: Callable, *args: Any, **kwargs: Any) -> Any: |
| """Run a single coroutine function under the concurrency semaphore.""" |
| async with self.semaphore: |
| return await coro(*args, **kwargs) |
|
|
| def create_task(self, coro: Callable, *args: Any, **kwargs: Any) -> asyncio.Task: |
| """Create, track, and return an ``asyncio.Task`` bounded by the pool.""" |
| task = asyncio.create_task(self.run(coro, *args, **kwargs)) |
| self.tasks.append(task) |
| task.add_done_callback(lambda t: self.tasks.remove(t)) |
| return task |
|
|
| async def gather( |
| self, |
| coros: Sequence[Callable], |
| args_list: Sequence[Tuple[Any, ...]] = (), |
| kwargs_list: Sequence[dict] = (), |
| return_exceptions: bool = False, |
| ) -> List[Any]: |
| """Run *coros* concurrently (bounded by the semaphore), return results in order.""" |
| n = len(coros) |
| _args = args_list if args_list else [() for _ in range(n)] |
| _kwargs = kwargs_list if kwargs_list else [{} for _ in range(n)] |
|
|
| if len(_args) != n: |
| raise ValueError(f"args_list length ({len(_args)}) must match coros length ({n})") |
| if len(_kwargs) != n: |
| raise ValueError(f"kwargs_list length ({len(_kwargs)}) must match coros length ({n})") |
|
|
| tasks = [ |
| self.create_task(coro, *args, **kwargs) |
| for coro, args, kwargs in zip(coros, _args, _kwargs) |
| ] |
| return await asyncio.gather(*tasks, return_exceptions=return_exceptions) |
|
|