|
|
|
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
"""
|
|
|
This file does three things:
|
|
|
- Contains the definition of SymNode
|
|
|
- Installs all the magic methods into SymBool, SymFloat, SymFloat at import time
|
|
|
- Does not depend on sympy at import time
|
|
|
|
|
|
As this file is imported from within torch/__init__.py we do not want it to depend on SymPy
|
|
|
to avoid having to load SymPy at import time, as doing so is *very* slow.
|
|
|
"""
|
|
|
|
|
|
|
|
|
import builtins
|
|
|
import functools
|
|
|
import inspect
|
|
|
import itertools
|
|
|
import logging
|
|
|
import math
|
|
|
import operator
|
|
|
import sys
|
|
|
from functools import lru_cache, update_wrapper
|
|
|
from typing import Optional, TYPE_CHECKING, Union
|
|
|
|
|
|
import torch
|
|
|
import torch._logging.structured as structured
|
|
|
|
|
|
|
|
|
from torch import (
|
|
|
sym_float,
|
|
|
sym_ite,
|
|
|
sym_max,
|
|
|
sym_min,
|
|
|
sym_not,
|
|
|
SymBool,
|
|
|
SymFloat,
|
|
|
SymInt,
|
|
|
)
|
|
|
from torch._logging import dtrace_structured
|
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from torch.fx.experimental.symbolic_shapes import ShapeEnv
|
|
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
sym_node_log = torch._logging.getArtifactLogger(__name__, "sym_node")
|
|
|
|
|
|
|
|
|
__all__ = ["SymNode", "method_to_operator", "magic_methods"]
|
|
|
|
|
|
|
|
|
from torch.types import py_sym_types as SymTypes
|
|
|
|
|
|
|
|
|
def _to_symtype(t):
|
|
|
if t is bool:
|
|
|
return SymBool
|
|
|
if t is int:
|
|
|
return SymInt
|
|
|
if t is float:
|
|
|
return SymFloat
|
|
|
return t
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SymNode:
|
|
|
"""
|
|
|
This is a type erased SymInt/SymFloat which we use to do actual operations.
|
|
|
End users don't touch this. Magic methods are NOT defined on this object.
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
_optimized_summation: bool = False
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
expr,
|
|
|
shape_env,
|
|
|
pytype,
|
|
|
hint: Optional[Union[int, float, bool]],
|
|
|
constant=None,
|
|
|
fx_node=None,
|
|
|
optimized_summation=False,
|
|
|
):
|
|
|
self._expr = expr
|
|
|
self.shape_env = shape_env
|
|
|
self.pytype = pytype
|
|
|
self._optimized_summation = optimized_summation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def compute_hint():
|
|
|
from torch.fx.experimental.symbolic_shapes import has_free_unbacked_symbols
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if has_free_unbacked_symbols(self.expr):
|
|
|
return None
|
|
|
hint = self.shape_env._maybe_evaluate_static(self.expr, compute_hint=True)
|
|
|
if hint is not None:
|
|
|
hint = self.pytype(hint) if not isinstance(hint, SymTypes) else hint
|
|
|
return hint
|
|
|
|
|
|
if hint is not None:
|
|
|
assert type(hint) is pytype or type(hint) is _to_symtype(pytype), (
|
|
|
"Cannot create SymNode of type "
|
|
|
f"{pytype} with incompatible hint of type {type(hint)}"
|
|
|
)
|
|
|
if self.shape_env and self.shape_env._translation_validation_enabled:
|
|
|
|
|
|
|
|
|
computed_hint = compute_hint()
|
|
|
assert hint == computed_hint, (
|
|
|
f"{hint} != {computed_hint} (for {self.expr})"
|
|
|
)
|
|
|
else:
|
|
|
hint = compute_hint()
|
|
|
self._hint = hint
|
|
|
self.constant: Optional[Union[int, float, bool]] = constant
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tx_validation_en = (
|
|
|
self.shape_env and self.shape_env._translation_validation_enabled
|
|
|
)
|
|
|
self.fx_node = tx_validation_en and fx_node
|
|
|
|
|
|
def with_shape_env(self, shape_env: ShapeEnv) -> SymNode:
|
|
|
return SymNode(
|
|
|
self._expr, shape_env, self.pytype, self._hint, self.constant, self.fx_node
|
|
|
)
|
|
|
|
|
|
def _value_eq(self, other: SymNode) -> bool:
|
|
|
|
|
|
return (
|
|
|
self._expr == other._expr
|
|
|
and self.pytype == other.pytype
|
|
|
and self._hint == other._hint
|
|
|
and self.constant == other.constant
|
|
|
and self.fx_node == other.fx_node
|
|
|
)
|
|
|
|
|
|
def _value_hash(self) -> int:
|
|
|
|
|
|
return hash((self._expr, self.pytype, self._hint, self.constant, self.fx_node))
|
|
|
|
|
|
@property
|
|
|
def expr(self):
|
|
|
return self.shape_env.replace(self._expr)
|
|
|
|
|
|
@property
|
|
|
def hint(self):
|
|
|
return self._hint
|
|
|
|
|
|
def has_hint(self):
|
|
|
return self._hint is not None
|
|
|
|
|
|
def require_hint(self, fallback=None):
|
|
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
|
|
|
|
|
if self._hint is None:
|
|
|
if fallback is not None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
unbacked_symbols = free_unbacked_symbols(self.expr)
|
|
|
replacements = {
|
|
|
s: 4096 if s in unbacked_symbols else self.shape_env.var_to_val[s]
|
|
|
for s in self.expr.free_symbols
|
|
|
}
|
|
|
return self.expr.xreplace(replacements)
|
|
|
|
|
|
return self.shape_env.size_hint(self.expr)
|
|
|
return self._hint
|
|
|
|
|
|
def maybe_as_int(self):
|
|
|
if self.expr.is_number:
|
|
|
return int(self.expr)
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
|
|
|
def maybe_as_float(self):
|
|
|
import sympy
|
|
|
|
|
|
if isinstance(self.expr, sympy.Float):
|
|
|
return float(self.expr)
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
def maybe_as_bool(self):
|
|
|
import sympy
|
|
|
|
|
|
if self.expr is sympy.true:
|
|
|
return True
|
|
|
elif self.expr is sympy.false:
|
|
|
return False
|
|
|
else:
|
|
|
return None
|
|
|
|
|
|
def is_int(self):
|
|
|
return self.pytype is int
|
|
|
|
|
|
def is_float(self):
|
|
|
return self.pytype is float
|
|
|
|
|
|
def is_bool(self):
|
|
|
return self.pytype is bool
|
|
|
|
|
|
def is_nested_int(self):
|
|
|
|
|
|
return (
|
|
|
self._hint is not None
|
|
|
and isinstance(self._hint, SymInt)
|
|
|
and self._hint.node.is_nested_int()
|
|
|
)
|
|
|
|
|
|
def wrap_int(self, num):
|
|
|
assert type(num) is int
|
|
|
import sympy
|
|
|
|
|
|
return SymNode(
|
|
|
sympy.Integer(num), self.shape_env, int, num, constant=num, fx_node=num
|
|
|
)
|
|
|
|
|
|
def wrap_float(self, num):
|
|
|
assert type(num) is float
|
|
|
import sympy
|
|
|
|
|
|
return SymNode(
|
|
|
sympy.Float(num), self.shape_env, float, num, constant=num, fx_node=num
|
|
|
)
|
|
|
|
|
|
def wrap_bool(self, num):
|
|
|
assert type(num) is bool
|
|
|
import sympy
|
|
|
|
|
|
return SymNode(
|
|
|
sympy.true if num else sympy.false,
|
|
|
self.shape_env,
|
|
|
bool,
|
|
|
num,
|
|
|
constant=num,
|
|
|
fx_node=num,
|
|
|
)
|
|
|
|
|
|
def clone(self):
|
|
|
return self
|
|
|
|
|
|
def str(self):
|
|
|
return f"{self.expr}"
|
|
|
|
|
|
def __str__(self):
|
|
|
return self.str()
|
|
|
|
|
|
def __repr__(self):
|
|
|
rep = [
|
|
|
f"SymNode({self._expr}, shape_env={self.shape_env}, pytype={self.pytype}",
|
|
|
]
|
|
|
if self._hint is not None:
|
|
|
rep.append(f"hint={self._hint}")
|
|
|
if self.constant is not None:
|
|
|
rep.append(f"constant={self.constant}")
|
|
|
if self.fx_node is not None:
|
|
|
rep.append(f"fx_node={self.fx_node}")
|
|
|
return ", ".join(rep) + ")"
|
|
|
|
|
|
def _graph_repr(self) -> builtins.str:
|
|
|
|
|
|
return self.str()
|
|
|
|
|
|
|
|
|
|
|
|
def abs(self) -> SymNode:
|
|
|
return self._abs()
|
|
|
|
|
|
def pos(self) -> SymNode:
|
|
|
return self._pos()
|
|
|
|
|
|
def round(self, ndigits=None) -> SymNode:
|
|
|
return self._round(ndigits)
|
|
|
|
|
|
def trunc(self) -> SymNode:
|
|
|
return self._trunc()
|
|
|
|
|
|
def add(self, other) -> SymNode:
|
|
|
return self._add(other)
|
|
|
|
|
|
def sub(self, other) -> SymNode:
|
|
|
return self._sub(other)
|
|
|
|
|
|
def mul(self, other) -> SymNode:
|
|
|
return self._mul(other)
|
|
|
|
|
|
def mod(self, other) -> SymNode:
|
|
|
return self._mod(other)
|
|
|
|
|
|
def float_pow(self, other) -> SymNode:
|
|
|
return self._float_pow(other)
|
|
|
|
|
|
def pow_by_natural(self, other) -> SymNode:
|
|
|
return self._pow_by_natural(other)
|
|
|
|
|
|
def and_(self, other) -> SymNode:
|
|
|
return self._and_(other)
|
|
|
|
|
|
def or_(self, other) -> SymNode:
|
|
|
return self._or_(other)
|
|
|
|
|
|
def float_truediv(self, other) -> SymNode:
|
|
|
return self._float_truediv(other)
|
|
|
|
|
|
def int_truediv(self, other) -> SymNode:
|
|
|
return self._int_truediv(other)
|
|
|
|
|
|
def int_floordiv(self, other) -> SymNode:
|
|
|
return self._int_floordiv(other)
|
|
|
|
|
|
def lshift(self, other) -> SymNode:
|
|
|
return self._lshift(other)
|
|
|
|
|
|
def rshift(self, other) -> SymNode:
|
|
|
return self._rshift(other)
|
|
|
|
|
|
def sym_not(self) -> SymNode:
|
|
|
return self._sym_not()
|
|
|
|
|
|
def eq(self, other) -> SymNode:
|
|
|
return self._eq(other)
|
|
|
|
|
|
def ne(self, other) -> SymNode:
|
|
|
return self._ne(other)
|
|
|
|
|
|
def gt(self, other) -> SymNode:
|
|
|
return self._gt(other)
|
|
|
|
|
|
def lt(self, other) -> SymNode:
|
|
|
return self._lt(other)
|
|
|
|
|
|
def le(self, other) -> SymNode:
|
|
|
return self._le(other)
|
|
|
|
|
|
def ge(self, other) -> SymNode:
|
|
|
return self._ge(other)
|
|
|
|
|
|
def floor(self) -> SymNode:
|
|
|
return self._floor()
|
|
|
|
|
|
def is_integer(self) -> SymNode:
|
|
|
return self._is_integer()
|
|
|
|
|
|
def sym_float(self) -> SymNode:
|
|
|
return self._sym_float()
|
|
|
|
|
|
def sym_int(self) -> SymNode:
|
|
|
return self._sym_int()
|
|
|
|
|
|
def ceil(self) -> SymNode:
|
|
|
return self._ceil()
|
|
|
|
|
|
def neg(self) -> SymNode:
|
|
|
return self._neg()
|
|
|
|
|
|
def sym_min(self, other) -> SymNode:
|
|
|
return self._sym_min(other)
|
|
|
|
|
|
def sym_max(self, other) -> SymNode:
|
|
|
return self._sym_max(other)
|
|
|
|
|
|
def sym_ite(self, then_val, else_val) -> SymNode:
|
|
|
return self._sym_ite(then_val, else_val)
|
|
|
|
|
|
def is_contiguous(self, sizes, strides) -> SymNode:
|
|
|
return self._is_contiguous(sizes, strides)
|
|
|
|
|
|
def is_channels_last_contiguous_2d(self, sizes, strides) -> SymNode:
|
|
|
return self._is_channels_last_contiguous_2d(sizes, strides)
|
|
|
|
|
|
def is_channels_last_contiguous_3d(self, sizes, strides) -> SymNode:
|
|
|
return self._is_channels_last_contiguous_3d(sizes, strides)
|
|
|
|
|
|
def is_channels_last_strides_2d(self, sizes, strides) -> SymNode:
|
|
|
return self._is_channels_last_strides_2d(sizes, strides)
|
|
|
|
|
|
def is_channels_last_strides_3d(self, sizes, strides) -> SymNode:
|
|
|
return self._is_channels_last_strides_3d(sizes, strides)
|
|
|
|
|
|
def is_non_overlapping_and_dense_indicator(self, sizes, strides) -> SymNode:
|
|
|
return self._is_non_overlapping_and_dense_indicator(sizes, strides)
|
|
|
|
|
|
|
|
|
def sym_or(self, other):
|
|
|
return self.or_(other)
|
|
|
|
|
|
def sym_and(self, other):
|
|
|
return self.and_(other)
|
|
|
|
|
|
|
|
|
def bitwise_and(self, other):
|
|
|
return self._bitwise_and(other)
|
|
|
|
|
|
def bitwise_or(self, other):
|
|
|
return self._bitwise_or(other)
|
|
|
|
|
|
|
|
|
def truediv(self, other):
|
|
|
return self.float_truediv(other)
|
|
|
|
|
|
def floordiv(self, other) -> SymNode:
|
|
|
return self.int_floordiv(other)
|
|
|
|
|
|
|
|
|
def pow(self, other):
|
|
|
return self.float_pow(other)
|
|
|
|
|
|
def is_non_overlapping_and_dense(self, sizes, strides):
|
|
|
return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(
|
|
|
to_node(self, 1)
|
|
|
)
|
|
|
|
|
|
def int_(self):
|
|
|
return self.guard_int("", 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sym_sum(self, args) -> SymNode:
|
|
|
import sympy
|
|
|
|
|
|
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
get_proxy_mode,
|
|
|
handle_sym_dispatch,
|
|
|
)
|
|
|
|
|
|
if get_proxy_mode():
|
|
|
return to_node(
|
|
|
self,
|
|
|
handle_sym_dispatch(
|
|
|
torch.sym_sum,
|
|
|
(tuple(wrap_node(a) for a in args),),
|
|
|
{},
|
|
|
),
|
|
|
)
|
|
|
exprs = [a.expr for a in args]
|
|
|
out = sympy.Add(*exprs)
|
|
|
|
|
|
size_hints = []
|
|
|
out_hint = None
|
|
|
for a in args:
|
|
|
if a.hint is None:
|
|
|
break
|
|
|
size_hints.append(a.hint)
|
|
|
else:
|
|
|
out_hint = sum(size_hints)
|
|
|
|
|
|
fx_node, _ = self.shape_env._create_fx_call_function(
|
|
|
torch.sym_sum, (tuple(a.fx_node for a in args),)
|
|
|
)
|
|
|
|
|
|
|
|
|
return SymNode(out, self.shape_env, int, out_hint, fx_node=fx_node)
|
|
|
|
|
|
def evaluate(self, size_oblivious=False):
|
|
|
return self.shape_env.evaluate_sym_node(self, size_oblivious)
|
|
|
|
|
|
|
|
|
def guard_int(self, file, line):
|
|
|
|
|
|
|
|
|
r = self.evaluate()
|
|
|
try:
|
|
|
return int(r)
|
|
|
except Exception:
|
|
|
log.warning("Failed to convert to int: %s", r)
|
|
|
raise
|
|
|
|
|
|
def guard_float(self, file, line):
|
|
|
|
|
|
|
|
|
r = self.evaluate()
|
|
|
try:
|
|
|
return float(r)
|
|
|
except Exception:
|
|
|
log.warning("Failed to convert to float: %s", r)
|
|
|
raise
|
|
|
|
|
|
def guard_bool(self, file, line):
|
|
|
|
|
|
|
|
|
r = self.evaluate()
|
|
|
try:
|
|
|
return bool(r)
|
|
|
except Exception:
|
|
|
log.warning("Failed to convert to bool: %s", r)
|
|
|
raise
|
|
|
|
|
|
def expect_true(self, file, line):
|
|
|
from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
|
|
|
|
|
|
if (
|
|
|
self.has_hint()
|
|
|
and not free_unbacked_symbols(self.expr)
|
|
|
and not self.shape_env.prefer_deferred_runtime_asserts_over_guards
|
|
|
):
|
|
|
|
|
|
return self.guard_bool(file, line)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return self.shape_env.guard_or_defer_runtime_assert(
|
|
|
self.expr, f"{file}:{line}", fx_node=self.fx_node
|
|
|
)
|
|
|
|
|
|
def expect_size(self, file, line):
|
|
|
from torch.fx.experimental.symbolic_shapes import _advise_is_size
|
|
|
|
|
|
b = self.ge(self.wrap_int(0))
|
|
|
|
|
|
r = b.expect_true(file, line)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if r and not self.has_hint():
|
|
|
_advise_is_size(SymInt(self))
|
|
|
return r
|
|
|
|
|
|
def statically_known_true(self, file, line):
|
|
|
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
|
|
|
|
|
assert self.is_bool()
|
|
|
return statically_known_true(SymBool(self))
|
|
|
|
|
|
def guard_size_oblivious(self, file, line):
|
|
|
"""
|
|
|
Like guard_bool, but if we encounter unbacked symbols, if those symbols
|
|
|
are size-like, we will treat them as >= 2 for the purposes of the analysis.
|
|
|
|
|
|
This CHANGES the runtime semantics, but all size-oblivious sites have been
|
|
|
audited to ensure that the runtime semantics don't change in a material way.
|
|
|
Acceptable runtime semantic changes are, e.g., squeeze() no longer dropping
|
|
|
an unbacked one size, or a tensor reporting as non-contiguous even if it's
|
|
|
contiguous if it would have been reported contiguous due to being empty.
|
|
|
"""
|
|
|
|
|
|
|
|
|
r = self.evaluate(size_oblivious=True)
|
|
|
try:
|
|
|
return bool(r)
|
|
|
except Exception:
|
|
|
log.warning("Failed to convert to bool: %s", r)
|
|
|
raise
|
|
|
|
|
|
def guard_or_false(self, file, line):
|
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_false
|
|
|
|
|
|
assert self.is_bool()
|
|
|
return guard_or_false(SymBool(self))
|
|
|
|
|
|
def guard_or_true(self, file, line):
|
|
|
from torch.fx.experimental.symbolic_shapes import guard_or_true
|
|
|
|
|
|
assert self.is_bool()
|
|
|
return guard_or_true(SymBool(self))
|
|
|
|
|
|
def bool_(self):
|
|
|
return self.guard_bool("", 0)
|
|
|
|
|
|
def is_symbolic(self):
|
|
|
return True
|
|
|
|
|
|
def nested_int(self):
|
|
|
return None
|
|
|
|
|
|
def is_constant(self):
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
METHOD_TO_OPERATOR = {
|
|
|
"pos": operator.pos,
|
|
|
"abs": operator.abs,
|
|
|
"add": operator.add,
|
|
|
"and": operator.and_,
|
|
|
"bitwise_and": operator.and_,
|
|
|
"ceil": math.ceil,
|
|
|
"eq": operator.eq,
|
|
|
"floor": math.floor,
|
|
|
"trunc": math.trunc,
|
|
|
"int_floordiv": operator.floordiv,
|
|
|
"ge": operator.ge,
|
|
|
"gt": operator.gt,
|
|
|
"is_integer": lambda x: x.is_integer(),
|
|
|
"le": operator.le,
|
|
|
"lshift": operator.lshift,
|
|
|
"lt": operator.lt,
|
|
|
"mod": operator.mod,
|
|
|
"mul": operator.mul,
|
|
|
"ne": operator.ne,
|
|
|
"neg": operator.neg,
|
|
|
"or": operator.or_,
|
|
|
"bitwise_or": operator.or_,
|
|
|
"float_pow": operator.pow,
|
|
|
"pow_by_natural": operator.pow,
|
|
|
"round": builtins.round,
|
|
|
"rshift": operator.rshift,
|
|
|
"sub": operator.sub,
|
|
|
"sym_float": sym_float,
|
|
|
"sym_ite": sym_ite,
|
|
|
"sym_max": sym_max,
|
|
|
"sym_min": sym_min,
|
|
|
"sym_not": sym_not,
|
|
|
"float_truediv": operator.truediv,
|
|
|
"int_truediv": operator.truediv,
|
|
|
}
|
|
|
|
|
|
unary_magic_methods = {
|
|
|
"abs",
|
|
|
"sym_float",
|
|
|
"sym_int",
|
|
|
"ceil",
|
|
|
"floor",
|
|
|
"neg",
|
|
|
"sym_not",
|
|
|
"pos",
|
|
|
"trunc",
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
def _get_sym_node_fn(name):
|
|
|
def fn(self):
|
|
|
return getattr(self, f"_sym_{name}")()
|
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
math_op_names = (
|
|
|
"sqrt",
|
|
|
"cos",
|
|
|
"cosh",
|
|
|
"sin",
|
|
|
"sinh",
|
|
|
"tan",
|
|
|
"tanh",
|
|
|
"asin",
|
|
|
"acos",
|
|
|
"atan",
|
|
|
"log2",
|
|
|
)
|
|
|
for name in math_op_names:
|
|
|
sym_name = f"sym_{name}"
|
|
|
priv_sym_name = f"_{sym_name}"
|
|
|
setattr(SymNode, sym_name, _get_sym_node_fn(name))
|
|
|
METHOD_TO_OPERATOR[sym_name] = getattr(torch, priv_sym_name)
|
|
|
unary_magic_methods.add(sym_name)
|
|
|
__all__.append(sym_name)
|
|
|
|
|
|
|
|
|
|
|
|
unary_nonmagic_methods = {
|
|
|
"is_integer",
|
|
|
}
|
|
|
|
|
|
unary_methods = unary_magic_methods | unary_nonmagic_methods
|
|
|
|
|
|
|
|
|
|
|
|
only_bool_magic_methods = {"and", "or", "sym_not", "sym_ite"}
|
|
|
|
|
|
bool_becomes_int_magic_methods = {"add", "sub", "mul"}
|
|
|
|
|
|
also_bool_magic_methods = {"eq"}
|
|
|
bool_magic_methods = only_bool_magic_methods | also_bool_magic_methods
|
|
|
|
|
|
|
|
|
only_float_magic_methods = {"is_integer", "round", "sym_int", "sym_log2"}
|
|
|
|
|
|
|
|
|
magic_methods_on_operator_with_trailing_underscore = {"and", "or"}
|
|
|
|
|
|
bitwise_ops = {
|
|
|
"bitwise_and": "and",
|
|
|
"bitwise_or": "or",
|
|
|
}
|
|
|
|
|
|
|
|
|
always_float_magic_methods = {"int_truediv", "float_truediv", "sym_float", "float_pow"}
|
|
|
|
|
|
for name in math_op_names:
|
|
|
sym_name = f"sym_{name}"
|
|
|
always_float_magic_methods.add(sym_name)
|
|
|
|
|
|
|
|
|
always_int_magic_methods = {"ceil", "floor", "trunc", "pow_by_natural"}
|
|
|
always_bool_magic_methods = {
|
|
|
"eq",
|
|
|
"ne",
|
|
|
"gt",
|
|
|
"lt",
|
|
|
"le",
|
|
|
"ge",
|
|
|
"and",
|
|
|
"or",
|
|
|
"sym_not",
|
|
|
"is_non_overlapping_and_dense",
|
|
|
"is_integer",
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sympy_float_truediv(a, b):
|
|
|
from torch.utils._sympy.functions import FloatTrueDiv
|
|
|
|
|
|
return FloatTrueDiv(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_int_truediv(a, b):
|
|
|
from torch.utils._sympy.functions import IntTrueDiv
|
|
|
|
|
|
return IntTrueDiv(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_floordiv(a, b):
|
|
|
from torch.utils._sympy.functions import FloorDiv
|
|
|
|
|
|
return FloorDiv(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_mod(a, b):
|
|
|
from torch.utils._sympy.functions import Mod, PythonMod
|
|
|
|
|
|
if a.is_nonnegative and b.is_nonnegative:
|
|
|
return Mod(a, b)
|
|
|
else:
|
|
|
return PythonMod(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_pow_by_natural(a, b):
|
|
|
from torch.utils._sympy.functions import PowByNatural
|
|
|
|
|
|
return PowByNatural(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_float_pow(a, b):
|
|
|
from torch.utils._sympy.functions import FloatPow
|
|
|
|
|
|
return FloatPow(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_and(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.And(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_or(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Or(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_lshift(a, b):
|
|
|
from torch.utils._sympy.functions import LShift
|
|
|
|
|
|
return LShift(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_rshift(a, b):
|
|
|
from torch.utils._sympy.functions import RShift
|
|
|
|
|
|
return RShift(a, b)
|
|
|
|
|
|
|
|
|
def _binary_search_insert_arg(ordered_args, new_arg):
|
|
|
"""
|
|
|
If new_arg is found in ordered_args None is returned, else the new
|
|
|
ordered_args with new_arg inserted
|
|
|
"""
|
|
|
if len(ordered_args) == 0:
|
|
|
return [new_arg]
|
|
|
|
|
|
from sympy.core.basic import _args_sortkey as sort_key, Basic
|
|
|
|
|
|
|
|
|
if sort_key(ordered_args[-1]) < sort_key(new_arg):
|
|
|
return ordered_args + [new_arg]
|
|
|
|
|
|
|
|
|
if sort_key(ordered_args[0]) > sort_key(new_arg):
|
|
|
return [new_arg] + ordered_args
|
|
|
|
|
|
low, high = 0, len(ordered_args) - 1
|
|
|
|
|
|
while low <= high:
|
|
|
mid = (low + high) // 2
|
|
|
compare_result = Basic.compare(ordered_args[mid], new_arg)
|
|
|
if compare_result == 0:
|
|
|
return None
|
|
|
elif compare_result < 0:
|
|
|
low = mid + 1
|
|
|
else:
|
|
|
high = mid - 1
|
|
|
|
|
|
ordered_args.insert(low, new_arg)
|
|
|
return ordered_args
|
|
|
|
|
|
|
|
|
def _optimized_add(
|
|
|
lhs, rhs, lhs_is_optimized_summation=False, rhs_is_optimized_summation=False
|
|
|
):
|
|
|
"""
|
|
|
Custom optimization for Add used to optimize incremental binary summations of certain properties. The idea
|
|
|
is when we know the expression is a summation of unique symbols all we need to know is the correct order of symbols,
|
|
|
and no other optimizations are needed. We pass evaluate=false, with the correct order of args and save the following.
|
|
|
1. Avoid running other optimizations when the Add is constructed.
|
|
|
2. Manually figure out the order of the args for the new expression in log(n) comparisons instead of nLog(n)
|
|
|
(comparing terms is expensive and shows in the profiles).
|
|
|
The function returns a tuple of (1) a boolean that indicates whether the output is a summation of unique symbols,
|
|
|
(2) the result sympy expression.
|
|
|
"""
|
|
|
import sympy
|
|
|
from sympy.core.basic import _args_sortkey as sortkey
|
|
|
|
|
|
def make_optimized(ordered_args):
|
|
|
assert ordered_args is not None
|
|
|
result = sympy.Add(*ordered_args, evaluate=False)
|
|
|
return (True, result)
|
|
|
|
|
|
from torch.utils._sympy.functions import _is_symbols_binary_summation
|
|
|
|
|
|
lhs_is_optimized_summation |= _is_symbols_binary_summation(lhs)
|
|
|
rhs_is_optimized_summation |= _is_symbols_binary_summation(rhs)
|
|
|
|
|
|
if lhs_is_optimized_summation and rhs_is_optimized_summation:
|
|
|
|
|
|
if sortkey(lhs._args[-1]) < sortkey(rhs._args[0]):
|
|
|
return make_optimized(lhs._args + rhs._args)
|
|
|
|
|
|
if sortkey(lhs._args[0]) > sortkey(rhs._args[-1]):
|
|
|
return make_optimized(rhs._args + lhs._args)
|
|
|
|
|
|
|
|
|
if len(lhs._args) <= 2 and len(rhs._args) <= 2:
|
|
|
new_args = list(lhs._args)
|
|
|
for a in rhs._args:
|
|
|
new_args = _binary_search_insert_arg(new_args, a)
|
|
|
if new_args is None:
|
|
|
break
|
|
|
|
|
|
if new_args is not None:
|
|
|
return make_optimized(new_args)
|
|
|
|
|
|
|
|
|
if lhs_is_optimized_summation and rhs.is_symbol:
|
|
|
new_args = _binary_search_insert_arg(list(lhs._args), rhs)
|
|
|
|
|
|
if new_args is not None:
|
|
|
return make_optimized(new_args)
|
|
|
|
|
|
|
|
|
if rhs_is_optimized_summation and lhs.is_symbol:
|
|
|
new_args = _binary_search_insert_arg(list(rhs._args), lhs)
|
|
|
|
|
|
if new_args is not None:
|
|
|
return make_optimized(new_args)
|
|
|
|
|
|
result = sympy.Add(lhs, rhs)
|
|
|
return (_is_symbols_binary_summation(result), result)
|
|
|
|
|
|
|
|
|
def _bitwise_and(a, b):
|
|
|
from torch.utils._sympy.functions import BitwiseFn_bitwise_and
|
|
|
|
|
|
return BitwiseFn_bitwise_and(a, b)
|
|
|
|
|
|
|
|
|
def _bitwise_or(a, b):
|
|
|
from torch.utils._sympy.functions import BitwiseFn_bitwise_or
|
|
|
|
|
|
return BitwiseFn_bitwise_or(a, b)
|
|
|
|
|
|
|
|
|
reflectable_magic_methods = {
|
|
|
"add": _optimized_add,
|
|
|
"sub": operator.sub,
|
|
|
"mul": operator.mul,
|
|
|
"mod": _sympy_mod,
|
|
|
"pow_by_natural": _sympy_pow_by_natural,
|
|
|
"float_pow": _sympy_float_pow,
|
|
|
"and": _sympy_and,
|
|
|
"bitwise_and": _bitwise_and,
|
|
|
"or": _sympy_or,
|
|
|
"bitwise_or": _bitwise_or,
|
|
|
"float_truediv": _sympy_float_truediv,
|
|
|
"int_truediv": _sympy_int_truediv,
|
|
|
"int_floordiv": _sympy_floordiv,
|
|
|
"lshift": _sympy_lshift,
|
|
|
"rshift": _sympy_rshift,
|
|
|
}
|
|
|
|
|
|
|
|
|
def _floor_ceil_helper(a, fn):
|
|
|
import sympy
|
|
|
|
|
|
if isinstance(a, sympy.Mul):
|
|
|
aa = a.args
|
|
|
if len(aa) == 2 and isinstance(aa[0], sympy.Float) and aa[1].is_integer:
|
|
|
coef = sympy.Integer(aa[0])
|
|
|
if aa[0] == coef:
|
|
|
return coef * aa[1]
|
|
|
if (
|
|
|
isinstance(a, sympy.Float)
|
|
|
and a == sympy.Integer(a)
|
|
|
or isinstance(a, sympy.Integer)
|
|
|
):
|
|
|
return sympy.Integer(a)
|
|
|
return fn(a)
|
|
|
|
|
|
|
|
|
def _sympy_floor(a):
|
|
|
from torch.utils._sympy.functions import FloorToInt
|
|
|
|
|
|
return FloorToInt(a)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sympy_trunc(a):
|
|
|
from torch.utils._sympy.functions import TruncToInt
|
|
|
|
|
|
return TruncToInt(a)
|
|
|
|
|
|
|
|
|
def _sympy_ceil(a):
|
|
|
from torch.utils._sympy.functions import CeilToInt
|
|
|
|
|
|
return CeilToInt(a)
|
|
|
|
|
|
|
|
|
def _sympy_eq(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Eq(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_ne(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Ne(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_gt(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Gt(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_lt(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Lt(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_le(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Le(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_ge(a, b):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Ge(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_min(a, b):
|
|
|
from torch.utils._sympy.functions import Min
|
|
|
|
|
|
return Min(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_max(a, b):
|
|
|
from torch.utils._sympy.functions import Max
|
|
|
|
|
|
return Max(a, b)
|
|
|
|
|
|
|
|
|
def _sympy_ite(a, t, f):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Piecewise((t, a), (f, True))
|
|
|
|
|
|
|
|
|
current_module = sys.modules[__name__]
|
|
|
|
|
|
|
|
|
def _get_sym_math_fn(name):
|
|
|
def fn(a):
|
|
|
import torch.utils._sympy.functions
|
|
|
|
|
|
return getattr(torch.utils._sympy.functions, f"OpaqueUnaryFn_{name}")(a)
|
|
|
|
|
|
return fn
|
|
|
|
|
|
|
|
|
for name in math_op_names:
|
|
|
priv_sympy_name = f"_sympy_{name}"
|
|
|
fn = _get_sym_math_fn(name)
|
|
|
fn.__qualname__ = fn.__name__ = priv_sympy_name
|
|
|
setattr(current_module, priv_sympy_name, fn)
|
|
|
|
|
|
del fn, name, priv_sympy_name
|
|
|
|
|
|
|
|
|
def _sympy_abs(a):
|
|
|
import sympy
|
|
|
|
|
|
return sympy.Abs(a)
|
|
|
|
|
|
|
|
|
def _sympy_round(number, ndigits=None):
|
|
|
from torch.utils._sympy.functions import RoundDecimal, RoundToInt
|
|
|
|
|
|
if ndigits is None:
|
|
|
return RoundToInt(number)
|
|
|
else:
|
|
|
return RoundDecimal(number, ndigits)
|
|
|
|
|
|
|
|
|
def _sympy_sym_float(a):
|
|
|
from torch.utils._sympy.functions import ToFloat
|
|
|
|
|
|
|
|
|
|
|
|
return ToFloat(a)
|
|
|
|
|
|
|
|
|
def _sympy_is_integer(a):
|
|
|
import sympy
|
|
|
|
|
|
from torch.utils._sympy.functions import ToFloat
|
|
|
|
|
|
return sympy.Eq(ToFloat(sympy.floor(a)), a)
|
|
|
|
|
|
|
|
|
magic_methods = {
|
|
|
**reflectable_magic_methods,
|
|
|
"sym_not": operator.invert,
|
|
|
"pos": operator.pos,
|
|
|
"eq": _sympy_eq,
|
|
|
"ne": _sympy_ne,
|
|
|
"gt": _sympy_gt,
|
|
|
"lt": _sympy_lt,
|
|
|
"le": _sympy_le,
|
|
|
"ge": _sympy_ge,
|
|
|
"floor": _sympy_floor,
|
|
|
"trunc": _sympy_trunc,
|
|
|
"sym_float": _sympy_sym_float,
|
|
|
"ceil": _sympy_ceil,
|
|
|
"neg": operator.neg,
|
|
|
"sym_min": _sympy_min,
|
|
|
"sym_max": _sympy_max,
|
|
|
"sym_ite": _sympy_ite,
|
|
|
"abs": _sympy_abs,
|
|
|
"round": _sympy_round,
|
|
|
"is_integer": _sympy_is_integer,
|
|
|
}
|
|
|
|
|
|
|
|
|
for name in math_op_names:
|
|
|
sym_name = f"sym_{name}"
|
|
|
magic_methods[sym_name] = getattr(current_module, f"_sympy_{name}")
|
|
|
|
|
|
del name, sym_name, math_op_names, current_module
|
|
|
|
|
|
|
|
|
def sympy_is_contiguous(sizes, strides):
|
|
|
dim = len(sizes)
|
|
|
return sympy_is_contiguous_generic(sizes, strides, list(range(dim - 1, -1, -1)))
|
|
|
|
|
|
|
|
|
def sympy_is_contiguous_generic(sizes, strides, dim_order):
|
|
|
import sympy
|
|
|
|
|
|
dim = len(sizes)
|
|
|
|
|
|
if len(dim_order) != dim:
|
|
|
return sympy.false
|
|
|
|
|
|
is_contiguous = sympy.true
|
|
|
z = sympy.S.One
|
|
|
|
|
|
for d in dim_order:
|
|
|
is_contiguous &= sympy.Eq(sizes[d], sympy.S.One) | sympy.Eq(strides[d], z)
|
|
|
z *= sizes[d]
|
|
|
|
|
|
for d in range(dim):
|
|
|
is_contiguous |= sympy.Eq(sizes[d], sympy.S.Zero)
|
|
|
return is_contiguous
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sympy_is_channels_last_contiguous_2d(sizes, strides):
|
|
|
return sympy_is_contiguous_generic(sizes, strides, [1, 3, 2, 0])
|
|
|
|
|
|
|
|
|
def sympy_is_channels_last_contiguous_3d(sizes, strides):
|
|
|
return sympy_is_contiguous_generic(sizes, strides, [1, 4, 3, 2, 0])
|
|
|
|
|
|
|
|
|
def sympy_is_channels_last_strides_generic(sizes, strides, dim_order):
|
|
|
import sympy
|
|
|
|
|
|
from torch.utils._sympy.functions import Max
|
|
|
|
|
|
dim = len(sizes)
|
|
|
|
|
|
if dim != len(dim_order):
|
|
|
return sympy.false
|
|
|
|
|
|
m = sympy.S.Zero
|
|
|
r = sympy.true
|
|
|
|
|
|
|
|
|
r &= sympy.Ne(strides[1], 0)
|
|
|
|
|
|
for d in dim_order:
|
|
|
r &= sympy.Ne(sizes[d], 0) & (strides[d] >= m)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if d == 0:
|
|
|
r &= sympy.Ne(m, strides[1])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
m = strides[d] * Max(sizes[d], 1)
|
|
|
|
|
|
return r
|
|
|
|
|
|
|
|
|
def sympy_is_channels_last_strides_2d(sizes, strides):
|
|
|
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 3, 2, 0])
|
|
|
|
|
|
|
|
|
def sympy_is_channels_last_strides_3d(sizes, strides):
|
|
|
return sympy_is_channels_last_strides_generic(sizes, strides, [1, 4, 3, 2, 0])
|
|
|
|
|
|
|
|
|
def _sympy_is_non_overlapping_and_dense_indicator(sizes, strides):
|
|
|
from torch.utils._sympy.functions import IsNonOverlappingAndDenseIndicator
|
|
|
|
|
|
return IsNonOverlappingAndDenseIndicator(*sizes, *strides)
|
|
|
|
|
|
|
|
|
sizes_strides_methods = {
|
|
|
|
|
|
|
|
|
"is_contiguous": sympy_is_contiguous,
|
|
|
"is_channels_last_contiguous_2d": sympy_is_channels_last_contiguous_2d,
|
|
|
"is_channels_last_contiguous_3d": sympy_is_channels_last_contiguous_3d,
|
|
|
"is_channels_last_strides_2d": sympy_is_channels_last_strides_2d,
|
|
|
"is_channels_last_strides_3d": sympy_is_channels_last_strides_3d,
|
|
|
"is_non_overlapping_and_dense_indicator": _sympy_is_non_overlapping_and_dense_indicator,
|
|
|
}
|
|
|
|
|
|
|
|
|
def to_node(self, num):
|
|
|
if isinstance(num, SymTypes):
|
|
|
return num.node
|
|
|
elif type(num) is bool:
|
|
|
return self.wrap_bool(num)
|
|
|
elif type(num) is int:
|
|
|
return self.wrap_int(num)
|
|
|
elif type(num) is float:
|
|
|
return self.wrap_float(num)
|
|
|
else:
|
|
|
|
|
|
|
|
|
return NotImplemented
|
|
|
|
|
|
|
|
|
def wrap_node(x):
|
|
|
|
|
|
if isinstance(x, SymNode) and x.constant is not None:
|
|
|
return x.constant
|
|
|
if x.is_int():
|
|
|
return SymInt(x)
|
|
|
elif x.is_float():
|
|
|
return SymFloat(x)
|
|
|
elif x.is_bool():
|
|
|
return SymBool(x)
|
|
|
else:
|
|
|
raise AssertionError(f"unrecognized return type {x}")
|
|
|
|
|
|
|
|
|
def method_to_operator(method):
|
|
|
return METHOD_TO_OPERATOR[method]
|
|
|
|
|
|
|
|
|
def _make_node_magic(method, func):
|
|
|
func = lru_cache(256)(func)
|
|
|
|
|
|
if method in magic_methods_on_operator_with_trailing_underscore:
|
|
|
method_attr = f"{method}_"
|
|
|
else:
|
|
|
method_attr = method
|
|
|
|
|
|
def uninteresting_files() -> set[str]:
|
|
|
import torch
|
|
|
|
|
|
mods = [
|
|
|
torch._dynamo.eval_frame,
|
|
|
torch._dynamo.utils,
|
|
|
torch.fx.experimental.sym_node,
|
|
|
torch,
|
|
|
]
|
|
|
import torch._dynamo.guards
|
|
|
|
|
|
return (
|
|
|
{inspect.getfile(m) for m in mods}
|
|
|
| torch._dynamo.guards.uninteresting_files()
|
|
|
| {"<string>"}
|
|
|
)
|
|
|
|
|
|
def capture_provenance(fn):
|
|
|
@functools.wraps(fn)
|
|
|
def wrapper(self, other=None):
|
|
|
if other is None:
|
|
|
result = fn(self)
|
|
|
else:
|
|
|
result = fn(self, other)
|
|
|
if torch._logging._internal.GET_DTRACE_STRUCTURED:
|
|
|
if other is not None:
|
|
|
arguments = [self, other]
|
|
|
else:
|
|
|
arguments = [self]
|
|
|
|
|
|
def get_id(sym_node) -> Optional[int]:
|
|
|
|
|
|
import sympy
|
|
|
|
|
|
if sym_node.constant is not None:
|
|
|
return None
|
|
|
elif id(sym_node) == id(result):
|
|
|
return None
|
|
|
elif isinstance(sym_node.expr, (sympy.Integer, sympy.Float)):
|
|
|
return None
|
|
|
elif sym_node.expr in (sympy.true, sympy.false):
|
|
|
return None
|
|
|
return id(sym_node)
|
|
|
|
|
|
dtrace_structured(
|
|
|
"expression_created",
|
|
|
metadata_fn=lambda: {
|
|
|
"method": method,
|
|
|
"result": str(result),
|
|
|
"result_id": id(result),
|
|
|
"arguments": [str(a) for a in arguments],
|
|
|
"argument_ids": [
|
|
|
get_id(i) for i in arguments if get_id(i) is not None
|
|
|
],
|
|
|
"user_stack": structured.get_user_stack(3),
|
|
|
"stack": structured.get_framework_stack(3),
|
|
|
},
|
|
|
)
|
|
|
|
|
|
return result
|
|
|
|
|
|
return wrapper
|
|
|
|
|
|
@capture_provenance
|
|
|
def binary_magic_impl(self, other):
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
get_proxy_mode,
|
|
|
handle_sym_dispatch,
|
|
|
)
|
|
|
|
|
|
op = method_to_operator(method)
|
|
|
|
|
|
out_hint = None
|
|
|
if self.hint is not None and other.hint is not None:
|
|
|
out_hint = op(self.hint, other.hint)
|
|
|
|
|
|
if get_proxy_mode():
|
|
|
return to_node(
|
|
|
self, handle_sym_dispatch(op, (wrap_node(self), wrap_node(other)), {})
|
|
|
)
|
|
|
assert isinstance(other, SymNode)
|
|
|
optimized_summation = False
|
|
|
try:
|
|
|
if method == "mod":
|
|
|
from torch.utils._sympy.functions import Mod, PythonMod
|
|
|
|
|
|
|
|
|
|
|
|
shape_env = self.shape_env
|
|
|
if (
|
|
|
self.expr.is_nonnegative
|
|
|
or shape_env.bound_sympy(self.expr).lower >= 0
|
|
|
) and (
|
|
|
other.expr.is_nonnegative
|
|
|
or shape_env.bound_sympy(other.expr).lower >= 0
|
|
|
):
|
|
|
out = Mod(self.expr, other.expr)
|
|
|
else:
|
|
|
out = PythonMod(self.expr, other.expr)
|
|
|
elif method == "add":
|
|
|
|
|
|
(optimized_summation, out) = func(
|
|
|
self.expr,
|
|
|
other.expr,
|
|
|
self._optimized_summation,
|
|
|
other._optimized_summation,
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
out = func(self.expr, other.expr)
|
|
|
except Exception:
|
|
|
log.warning("failed to eval %s(%s, %s)", method, self.expr, other.expr)
|
|
|
raise
|
|
|
sym_node_log.debug("%s %s %s -> %s", method, self.expr, other.expr, out)
|
|
|
pytype: type
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if method in always_float_magic_methods:
|
|
|
pytype = float
|
|
|
elif method in always_bool_magic_methods:
|
|
|
pytype = bool
|
|
|
elif self.pytype is float or other.pytype is float:
|
|
|
pytype = float
|
|
|
else:
|
|
|
pytype = self.pytype
|
|
|
|
|
|
if (
|
|
|
pytype is not None
|
|
|
and out_hint is not None
|
|
|
and not isinstance(out_hint, SymTypes)
|
|
|
):
|
|
|
out_hint = pytype(out_hint)
|
|
|
|
|
|
|
|
|
|
|
|
fx_node, _ = self.shape_env._create_fx_call_function(
|
|
|
op, (self.fx_node, other.fx_node)
|
|
|
)
|
|
|
|
|
|
result = SymNode(
|
|
|
out,
|
|
|
self.shape_env,
|
|
|
pytype,
|
|
|
out_hint,
|
|
|
fx_node=fx_node,
|
|
|
optimized_summation=optimized_summation,
|
|
|
)
|
|
|
return result
|
|
|
|
|
|
@capture_provenance
|
|
|
def unary_magic_impl(self):
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
get_proxy_mode,
|
|
|
handle_sym_dispatch,
|
|
|
)
|
|
|
|
|
|
op = method_to_operator(method)
|
|
|
if get_proxy_mode():
|
|
|
return to_node(self, handle_sym_dispatch(op, (wrap_node(self),), {}))
|
|
|
|
|
|
expr = self.expr
|
|
|
if method == "floor" or method == "ceiling":
|
|
|
expr = self.shape_env._simplify_floor_div(expr)
|
|
|
|
|
|
try:
|
|
|
out = func(expr)
|
|
|
except Exception:
|
|
|
log.warning("failed to eval %s(%s)", method, expr)
|
|
|
raise
|
|
|
sym_node_log.debug("%s %s -> %s", func, expr, out)
|
|
|
out_hint = None
|
|
|
if self.hint is not None:
|
|
|
out_hint = op(self.hint)
|
|
|
pytype: type
|
|
|
if method in always_int_magic_methods:
|
|
|
pytype = int
|
|
|
elif method in always_bool_magic_methods:
|
|
|
pytype = bool
|
|
|
elif method in always_float_magic_methods:
|
|
|
pytype = float
|
|
|
else:
|
|
|
pytype = self.pytype
|
|
|
|
|
|
fx_node, _ = self.shape_env._create_fx_call_function(op, (self.fx_node,))
|
|
|
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
|
|
|
|
|
|
if method in unary_methods:
|
|
|
setattr(SymNode, f"_{method_attr}", unary_magic_impl)
|
|
|
elif method == "sym_ite":
|
|
|
|
|
|
def sym_ite_impl(pred_node, then_node, else_node):
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
get_proxy_mode,
|
|
|
handle_sym_dispatch,
|
|
|
)
|
|
|
|
|
|
out_hint = then_node.hint if pred_node.hint else else_node.hint
|
|
|
if get_proxy_mode():
|
|
|
return to_node(
|
|
|
pred_node,
|
|
|
handle_sym_dispatch(
|
|
|
sym_ite,
|
|
|
(
|
|
|
wrap_node(pred_node),
|
|
|
wrap_node(then_node),
|
|
|
wrap_node(else_node),
|
|
|
),
|
|
|
{},
|
|
|
),
|
|
|
)
|
|
|
|
|
|
try:
|
|
|
out = func(pred_node.expr, then_node.expr, else_node.expr)
|
|
|
except Exception:
|
|
|
log.warning(
|
|
|
"failed to eval %s(%s, %s, %s)",
|
|
|
method,
|
|
|
pred_node.expr,
|
|
|
then_node.expr,
|
|
|
else_node.expr,
|
|
|
)
|
|
|
raise
|
|
|
|
|
|
fx_node, _ = pred_node.shape_env._create_fx_call_function(
|
|
|
sym_ite, (pred_node.fx_node, then_node.fx_node, else_node.fx_node)
|
|
|
)
|
|
|
return SymNode(
|
|
|
out, pred_node.shape_env, then_node.pytype, out_hint, fx_node=fx_node
|
|
|
)
|
|
|
|
|
|
setattr(SymNode, f"_{method_attr}", sym_ite_impl)
|
|
|
elif method == "round":
|
|
|
|
|
|
def round_impl(self, ndigits=None):
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
get_proxy_mode,
|
|
|
handle_sym_dispatch,
|
|
|
)
|
|
|
|
|
|
op = builtins.round
|
|
|
if get_proxy_mode():
|
|
|
return to_node(
|
|
|
self, handle_sym_dispatch(op, (wrap_node(self), ndigits), {})
|
|
|
)
|
|
|
|
|
|
expr = self.expr
|
|
|
try:
|
|
|
out = func(expr, ndigits)
|
|
|
except Exception:
|
|
|
log.warning("failed to eval %s(%s, ndigits=%s)", method, expr, ndigits)
|
|
|
raise
|
|
|
|
|
|
if ndigits is None:
|
|
|
pytype = int
|
|
|
else:
|
|
|
pytype = self.pytype
|
|
|
|
|
|
out_hint = None
|
|
|
if self.hint is not None:
|
|
|
out_hint = op(self.hint, ndigits)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
args = [self.fx_node]
|
|
|
if ndigits is not None:
|
|
|
args.append(ndigits)
|
|
|
fx_node, _ = self.shape_env._create_fx_call_function(op, tuple(args))
|
|
|
return SymNode(out, self.shape_env, pytype, out_hint, fx_node=fx_node)
|
|
|
|
|
|
setattr(SymNode, f"_{method_attr}", round_impl)
|
|
|
else:
|
|
|
setattr(SymNode, f"_{method_attr}", binary_magic_impl)
|
|
|
|
|
|
|
|
|
def _make_node_sizes_strides(method, func):
|
|
|
|
|
|
|
|
|
def sizes_strides_impl(self, sizes, strides):
|
|
|
from torch.fx.experimental.proxy_tensor import (
|
|
|
get_proxy_mode,
|
|
|
handle_sym_dispatch,
|
|
|
)
|
|
|
|
|
|
op = getattr(sys.modules[__name__], method)
|
|
|
if get_proxy_mode():
|
|
|
return to_node(
|
|
|
self,
|
|
|
handle_sym_dispatch(
|
|
|
op,
|
|
|
([wrap_node(s) for s in sizes], [wrap_node(s) for s in strides]),
|
|
|
{},
|
|
|
),
|
|
|
)
|
|
|
size_exprs = [s.expr for s in sizes]
|
|
|
stride_exprs = [s.expr for s in strides]
|
|
|
try:
|
|
|
out = func(size_exprs, stride_exprs)
|
|
|
except Exception:
|
|
|
log.warning("failed to eval %s(%s, %s)", method, size_exprs, stride_exprs)
|
|
|
raise
|
|
|
|
|
|
|
|
|
size_hints = []
|
|
|
out_hint = None
|
|
|
for s in sizes:
|
|
|
if s.hint is None:
|
|
|
break
|
|
|
size_hints.append(s.hint)
|
|
|
else:
|
|
|
stride_hints = []
|
|
|
for s in strides:
|
|
|
if s.hint is None:
|
|
|
break
|
|
|
stride_hints.append(s.hint)
|
|
|
else:
|
|
|
out_hint = op(size_hints, stride_hints)
|
|
|
|
|
|
|
|
|
pytype: type
|
|
|
if method.endswith("_indicator"):
|
|
|
pytype = int
|
|
|
else:
|
|
|
pytype = bool
|
|
|
return SymNode(out, self.shape_env, pytype, out_hint)
|
|
|
|
|
|
setattr(SymNode, f"_{method}", sizes_strides_impl)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def sizes_strides_user(sizes, strides):
|
|
|
import sympy
|
|
|
|
|
|
from torch.fx.experimental.symbolic_shapes import (
|
|
|
eval_is_non_overlapping_and_dense,
|
|
|
)
|
|
|
|
|
|
for a in itertools.chain(sizes, strides):
|
|
|
if isinstance(a, SymInt):
|
|
|
return wrap_node(
|
|
|
getattr(a.node, method)(
|
|
|
[to_node(a.node, b) for b in sizes],
|
|
|
[to_node(a.node, b) for b in strides],
|
|
|
)
|
|
|
)
|
|
|
if method == "is_non_overlapping_and_dense_indicator":
|
|
|
return eval_is_non_overlapping_and_dense(sizes, strides)
|
|
|
else:
|
|
|
|
|
|
return bool(
|
|
|
func(
|
|
|
[sympy.sympify(a) for a in sizes],
|
|
|
[sympy.sympify(a) for a in strides],
|
|
|
)
|
|
|
)
|
|
|
|
|
|
|
|
|
if not hasattr(sys.modules[__name__], method):
|
|
|
setattr(sys.modules[__name__], method, sizes_strides_user)
|
|
|
|
|
|
|
|
|
for method, func in magic_methods.items():
|
|
|
_make_node_magic(method, func)
|
|
|
|
|
|
for method, func in sizes_strides_methods.items():
|
|
|
_make_node_sizes_strides(method, func)
|
|
|
|
|
|
|
|
|
def _make_user_magic(method, user_type):
|
|
|
|
|
|
|
|
|
|
|
|
if method in magic_methods_on_operator_with_trailing_underscore:
|
|
|
method_attr = f"sym_{method}"
|
|
|
else:
|
|
|
method_attr = method
|
|
|
|
|
|
def get_constant(x: Union[SymInt, int, SymFloat, float, SymBool, bool]):
|
|
|
if isinstance(x, (int, float, bool)):
|
|
|
return x
|
|
|
if isinstance(x, SymBool):
|
|
|
return x.node.guard_bool("", 0)
|
|
|
raise AssertionError("expect to be called with constant SymBools")
|
|
|
|
|
|
def is_constant(x):
|
|
|
if isinstance(x, (int, float, bool)):
|
|
|
return True
|
|
|
if isinstance(x, (SymInt, SymFloat, SymBool)):
|
|
|
return x.node.is_constant()
|
|
|
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if method in bool_becomes_int_magic_methods:
|
|
|
|
|
|
def promote(x):
|
|
|
"""Implements True+True=2, which works in python but not sympy"""
|
|
|
if isinstance(x, SymBool):
|
|
|
return SymInt(x.node.wrap_int(int(x)))
|
|
|
return x
|
|
|
|
|
|
else:
|
|
|
|
|
|
def promote(x):
|
|
|
return x
|
|
|
|
|
|
def promote2(self, other):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if method not in [
|
|
|
"add",
|
|
|
"sub",
|
|
|
"mul",
|
|
|
"mod",
|
|
|
"float_pow",
|
|
|
"float_truediv",
|
|
|
"int_floordiv",
|
|
|
"sym_min",
|
|
|
"sym_max",
|
|
|
|
|
|
"eq",
|
|
|
"ne",
|
|
|
"gt",
|
|
|
"lt",
|
|
|
"le",
|
|
|
"ge",
|
|
|
]:
|
|
|
return self, other
|
|
|
f_self = isinstance(self, (float, torch.SymFloat))
|
|
|
f_other = isinstance(other, (float, torch.SymFloat))
|
|
|
if f_self or f_other:
|
|
|
if not f_self:
|
|
|
self = torch.sym_float(self)
|
|
|
if not f_other:
|
|
|
other = torch.sym_float(other)
|
|
|
return self, other
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def unary_magic_impl(self):
|
|
|
self = promote(self)
|
|
|
if is_constant(self):
|
|
|
return (method_to_operator(method))(get_constant(self))
|
|
|
return wrap_node(getattr(self.node, method_attr)())
|
|
|
|
|
|
def binary_magic_impl(self, other):
|
|
|
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
|
|
return NotImplemented
|
|
|
sym_node_log.debug("MAGIC %s %s %s", method, self, other)
|
|
|
self = promote(self)
|
|
|
other = promote(other)
|
|
|
self, other = promote2(self, other)
|
|
|
if is_constant(self):
|
|
|
return (method_to_operator(method))(get_constant(self), other)
|
|
|
if is_constant(other):
|
|
|
other = get_constant(other)
|
|
|
other_node = to_node(self.node, other)
|
|
|
if other_node is NotImplemented:
|
|
|
return NotImplemented
|
|
|
ret = wrap_node(getattr(self.node, method_attr)(other_node))
|
|
|
return get_constant(ret) if is_constant(ret) else ret
|
|
|
|
|
|
def rbinary_magic_impl(self, other):
|
|
|
if not isinstance(other, (int, float, bool, SymInt, SymFloat, SymBool)):
|
|
|
return NotImplemented
|
|
|
self = promote(self)
|
|
|
other = promote(other)
|
|
|
self, other = promote2(self, other)
|
|
|
if is_constant(self):
|
|
|
return (method_to_operator(method))(get_constant(self), other)
|
|
|
if is_constant(other):
|
|
|
other = get_constant(other)
|
|
|
other_node = to_node(self.node, other)
|
|
|
if other_node is NotImplemented:
|
|
|
return NotImplemented
|
|
|
ret = wrap_node(getattr(other_node, method_attr)(self.node))
|
|
|
return get_constant(ret) if is_constant(ret) else ret
|
|
|
|
|
|
if method in unary_magic_methods:
|
|
|
setattr(user_type, f"__{method}__", unary_magic_impl)
|
|
|
elif method in unary_nonmagic_methods:
|
|
|
orig = getattr(user_type, method)
|
|
|
setattr(user_type, method, update_wrapper(unary_magic_impl, orig))
|
|
|
elif method == "sym_ite":
|
|
|
|
|
|
def sym_ite_magic_impl(pred, then_val, else_val):
|
|
|
pred_node = pred.node
|
|
|
then_node = to_node(pred_node, then_val)
|
|
|
else_node = to_node(pred_node, else_val)
|
|
|
if then_node is NotImplemented or else_node is NotImplemented:
|
|
|
return NotImplemented
|
|
|
assert (
|
|
|
isinstance(then_node, SymNode)
|
|
|
and isinstance(else_node, SymNode)
|
|
|
and then_node.pytype == else_node.pytype
|
|
|
)
|
|
|
ret = wrap_node(getattr(pred.node, method_attr)(then_node, else_node))
|
|
|
return get_constant(ret) if ret.node.is_constant() else ret
|
|
|
|
|
|
setattr(user_type, f"__{method}__", sym_ite_magic_impl)
|
|
|
elif method == "round":
|
|
|
|
|
|
def round_magic_impl(self, ndigits=None):
|
|
|
if is_constant(self):
|
|
|
return builtins.round(get_constant(self), ndigits)
|
|
|
|
|
|
return wrap_node(getattr(self.node, method)(ndigits))
|
|
|
|
|
|
setattr(user_type, f"__{method}__", round_magic_impl)
|
|
|
else:
|
|
|
method_name = method
|
|
|
if method in bitwise_ops:
|
|
|
method_name = bitwise_ops[method]
|
|
|
setattr(user_type, f"__{method_name}__", binary_magic_impl)
|
|
|
if method in reflectable_magic_methods:
|
|
|
setattr(user_type, f"__r{method_name}__", rbinary_magic_impl)
|
|
|
|
|
|
|
|
|
for method, func in magic_methods.items():
|
|
|
if method in only_bool_magic_methods:
|
|
|
_make_user_magic(method, SymBool)
|
|
|
continue
|
|
|
if method in only_float_magic_methods:
|
|
|
_make_user_magic(method, SymFloat)
|
|
|
continue
|
|
|
if method in also_bool_magic_methods or method in bool_becomes_int_magic_methods:
|
|
|
_make_user_magic(method, SymBool)
|
|
|
_make_user_magic(method, SymInt)
|
|
|
if method not in bitwise_ops:
|
|
|
_make_user_magic(method, SymFloat)
|
|
|
|
|
|
del method
|
|
|
del func
|
|
|
|