|
from collections import deque, defaultdict |
|
from functools import wraps |
|
from types import GeneratorType |
|
from typing import Callable |
|
import numpy as np |
|
import time |
|
from ditk import logging |
|
from ding.framework import task |
|
|
|
|
|
class StepTimer: |
|
|
|
def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None: |
|
""" |
|
Overview: |
|
Print time cost of each step (execute one middleware). |
|
Arguments: |
|
- print_per_step (:obj:`int`): Print each N step. |
|
- smooth_window (:obj:`int`): The window size to smooth the mean. |
|
""" |
|
|
|
self.print_per_step = print_per_step |
|
self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window)) |
|
|
|
def __call__(self, fn: Callable) -> Callable: |
|
step_name = getattr(fn, "__name__", type(fn).__name__) |
|
|
|
@wraps(fn) |
|
def executor(ctx): |
|
start_time = time.time() |
|
time_cost = 0 |
|
g = fn(ctx) |
|
if isinstance(g, GeneratorType): |
|
try: |
|
next(g) |
|
except StopIteration: |
|
pass |
|
time_cost = time.time() - start_time |
|
yield |
|
start_time = time.time() |
|
try: |
|
next(g) |
|
except StopIteration: |
|
pass |
|
time_cost += time.time() - start_time |
|
else: |
|
time_cost = time.time() - start_time |
|
self.records[step_name].append(time_cost) |
|
if ctx.total_step % self.print_per_step == 0: |
|
logging.info( |
|
"[Step Timer][Node:{:>2}] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format( |
|
task.router.node_id or 0, step_name, time_cost * 1000, |
|
np.mean(self.records[step_name]) * 1000 |
|
) |
|
) |
|
|
|
return executor |
|
|