import time import inspect from collections import defaultdict import torch class Timer(): def __init__(self): self.timings = {} self.start_time = 0 def start(self): self.start_time = time.time() def end(self): frameinfo = inspect.getouterframes( inspect.currentframe() )[1] filename = frameinfo.filename filename = '/'.join(filename.split('/')[-2:]) marker = f'{filename}:{frameinfo.lineno}' if marker not in self.timings.keys(): self.timings[marker] = [] else: self.timings[marker].append( float(time.time() - self.start_time) ) def report(self): n_points = len(self.timings.keys()) if n_points > 0: print('[TIMER]:') for marker in self.timings: print(f'{marker}: {torch.FloatTensor(self.timings[marker]).mean().item()}') else: print('[TIMER]: No record') timer = Timer()