File size: 1,922 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 |
from collections import deque, defaultdict
from functools import wraps
from types import GeneratorType
from typing import Callable
import numpy as np
import time
from ditk import logging
from ding.framework import task
class StepTimer:
def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None:
"""
Overview:
Print time cost of each step (execute one middleware).
Arguments:
- print_per_step (:obj:`int`): Print each N step.
- smooth_window (:obj:`int`): The window size to smooth the mean.
"""
self.print_per_step = print_per_step
self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window))
def __call__(self, fn: Callable) -> Callable:
step_name = getattr(fn, "__name__", type(fn).__name__)
@wraps(fn)
def executor(ctx):
start_time = time.time()
time_cost = 0
g = fn(ctx)
if isinstance(g, GeneratorType):
try:
next(g)
except StopIteration:
pass
time_cost = time.time() - start_time
yield
start_time = time.time()
try:
next(g)
except StopIteration:
pass
time_cost += time.time() - start_time
else:
time_cost = time.time() - start_time
self.records[step_name].append(time_cost)
if ctx.total_step % self.print_per_step == 0:
logging.info(
"[Step Timer][Node:{:>2}] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format(
task.router.node_id or 0, step_name, time_cost * 1000,
np.mean(self.records[step_name]) * 1000
)
)
return executor
|