Julian Bilcke
we are going to hack into finetrainers
9fd1204
raw
history blame
3.42 kB
import time
from dataclasses import dataclass
from enum import Enum
import torch
from finetrainers.constants import FINETRAINERS_ENABLE_TIMING
from finetrainers.logging import get_logger
logger = get_logger()
class TimerDevice(str, Enum):
CPU = "cpu"
CUDA = "cuda"
@dataclass
class TimerData:
name: str
device: TimerDevice
start_time: float = 0.0
end_time: float = 0.0
class Timer:
def __init__(self, name: str, device: TimerDevice, device_sync: bool = False):
self.data = TimerData(name=name, device=device)
self._device_sync = device_sync
self._start_event = None
self._end_event = None
self._active = False
self._enabled = FINETRAINERS_ENABLE_TIMING
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.end()
return False
def start(self):
if self._active:
logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.")
return
self._active = True
if not self._enabled:
return
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
self._start_cuda()
else:
self._start_cpu()
if not self.data.device == TimerDevice.CPU:
logger.warning(
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU."
)
def end(self):
if not self._active:
logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.")
return
self._active = False
if not self._enabled:
return
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
self._end_cuda()
else:
self._end_cpu()
if not self.data.device == TimerDevice.CPU:
logger.warning(
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU."
)
@property
def elapsed_time(self) -> float:
if self._active:
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
premature_end_event = torch.cuda.Event(enable_timing=True)
premature_end_event.record()
premature_end_event.synchronize()
return self._start_event.elapsed_time(premature_end_event) / 1000.0
else:
return time.time() - self.data.start_time
else:
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available():
return self._start_event.elapsed_time(self._end_event) / 1000.0
else:
return self.data.end_time - self.data.start_time
def _start_cpu(self):
self.data.start_time = time.time()
def _start_cuda(self):
torch.cuda.synchronize()
self._start_event = torch.cuda.Event(enable_timing=True)
self._end_event = torch.cuda.Event(enable_timing=True)
self._start_event.record()
def _end_cpu(self):
self.data.end_time = time.time()
def _end_cuda(self):
if self._device_sync:
torch.cuda.synchronize()
self._end_event.record()