|
|
import time |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
import torch |
|
|
|
|
|
from finetrainers.constants import FINETRAINERS_ENABLE_TIMING |
|
|
from finetrainers.logging import get_logger |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
|
|
|
|
|
|
class TimerDevice(str, Enum): |
|
|
CPU = "cpu" |
|
|
CUDA = "cuda" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TimerData: |
|
|
name: str |
|
|
device: TimerDevice |
|
|
start_time: float = 0.0 |
|
|
end_time: float = 0.0 |
|
|
|
|
|
|
|
|
class Timer: |
|
|
def __init__(self, name: str, device: TimerDevice, device_sync: bool = False): |
|
|
self.data = TimerData(name=name, device=device) |
|
|
|
|
|
self._device_sync = device_sync |
|
|
self._start_event = None |
|
|
self._end_event = None |
|
|
self._active = False |
|
|
self._enabled = FINETRAINERS_ENABLE_TIMING |
|
|
|
|
|
def __enter__(self): |
|
|
self.start() |
|
|
return self |
|
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
|
self.end() |
|
|
return False |
|
|
|
|
|
def start(self): |
|
|
if self._active: |
|
|
logger.warning(f"Timer {self.data.name} is already running. Please stop it before starting again.") |
|
|
return |
|
|
self._active = True |
|
|
if not self._enabled: |
|
|
return |
|
|
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): |
|
|
self._start_cuda() |
|
|
else: |
|
|
self._start_cpu() |
|
|
if not self.data.device == TimerDevice.CPU: |
|
|
logger.warning( |
|
|
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." |
|
|
) |
|
|
|
|
|
def end(self): |
|
|
if not self._active: |
|
|
logger.warning(f"Timer {self.data.name} is not running. Please start it before stopping.") |
|
|
return |
|
|
self._active = False |
|
|
if not self._enabled: |
|
|
return |
|
|
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): |
|
|
self._end_cuda() |
|
|
else: |
|
|
self._end_cpu() |
|
|
if not self.data.device == TimerDevice.CPU: |
|
|
logger.warning( |
|
|
f"Timer device {self.data.device} is either not supported or incorrect device selected. Falling back to CPU." |
|
|
) |
|
|
|
|
|
@property |
|
|
def elapsed_time(self) -> float: |
|
|
if self._active: |
|
|
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): |
|
|
premature_end_event = torch.cuda.Event(enable_timing=True) |
|
|
premature_end_event.record() |
|
|
premature_end_event.synchronize() |
|
|
return self._start_event.elapsed_time(premature_end_event) / 1000.0 |
|
|
else: |
|
|
return time.time() - self.data.start_time |
|
|
else: |
|
|
if self.data.device == TimerDevice.CUDA and torch.cuda.is_available(): |
|
|
return self._start_event.elapsed_time(self._end_event) / 1000.0 |
|
|
else: |
|
|
return self.data.end_time - self.data.start_time |
|
|
|
|
|
def _start_cpu(self): |
|
|
self.data.start_time = time.time() |
|
|
|
|
|
def _start_cuda(self): |
|
|
torch.cuda.synchronize() |
|
|
self._start_event = torch.cuda.Event(enable_timing=True) |
|
|
self._end_event = torch.cuda.Event(enable_timing=True) |
|
|
self._start_event.record() |
|
|
|
|
|
def _end_cpu(self): |
|
|
self.data.end_time = time.time() |
|
|
|
|
|
def _end_cuda(self): |
|
|
if self._device_sync: |
|
|
torch.cuda.synchronize() |
|
|
self._end_event.record() |
|
|
|