Spaces:
Running
Running
File size: 1,199 Bytes
a80d6bb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import torch
from pytorch_lightning.profiler import SimpleProfiler, PassThroughProfiler
from contextlib import contextmanager
from pytorch_lightning.utilities import rank_zero_only
class InferenceProfiler(SimpleProfiler):
"""
This profiler records duration of actions with cuda.synchronize()
Use this in test time.
"""
def __init__(self):
super().__init__()
self.start = rank_zero_only(self.start)
self.stop = rank_zero_only(self.stop)
self.summary = rank_zero_only(self.summary)
@contextmanager
def profile(self, action_name: str) -> None:
try:
torch.cuda.synchronize()
self.start(action_name)
yield action_name
finally:
torch.cuda.synchronize()
self.stop(action_name)
def build_profiler(name):
if name == 'inference':
return InferenceProfiler()
elif name == 'pytorch':
from pytorch_lightning.profiler import PyTorchProfiler
return PyTorchProfiler(use_cuda=True, profile_memory=True, row_limit=100)
elif name is None:
return PassThroughProfiler()
else:
raise ValueError(f'Invalid profiler: {name}')
|