File size: 10,352 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
55dc3dd
 
e4f9cbe
 
 
 
 
 
 
 
55dc3dd
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
"""Manage FastAPI background tasks."""

import asyncio
import functools
import time
import traceback
import uuid
from datetime import datetime, timedelta
from enum import Enum
from typing import Any, Awaitable, Callable, Iterable, Optional, TypeVar, Union

import dask
import psutil
from dask.distributed import Client, Variable
from distributed import Future, get_client, get_worker
from pydantic import BaseModel, parse_obj_as
from tqdm import tqdm

from .utils import log, pretty_timedelta

TaskId = str
# ID for the step of a task.
TaskStepId = tuple[str, int]
Task = Union[Callable[..., Any], Callable[..., Awaitable[Any]]]


class TaskStatus(str, Enum):
  """Enum holding a tasks status."""
  PENDING = 'pending'
  COMPLETED = 'completed'
  ERROR = 'error'


class TaskStepInfo(BaseModel):
  """Information about a step of the task.."""
  progress: Optional[float]
  description: Optional[str]
  details: Optional[str]


class TaskInfo(BaseModel):
  """Metadata about a task."""
  name: str
  status: TaskStatus
  progress: Optional[float]
  message: Optional[str]
  details: Optional[str]
  # The current step's progress.
  step_progress: Optional[float]

  # A task may have multiple progress indicators, e.g. for chained signals that compute 3 signals.
  steps: Optional[list[TaskStepInfo]]
  description: Optional[str]
  start_timestamp: str
  end_timestamp: Optional[str]
  error: Optional[str]


class TaskManifest(BaseModel):
  """Information for tasks that are running or completed."""
  tasks: dict[str, TaskInfo]
  progress: Optional[float]


STEPS_LOG_KEY = 'steps'


class TaskManager:
  """Manage FastAPI background tasks."""
  _tasks: dict[str, TaskInfo] = {}

  def __init__(self, dask_client: Optional[Client] = None) -> None:
    """By default, use a dask multi-processing client.

    A user can pass in a dask client to use a different executor.
    """
    # Set dasks workers to be non-daemonic so they can spawn child processes if they need to. This
    # is particularly useful for signals that use libraries with multiprocessing support.
    dask.config.set({'distributed.worker.daemon': False})

    total_memory_gb = psutil.virtual_memory().total / (1024**3)
    self._dask_client = dask_client or Client(
      asynchronous=True, memory_limit=f'{total_memory_gb} GB')

  async def _update_tasks(self) -> None:
    for task_id, task in self._tasks.items():
      if task.status == TaskStatus.COMPLETED:
        continue

      step_events = self._dask_client.get_events(_progress_event_topic(task_id))
      # This allows us to work with both sync and async clients.
      if not isinstance(step_events, tuple):
        step_events = await step_events

      if step_events:
        _, log_message = step_events[-1]
        steps = parse_obj_as(list[TaskStepInfo], log_message[STEPS_LOG_KEY])
        task.steps = steps
        if steps:
          cur_step = 0
          for i, step in enumerate(reversed(steps)):
            if step.progress is not None:
              cur_step = len(steps) - i - 1
              break
          task.details = steps[cur_step].details
          task.step_progress = steps[cur_step].progress
          task.progress = (sum([step.progress or 0.0 for step in steps])) / len(steps)
          # Don't show an indefinite jump if there are multiple steps.
          if cur_step > 0 and task.step_progress is None:
            task.step_progress = 0.0

          task.message = f'Step {cur_step+1}/{len(steps)}'
          if steps[cur_step].description:
            task.message += f': {steps[cur_step].description}'
        else:
          task.progress = None

  async def manifest(self) -> TaskManifest:
    """Get all tasks."""
    await self._update_tasks()
    tasks_with_progress = [
      task.progress
      for task in self._tasks.values()
      if task.progress and task.status != TaskStatus.COMPLETED
    ]
    return TaskManifest(
      tasks=self._tasks,
      progress=sum(tasks_with_progress) / len(tasks_with_progress) if tasks_with_progress else None)

  def task_id(self, name: str, description: Optional[str] = None) -> TaskId:
    """Create a unique ID for a task."""
    task_id = uuid.uuid4().hex
    self._tasks[task_id] = TaskInfo(
      name=name,
      status=TaskStatus.PENDING,
      progress=None,
      description=description,
      start_timestamp=datetime.now().isoformat())
    return task_id

  def _set_task_completed(self, task_id: TaskId, task_future: Future) -> None:
    end_timestamp = datetime.now().isoformat()
    self._tasks[task_id].end_timestamp = end_timestamp

    elapsed = datetime.fromisoformat(end_timestamp) - datetime.fromisoformat(
      self._tasks[task_id].start_timestamp)
    elapsed_formatted = pretty_timedelta(elapsed)

    if task_future.status == 'error':
      self._tasks[task_id].status = TaskStatus.ERROR
      tb = traceback.format_tb(task_future.traceback())
      e = task_future.exception()
      self._tasks[task_id].error = f'{e}: \n{tb}'
      raise e
    else:
      # This runs in dask callback thread, so we have to make a new event loop.
      loop = asyncio.new_event_loop()
      loop.run_until_complete(self._update_tasks())
      for step in self._tasks[task_id].steps or []:
        step.progress = 1.0

      self._tasks[task_id].status = TaskStatus.COMPLETED
      self._tasks[task_id].progress = 1.0
      self._tasks[task_id].message = f'Completed in {elapsed_formatted}'

    log(f'Task completed "{task_id}": "{self._tasks[task_id].name}" in '
        f'{elapsed_formatted}.')

  def execute(self, task_id: str, task: Task, *args: Any) -> None:
    """Create a unique ID for a task."""
    log(f'Scheduling task "{task_id}": "{self._tasks[task_id].name}".')

    task_info = self._tasks[task_id]
    task_future = self._dask_client.submit(
      functools.partial(_execute_task, task, task_info, task_id), *args, key=task_id)
    task_future.add_done_callback(
      lambda task_future: self._set_task_completed(task_id, task_future))

  async def stop(self) -> None:
    """Stop the task manager and close the dask client."""
    await self._dask_client.close()


@functools.cache
def task_manager() -> TaskManager:
  """The global singleton for the task manager."""
  return TaskManager()


def _execute_task(task: Task, task_info: TaskInfo, task_id: str, *args: Any) -> None:
  get_worker().state.tasks[task_id].annotations['task_info'] = task_info
  task(*args)


def _progress_event_topic(task_id: TaskId) -> Variable:
  return f'{task_id}_progress'


TProgress = TypeVar('TProgress')


def progress(it: Iterable[TProgress],
             task_step_id: Optional[TaskStepId],
             estimated_len: Optional[int],
             step_description: Optional[str] = None,
             emit_every_s: float = 1.) -> Iterable[TProgress]:
  """An iterable wrapper that emits progress and yields the original iterable."""
  if not task_step_id:
    yield from it
    return

  task_id, step_id = task_step_id
  steps = get_worker_steps(task_id)
  if not steps:
    steps = [TaskStepInfo(description=step_description, progress=0.0)]
  elif len(steps) <= step_id:
    # If the step given exceeds the length of the last step, add a new step.
    steps.append(TaskStepInfo(description=step_description, progress=0.0))
  else:
    steps[step_id].description = step_description
    steps[step_id].progress = 0.0
  set_worker_steps(task_id, steps)

  estimated_len = max(1, estimated_len) if estimated_len else None

  task_info: TaskInfo = get_worker().state.tasks[task_id].annotations['task_info']

  it_idx = 0
  start_time = time.time()
  last_emit = time.time() - emit_every_s
  with tqdm(it, desc=task_info.name, total=estimated_len) as tq:
    for t in tq:
      cur_time = time.time()
      if estimated_len and cur_time - last_emit > emit_every_s:
        it_per_sec = tq.format_dict['rate'] or 0.0
        set_worker_task_progress(
          task_step_id=task_step_id,
          it_idx=it_idx,
          elapsed_sec=tq.format_dict['elapsed'] or 0.0,
          it_per_sec=it_per_sec or 0.0,
          estimated_total_sec=((estimated_len) / it_per_sec if it_per_sec else 0),
          estimated_len=estimated_len)
        last_emit = cur_time
      yield t
      it_idx += 1

  total_time = time.time() - start_time
  set_worker_task_progress(
    task_step_id=task_step_id,
    it_idx=estimated_len if estimated_len else it_idx,
    elapsed_sec=total_time,
    it_per_sec=(estimated_len or it_idx) / total_time,
    estimated_total_sec=total_time,
    estimated_len=estimated_len or it_idx)


def set_worker_steps(task_id: TaskId, steps: list[TaskStepInfo]) -> None:
  """Sets up worker steps. Use to provide task step descriptions before they compute."""
  get_worker().log_event(
    _progress_event_topic(task_id), {STEPS_LOG_KEY: [step.dict() for step in steps]})


def get_worker_steps(task_id: TaskId) -> list[TaskStepInfo]:
  """Gets the last worker steps."""
  events = get_client().get_events(_progress_event_topic(task_id))
  if not events or not events[-1]:
    return []

  (_, last_event) = events[-1]
  last_info = last_event.get(STEPS_LOG_KEY)
  return [TaskStepInfo(**step_info) for step_info in last_info]


def set_worker_task_progress(task_step_id: TaskStepId, it_idx: int, elapsed_sec: float,
                             it_per_sec: float, estimated_total_sec: float,
                             estimated_len: int) -> None:
  """Updates a task step with a progress between 0 and 1.

  This method does not exist on the TaskManager as it is meant to be a standalone method used by
  workers running tasks on separate processes so does not have access to task manager state.
  """
  progress = float(it_idx) / estimated_len
  task_id, step_id = task_step_id
  steps = get_worker_steps(task_id)
  if len(steps) <= step_id:
    raise ValueError(f'No step with idx {step_id} exists. Got steps: {steps}')
  steps[step_id].progress = progress

  # 1748/1748 [elapsed 00:16<00:00, 106.30 ex/s]
  elapsed = f'{pretty_timedelta(timedelta(seconds=elapsed_sec))}'
  if it_idx != estimated_len:
    # Only show estimated when in progress.
    elapsed = f'{elapsed} < {pretty_timedelta(timedelta(seconds=estimated_total_sec))}'
  steps[step_id].details = (f'{it_idx:,}/{estimated_len:,} '
                            f'[{elapsed}, {it_per_sec:,.2f} ex/s]')

  set_worker_steps(task_id, steps)