import torch from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler from contextlib import contextmanager from pytorch_lightning.utilities import rank_zero_only class InferenceProfiler(SimpleProfiler): """ This profiler records duration of actions with cuda.synchronize() Use this in test time. """ def __init__(self): super().__init__() self.start = rank_zero_only(self.start) self.stop = rank_zero_only(self.stop) self.summary = rank_zero_only(self.summary) @contextmanager def profile(self, action_name: str) -> None: try: torch.cuda.synchronize() self.start(action_name) yield action_name finally: torch.cuda.synchronize() self.stop(action_name) def build_profiler(name): if name == "inference": return InferenceProfiler() elif name == "pytorch": from pytorch_lightning.profiler import PyTorchProfiler return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100) elif name is None: return PassThroughProfiler() else: raise ValueError(f"Invalid profiler: {name}")