"""Manage FastAPI background tasks.""" import asyncio import functools import time import traceback import uuid from datetime import datetime, timedelta from enum import Enum from typing import Any, Awaitable, Callable, Iterable, Optional, TypeVar, Union import dask import psutil from dask.distributed import Client, Variable from distributed import Future, get_client, get_worker from pydantic import BaseModel, parse_obj_as from tqdm import tqdm from .utils import log, pretty_timedelta TaskId = str # ID for the step of a task. TaskStepId = tuple[str, int] Task = Union[Callable[..., Any], Callable[..., Awaitable[Any]]] class TaskStatus(str, Enum): """Enum holding a tasks status.""" PENDING = 'pending' COMPLETED = 'completed' ERROR = 'error' class TaskStepInfo(BaseModel): """Information about a step of the task..""" progress: Optional[float] description: Optional[str] details: Optional[str] class TaskInfo(BaseModel): """Metadata about a task.""" name: str status: TaskStatus progress: Optional[float] message: Optional[str] details: Optional[str] # The current step's progress. step_progress: Optional[float] # A task may have multiple progress indicators, e.g. for chained signals that compute 3 signals. steps: Optional[list[TaskStepInfo]] description: Optional[str] start_timestamp: str end_timestamp: Optional[str] error: Optional[str] class TaskManifest(BaseModel): """Information for tasks that are running or completed.""" tasks: dict[str, TaskInfo] progress: Optional[float] STEPS_LOG_KEY = 'steps' class TaskManager: """Manage FastAPI background tasks.""" _tasks: dict[str, TaskInfo] = {} def __init__(self, dask_client: Optional[Client] = None) -> None: """By default, use a dask multi-processing client. A user can pass in a dask client to use a different executor. """ # Set dasks workers to be non-daemonic so they can spawn child processes if they need to. This # is particularly useful for signals that use libraries with multiprocessing support. dask.config.set({'distributed.worker.daemon': False}) total_memory_gb = psutil.virtual_memory().total / (1024**3) self._dask_client = dask_client or Client( asynchronous=True, memory_limit=f'{total_memory_gb} GB') async def _update_tasks(self) -> None: for task_id, task in self._tasks.items(): if task.status == TaskStatus.COMPLETED: continue step_events = self._dask_client.get_events(_progress_event_topic(task_id)) # This allows us to work with both sync and async clients. if not isinstance(step_events, tuple): step_events = await step_events if step_events: _, log_message = step_events[-1] steps = parse_obj_as(list[TaskStepInfo], log_message[STEPS_LOG_KEY]) task.steps = steps if steps: cur_step = 0 for i, step in enumerate(reversed(steps)): if step.progress is not None: cur_step = len(steps) - i - 1 break task.details = steps[cur_step].details task.step_progress = steps[cur_step].progress task.progress = (sum([step.progress or 0.0 for step in steps])) / len(steps) # Don't show an indefinite jump if there are multiple steps. if cur_step > 0 and task.step_progress is None: task.step_progress = 0.0 task.message = f'Step {cur_step+1}/{len(steps)}' if steps[cur_step].description: task.message += f': {steps[cur_step].description}' else: task.progress = None async def manifest(self) -> TaskManifest: """Get all tasks.""" await self._update_tasks() tasks_with_progress = [ task.progress for task in self._tasks.values() if task.progress and task.status != TaskStatus.COMPLETED ] return TaskManifest( tasks=self._tasks, progress=sum(tasks_with_progress) / len(tasks_with_progress) if tasks_with_progress else None) def task_id(self, name: str, description: Optional[str] = None) -> TaskId: """Create a unique ID for a task.""" task_id = uuid.uuid4().hex self._tasks[task_id] = TaskInfo( name=name, status=TaskStatus.PENDING, progress=None, description=description, start_timestamp=datetime.now().isoformat()) return task_id def _set_task_completed(self, task_id: TaskId, task_future: Future) -> None: end_timestamp = datetime.now().isoformat() self._tasks[task_id].end_timestamp = end_timestamp elapsed = datetime.fromisoformat(end_timestamp) - datetime.fromisoformat( self._tasks[task_id].start_timestamp) elapsed_formatted = pretty_timedelta(elapsed) if task_future.status == 'error': self._tasks[task_id].status = TaskStatus.ERROR tb = traceback.format_tb(task_future.traceback()) e = task_future.exception() self._tasks[task_id].error = f'{e}: \n{tb}' raise e else: # This runs in dask callback thread, so we have to make a new event loop. loop = asyncio.new_event_loop() loop.run_until_complete(self._update_tasks()) for step in self._tasks[task_id].steps or []: step.progress = 1.0 self._tasks[task_id].status = TaskStatus.COMPLETED self._tasks[task_id].progress = 1.0 self._tasks[task_id].message = f'Completed in {elapsed_formatted}' log(f'Task completed "{task_id}": "{self._tasks[task_id].name}" in ' f'{elapsed_formatted}.') def execute(self, task_id: str, task: Task, *args: Any) -> None: """Create a unique ID for a task.""" log(f'Scheduling task "{task_id}": "{self._tasks[task_id].name}".') task_info = self._tasks[task_id] task_future = self._dask_client.submit( functools.partial(_execute_task, task, task_info, task_id), *args, key=task_id) task_future.add_done_callback( lambda task_future: self._set_task_completed(task_id, task_future)) async def stop(self) -> None: """Stop the task manager and close the dask client.""" await self._dask_client.close() @functools.cache def task_manager() -> TaskManager: """The global singleton for the task manager.""" return TaskManager() def _execute_task(task: Task, task_info: TaskInfo, task_id: str, *args: Any) -> None: get_worker().state.tasks[task_id].annotations['task_info'] = task_info task(*args) def _progress_event_topic(task_id: TaskId) -> Variable: return f'{task_id}_progress' TProgress = TypeVar('TProgress') def progress(it: Iterable[TProgress], task_step_id: Optional[TaskStepId], estimated_len: Optional[int], step_description: Optional[str] = None, emit_every_s: float = 1.) -> Iterable[TProgress]: """An iterable wrapper that emits progress and yields the original iterable.""" if not task_step_id: yield from it return task_id, step_id = task_step_id steps = get_worker_steps(task_id) if not steps: steps = [TaskStepInfo(description=step_description, progress=0.0)] elif len(steps) <= step_id: # If the step given exceeds the length of the last step, add a new step. steps.append(TaskStepInfo(description=step_description, progress=0.0)) else: steps[step_id].description = step_description steps[step_id].progress = 0.0 set_worker_steps(task_id, steps) estimated_len = max(1, estimated_len) if estimated_len else None task_info: TaskInfo = get_worker().state.tasks[task_id].annotations['task_info'] it_idx = 0 start_time = time.time() last_emit = time.time() - emit_every_s with tqdm(it, desc=task_info.name, total=estimated_len) as tq: for t in tq: cur_time = time.time() if estimated_len and cur_time - last_emit > emit_every_s: it_per_sec = tq.format_dict['rate'] or 0.0 set_worker_task_progress( task_step_id=task_step_id, it_idx=it_idx, elapsed_sec=tq.format_dict['elapsed'] or 0.0, it_per_sec=it_per_sec or 0.0, estimated_total_sec=((estimated_len) / it_per_sec if it_per_sec else 0), estimated_len=estimated_len) last_emit = cur_time yield t it_idx += 1 total_time = time.time() - start_time set_worker_task_progress( task_step_id=task_step_id, it_idx=estimated_len if estimated_len else it_idx, elapsed_sec=total_time, it_per_sec=(estimated_len or it_idx) / total_time, estimated_total_sec=total_time, estimated_len=estimated_len or it_idx) def set_worker_steps(task_id: TaskId, steps: list[TaskStepInfo]) -> None: """Sets up worker steps. Use to provide task step descriptions before they compute.""" get_worker().log_event( _progress_event_topic(task_id), {STEPS_LOG_KEY: [step.dict() for step in steps]}) def get_worker_steps(task_id: TaskId) -> list[TaskStepInfo]: """Gets the last worker steps.""" events = get_client().get_events(_progress_event_topic(task_id)) if not events or not events[-1]: return [] (_, last_event) = events[-1] last_info = last_event.get(STEPS_LOG_KEY) return [TaskStepInfo(**step_info) for step_info in last_info] def set_worker_task_progress(task_step_id: TaskStepId, it_idx: int, elapsed_sec: float, it_per_sec: float, estimated_total_sec: float, estimated_len: int) -> None: """Updates a task step with a progress between 0 and 1. This method does not exist on the TaskManager as it is meant to be a standalone method used by workers running tasks on separate processes so does not have access to task manager state. """ progress = float(it_idx) / estimated_len task_id, step_id = task_step_id steps = get_worker_steps(task_id) if len(steps) <= step_id: raise ValueError(f'No step with idx {step_id} exists. Got steps: {steps}') steps[step_id].progress = progress # 1748/1748 [elapsed 00:16<00:00, 106.30 ex/s] elapsed = f'{pretty_timedelta(timedelta(seconds=elapsed_sec))}' if it_idx != estimated_len: # Only show estimated when in progress. elapsed = f'{elapsed} < {pretty_timedelta(timedelta(seconds=estimated_total_sec))}' steps[step_id].details = (f'{it_idx:,}/{estimated_len:,} ' f'[{elapsed}, {it_per_sec:,.2f} ex/s]') set_worker_steps(task_id, steps)