|
import torch |
|
from torch.utils._pytree import tree_map |
|
from typing import Iterator, List |
|
import logging |
|
import contextlib |
|
import itertools |
|
from torch.utils._python_dispatch import TorchDispatchMode |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LoggingTensor(torch.Tensor): |
|
elem: torch.Tensor |
|
|
|
__slots__ = ['elem'] |
|
|
|
context = contextlib.nullcontext |
|
|
|
__torch_function__ = torch._C._disabled_torch_function_impl |
|
|
|
@staticmethod |
|
def __new__(cls, elem, *args, **kwargs): |
|
|
|
|
|
|
|
r = torch.Tensor._make_wrapper_subclass( |
|
cls, elem.size(), |
|
strides=elem.stride(), storage_offset=elem.storage_offset(), |
|
|
|
dtype=elem.dtype, layout=elem.layout, |
|
device=elem.device, requires_grad=kwargs.get("requires_grad", False) |
|
) |
|
|
|
r.elem = elem.detach() if r.requires_grad else elem |
|
return r |
|
|
|
def __repr__(self): |
|
return super().__repr__(tensor_contents=f"{self.elem}") |
|
|
|
@classmethod |
|
def __torch_dispatch__(cls, func, types, args=(), kwargs=None): |
|
def unwrap(e): |
|
return e.elem if isinstance(e, cls) else e |
|
|
|
def wrap(e): |
|
return cls(e) if isinstance(e, torch.Tensor) else e |
|
|
|
with cls.context(): |
|
rs = tree_map(wrap, func(*tree_map(unwrap, args), **tree_map(unwrap, kwargs))) |
|
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) |
|
return rs |
|
|
|
class LoggingTensorMode(TorchDispatchMode): |
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None): |
|
if kwargs is None: |
|
kwargs = {} |
|
rs = func(*args, **kwargs) |
|
logging.getLogger("LoggingTensor").info(f"{func.__module__}.{func.__name__}", args, kwargs, rs) |
|
return rs |
|
|
|
class LoggingTensorReentrant(LoggingTensor): |
|
context = torch.overrides.enable_reentrant_dispatch |
|
|
|
|
|
class LoggingTensorHandler(logging.Handler): |
|
log_list: List[str] |
|
next_shortid: int |
|
|
|
def __init__(self, log_list: List[str], use_shortid_for_all_tensors: bool) -> None: |
|
logging.Handler.__init__(self) |
|
self.log_list = log_list |
|
self.next_shortid = 0 |
|
self.use_shortid_for_all_tensors = use_shortid_for_all_tensors |
|
|
|
|
|
|
|
def _shortid(self, o: object) -> int: |
|
if not hasattr(o, '_shortid'): |
|
o._shortid = self.next_shortid |
|
self.next_shortid += 1 |
|
return o._shortid |
|
|
|
def _fmt(self, a: object) -> str: |
|
cond_cls = torch.Tensor if self.use_shortid_for_all_tensors else LoggingTensor |
|
return f'${self._shortid(a)}' if isinstance(a, cond_cls) else repr(a) |
|
|
|
def emit(self, record): |
|
fmt_args = ", ".join(itertools.chain( |
|
(self._fmt(a) for a in record.args[0]), |
|
(f"{k}={self._fmt(v)}" for k, v in record.args[1].items()) |
|
)) |
|
fmt_rets = ", ".join(self._fmt(a) for a in record.args[2]) \ |
|
if isinstance(record.args[2], (list, tuple)) else self._fmt(record.args[2]) |
|
self.log_list.append(f'{fmt_rets} = {record.msg}({fmt_args})') |
|
|
|
def log_input(name: str, var: object): |
|
logging.getLogger("LoggingTensor").info("input", (name,), {}, (var,)) |
|
|
|
@contextlib.contextmanager |
|
def capture_logs(is_mode=False) -> Iterator[List[str]]: |
|
logger = logging.getLogger("LoggingTensor") |
|
log_list: List[str] = [] |
|
handler = LoggingTensorHandler(log_list, use_shortid_for_all_tensors=is_mode) |
|
logger.addHandler(handler) |
|
logger.setLevel(logging.INFO) |
|
logger.propagate = False |
|
try: |
|
yield log_list |
|
finally: |
|
logger.removeHandler(handler) |
|
|
|
@contextlib.contextmanager |
|
def capture_logs_with_logging_tensor_mode(): |
|
with LoggingTensorMode(), capture_logs(True) as logs: |
|
yield logs |
|
|