UMMJ's picture
Upload 5875 files
9dd3461
"""Diagnostic components for PyTorch ONNX export."""
import contextlib
from typing import Any, Optional, Tuple, TypeVar
import torch
from torch.onnx._internal.diagnostics import _rules, infra
# This is a workaround for mypy not supporting Self from typing_extensions.
_ExportDiagnostic = TypeVar("_ExportDiagnostic", bound="ExportDiagnostic")
class ExportDiagnostic(infra.Diagnostic):
"""Base class for all export diagnostics.
This class is used to represent all export diagnostics. It is a subclass of
infra.Diagnostic, and adds additional methods to add more information to the
diagnostic.
"""
def __init__(
self,
rule: infra.Rule,
level: infra.Level,
message_args: Optional[Tuple[Any, ...]],
**kwargs,
) -> None:
super().__init__(rule, level, message_args, **kwargs)
def with_cpp_stack(self: _ExportDiagnostic) -> _ExportDiagnostic:
# TODO: Implement this.
# self.stacks.append(...)
raise NotImplementedError()
return self
def with_python_stack(self: _ExportDiagnostic) -> _ExportDiagnostic:
# TODO: Implement this.
# self.stacks.append(...)
raise NotImplementedError()
return self
def with_model_source_location(
self: _ExportDiagnostic,
) -> _ExportDiagnostic:
# TODO: Implement this.
# self.locations.append(...)
raise NotImplementedError()
return self
def with_export_source_location(
self: _ExportDiagnostic,
) -> _ExportDiagnostic:
# TODO: Implement this.
# self.locations.append(...)
raise NotImplementedError()
return self
class ExportDiagnosticTool(infra.DiagnosticTool):
"""Base class for all export diagnostic tools.
This class is used to represent all export diagnostic tools. It is a subclass
of infra.DiagnosticTool.
"""
def __init__(self) -> None:
super().__init__(
name="torch.onnx.export",
version=torch.__version__,
rules=_rules.rules,
diagnostic_type=ExportDiagnostic,
)
class ExportDiagnosticEngine(infra.DiagnosticEngine):
"""PyTorch ONNX Export diagnostic engine.
The only purpose of creating this class instead of using the base class directly
is to provide a background context for `diagnose` calls inside exporter.
By design, one `torch.onnx.export` call should initialize one diagnostic context.
All `diagnose` calls inside exporter should be made in the context of that export.
However, since diagnostic context is currently being accessed via a global variable,
there is no guarantee that the context is properly initialized. Therefore, we need
to provide a default background context to fallback to, otherwise any invocation of
exporter internals, e.g. unit tests, will fail due to missing diagnostic context.
This can be removed once the pipeline for context to flow through the exporter is
established.
"""
_background_context: infra.DiagnosticContext
def __init__(self) -> None:
super().__init__()
self._background_context = infra.DiagnosticContext(
ExportDiagnosticTool(), options=None
)
@property
def background_context(self) -> infra.DiagnosticContext:
return self._background_context
def clear(self):
super().clear()
self._background_context._diagnostics.clear()
def sarif_log(self):
log = super().sarif_log()
log.runs.append(self._background_context.sarif())
return log
engine = ExportDiagnosticEngine()
context = engine.background_context
@contextlib.contextmanager
def create_export_diagnostic_context():
"""Create a diagnostic context for export.
This is a workaround for code robustness since diagnostic context is accessed by
export internals via global variable. See `ExportDiagnosticEngine` for more details.
"""
global context
context = engine.create_diagnostic_context(ExportDiagnosticTool())
try:
yield context
finally:
context = engine.background_context