|
|
|
import functools |
|
import math |
|
import sys |
|
|
|
import sympy |
|
from sympy import S |
|
|
|
__all__ = [ |
|
"FloorDiv", |
|
"ModularIndexing", |
|
"CleanDiv", |
|
"CeilDiv", |
|
"IntTrueDiv", |
|
"FloatTrueDiv", |
|
"LShift", |
|
"RShift", |
|
"IsNonOverlappingAndDenseIndicator", |
|
"RoundToInt", |
|
"RoundDecimal", |
|
"ToFloat", |
|
"FloatPow", |
|
"PowByNatural", |
|
] |
|
|
|
|
|
def _keep_float(f): |
|
@functools.wraps(f) |
|
def inner(*args): |
|
r = f(*args) |
|
if any(isinstance(a, sympy.Float) for a in args) and not isinstance( |
|
r, sympy.Float |
|
): |
|
r = sympy.Float(float(r)) |
|
return r |
|
|
|
return inner |
|
|
|
|
|
def fuzzy_eq(x, y): |
|
if None in (x, y): |
|
return None |
|
return x == y |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FloorDiv(sympy.Function): |
|
""" |
|
We maintain this so that: |
|
1. We can use divisibility guards to simplify FloorDiv(a, b) to a / b. |
|
2. Printing out the expression is nicer (compared to say, representing a//b as (a - a % b) / b) |
|
|
|
NB: This is Python-style floor division, round to -Inf |
|
""" |
|
|
|
nargs = (2,) |
|
precedence = 50 |
|
|
|
is_integer = True |
|
|
|
@property |
|
def base(self): |
|
return self.args[0] |
|
|
|
@property |
|
def divisor(self): |
|
return self.args[1] |
|
|
|
def _sympystr(self, printer): |
|
base = printer.parenthesize(self.base, self.precedence) |
|
divisor = printer.parenthesize(self.divisor, self.precedence) |
|
return f"({base}//{divisor})" |
|
|
|
|
|
|
|
@classmethod |
|
def eval(cls, base, divisor): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if divisor.is_zero: |
|
raise ZeroDivisionError("division by zero") |
|
|
|
if base.is_zero: |
|
return sympy.S.Zero |
|
if base.is_integer and divisor == 1: |
|
return base |
|
if base.is_integer and divisor == -1: |
|
return sympy.Mul(base, -1) |
|
if isinstance(base, sympy.Integer) and isinstance(divisor, sympy.Integer): |
|
return sympy.Integer(int(base) // int(divisor)) |
|
if isinstance(base, FloorDiv): |
|
return FloorDiv(base.args[0], base.args[1] * divisor) |
|
|
|
|
|
|
|
""" |
|
if isinstance(base, sympy.Add): |
|
for a in base.args: |
|
gcd = sympy.gcd(a, divisor) |
|
if gcd == divisor: |
|
return FloorDiv(base - a, divisor) + a / gcd |
|
""" |
|
|
|
try: |
|
gcd = sympy.gcd(base, divisor) |
|
if gcd != 1: |
|
return FloorDiv( |
|
sympy.simplify(base / gcd), sympy.simplify(divisor / gcd) |
|
) |
|
except sympy.PolynomialError: |
|
pass |
|
|
|
|
|
class ModularIndexing(sympy.Function): |
|
""" |
|
ModularIndexing(a, b, c) => (a // b) % c where % is the C modulus |
|
""" |
|
|
|
nargs = (3,) |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, base, divisor, modulus): |
|
if base == 0 or modulus == 1: |
|
return sympy.Integer(0) |
|
|
|
if ( |
|
isinstance(base, sympy.Integer) |
|
and isinstance(divisor, sympy.Integer) |
|
and isinstance(modulus, sympy.Integer) |
|
): |
|
return (base // divisor) % modulus |
|
|
|
try: |
|
if divisor != 1: |
|
gcd = sympy.gcd(base, divisor) |
|
if gcd != 1: |
|
return ModularIndexing( |
|
sympy.simplify(base / gcd), |
|
sympy.simplify(divisor / gcd), |
|
modulus, |
|
) |
|
except sympy.PolynomialError: |
|
pass |
|
|
|
if isinstance(base, sympy.Add): |
|
new_terms = [] |
|
all_positive = True |
|
for term in base.args: |
|
if sympy.gcd(term, modulus * divisor) != modulus * divisor: |
|
if (isinstance(term, sympy.Integer) and term < 0) or ( |
|
isinstance(term, sympy.Mul) |
|
and isinstance(term.args[0], sympy.Integer) |
|
and term.args[0] < 0 |
|
): |
|
|
|
|
|
|
|
|
|
all_positive = False |
|
break |
|
else: |
|
new_terms.append(term) |
|
|
|
if len(new_terms) != len(base.args) and all_positive: |
|
return ModularIndexing(sum(new_terms), divisor, modulus) |
|
|
|
if isinstance(base, FloorDiv): |
|
return ModularIndexing(base.args[0], base.args[1] * divisor, modulus) |
|
|
|
def _eval_is_nonnegative(self): |
|
p, q = self.args[:2] |
|
return fuzzy_eq(p.is_nonnegative, q.is_nonnegative) |
|
|
|
def _eval_is_positive(self): |
|
p, q = self.args[:2] |
|
return fuzzy_eq(p.is_positive, q.is_positive) |
|
|
|
|
|
class Where(sympy.Function): |
|
""" |
|
Good ol' ternary operator |
|
""" |
|
|
|
nargs = (3,) |
|
|
|
def _eval_is_integer(self): |
|
return True if self.args[1].is_integer and self.args[2].is_integer else None |
|
|
|
def _eval_is_nonnegative(self): |
|
return ( |
|
True |
|
if self.args[1].is_nonnegative and self.args[2].is_nonnegative |
|
else None |
|
) |
|
|
|
def _eval_is_positive(self): |
|
return True if self.args[1].is_positive and self.args[2].is_positive else None |
|
|
|
@classmethod |
|
def eval(cls, c, p, q): |
|
if c == sympy.true: |
|
return p |
|
elif c == sympy.false: |
|
return q |
|
|
|
|
|
|
|
class PythonMod(sympy.Function): |
|
nargs = (2,) |
|
|
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, p, q): |
|
|
|
|
|
|
|
|
|
|
|
if q.is_zero: |
|
raise ZeroDivisionError("Modulo by zero") |
|
|
|
|
|
|
|
|
|
|
|
if p is S.Zero or p in (q, -q) or q == 1: |
|
return S.Zero |
|
|
|
|
|
if q.is_Number and p.is_Number: |
|
return p % q |
|
|
|
|
|
if q.is_Number and q == 2: |
|
if p.is_even: |
|
return S.Zero |
|
if p.is_odd: |
|
return S.One |
|
|
|
|
|
r = p / q |
|
if r.is_integer: |
|
return S.Zero |
|
|
|
|
|
|
|
|
|
less = p < q |
|
if less.is_Boolean and bool(less) and r.is_positive: |
|
return p |
|
|
|
if sympy.Mod(p, q) == 0: |
|
return S.Zero |
|
|
|
|
|
def _eval_is_nonnegative(self): |
|
return True if self.args[1].is_positive else None |
|
|
|
def _eval_is_nonpositive(self): |
|
return True if self.args[1].is_negative else None |
|
|
|
|
|
|
|
class Mod(sympy.Function): |
|
nargs = (2,) |
|
|
|
is_integer = True |
|
is_nonnegative = True |
|
|
|
@classmethod |
|
def eval(cls, p, q): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if q.is_zero: |
|
raise ZeroDivisionError("Modulo by zero") |
|
|
|
|
|
|
|
|
|
|
|
if p is S.Zero or p in (q, -q) or q == 1: |
|
return S.Zero |
|
|
|
|
|
if q.is_Number and p.is_Number: |
|
assert p >= 0, p |
|
assert q >= 1, q |
|
return p % q |
|
|
|
|
|
if q.is_Number and q == 2: |
|
if p.is_even: |
|
return S.Zero |
|
if p.is_odd: |
|
return S.One |
|
|
|
|
|
r = p / q |
|
if r.is_integer: |
|
return S.Zero |
|
|
|
|
|
|
|
|
|
less = p < q |
|
if less.is_Boolean and bool(less) and r.is_positive: |
|
return p |
|
|
|
|
|
class CleanDiv(FloorDiv): |
|
""" |
|
Div where we can assume no rounding. |
|
This is to enable future optimizations. |
|
""" |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
class CeilToInt(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, number): |
|
|
|
if number == sympy.oo: |
|
return sympy.Integer(sys.maxsize - 1) |
|
if number == -sympy.oo: |
|
return sympy.Integer(-sys.maxsize - 1) |
|
if isinstance(number, sympy.Number): |
|
return sympy.Integer(math.ceil(float(number))) |
|
|
|
|
|
class FloorToInt(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, number): |
|
|
|
if number == sympy.oo: |
|
return sympy.Integer(sys.maxsize - 1) |
|
if number == -sympy.oo: |
|
return sympy.Integer(-sys.maxsize - 1) |
|
if isinstance(number, sympy.Number): |
|
return sympy.Integer(math.floor(float(number))) |
|
|
|
|
|
class CeilDiv(sympy.Function): |
|
""" |
|
Div used in indexing that rounds up. |
|
""" |
|
|
|
is_integer = True |
|
|
|
def __new__(cls, base, divisor): |
|
base = sympy.sympify(base) |
|
divisor = sympy.sympify(divisor) |
|
if sympy.gcd(base, divisor) == divisor: |
|
return CleanDiv(base, divisor) |
|
else: |
|
return FloorDiv(base + (divisor - 1), divisor) |
|
|
|
|
|
class LShift(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, base, shift): |
|
if shift < 0: |
|
raise ValueError("negative shift count") |
|
return base * 2**shift |
|
|
|
|
|
class RShift(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, base, shift): |
|
if shift < 0: |
|
raise ValueError("negative shift count") |
|
return base // 2**shift |
|
|
|
|
|
def safe_pow(base, exp): |
|
sign = 1 |
|
if base < 0: |
|
base = -base |
|
sign = 1 if exp % 2 == 0 else -1 |
|
return sign * _safe_pow(base, exp) |
|
|
|
|
|
def _safe_pow(base, exponent): |
|
if exponent < 0: |
|
raise ValueError("Exponent must be non-negative.") |
|
|
|
if exponent == 0: |
|
return 1 |
|
|
|
half_exp = safe_pow(base, exponent // 2) |
|
if half_exp > sys.maxsize - 1: |
|
return sys.maxsize - 1 |
|
|
|
result = half_exp * half_exp |
|
if result > sys.maxsize - 1: |
|
return sys.maxsize - 1 |
|
|
|
if exponent % 2 == 1: |
|
result *= base |
|
if result > sys.maxsize - 1: |
|
return sys.maxsize - 1 |
|
|
|
return result |
|
|
|
|
|
class PowByNatural(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, base, exp): |
|
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): |
|
return sympy.Integer(safe_pow(base, exp)) |
|
if isinstance(exp, sympy.Integer): |
|
|
|
r = sympy.Integer(1) |
|
for _ in range(int(exp)): |
|
r *= base |
|
return r |
|
|
|
|
|
|
|
|
|
|
|
|
|
class FloatPow(sympy.Function): |
|
is_integer = False |
|
is_real = True |
|
|
|
@classmethod |
|
def eval(cls, base, exp): |
|
if isinstance(base, sympy.Number) and isinstance(exp, sympy.Number): |
|
return sympy.Float(float(base) ** float(exp)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FloatTrueDiv(sympy.Function): |
|
is_integer = False |
|
is_real = True |
|
|
|
@classmethod |
|
def eval(cls, base, divisor): |
|
|
|
|
|
|
|
if divisor.is_zero: |
|
raise ZeroDivisionError("division by zero") |
|
|
|
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): |
|
return sympy.Float(float(base) / float(divisor)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IntTrueDiv(sympy.Function): |
|
is_integer = False |
|
is_real = True |
|
|
|
@classmethod |
|
def eval(cls, base, divisor): |
|
if divisor.is_zero: |
|
raise ZeroDivisionError("division by zero") |
|
|
|
if isinstance(base, sympy.Number) and isinstance(divisor, sympy.Number): |
|
return sympy.Float(int(base) / int(divisor)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class IsNonOverlappingAndDenseIndicator(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, *args): |
|
assert len(args) % 2 == 0 |
|
dim = len(args) // 2 |
|
|
|
|
|
|
|
|
|
|
|
if all(isinstance(a, sympy.Integer) for a in args): |
|
|
|
from torch.fx.experimental.symbolic_shapes import ( |
|
eval_is_non_overlapping_and_dense, |
|
) |
|
|
|
size_args = args[0:dim] |
|
stride_args = args[dim:] |
|
return eval_is_non_overlapping_and_dense( |
|
[int(a) for a in size_args], [int(a) for a in stride_args] |
|
) |
|
return None |
|
|
|
|
|
|
|
class TruncToFloat(sympy.Function): |
|
is_integer = False |
|
is_real = True |
|
|
|
@classmethod |
|
def eval(cls, number): |
|
|
|
if isinstance(number, sympy.Number): |
|
|
|
|
|
|
|
return sympy.Float(math.trunc(float(number))) |
|
|
|
|
|
class TruncToInt(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, number): |
|
|
|
if number == sympy.oo: |
|
return sympy.Integer(sys.maxsize - 1) |
|
if number == -sympy.oo: |
|
return sympy.Integer(-sys.maxsize - 1) |
|
if isinstance(number, sympy.Number): |
|
return sympy.Integer(math.trunc(float(number))) |
|
|
|
|
|
|
|
class RoundToInt(sympy.Function): |
|
is_integer = True |
|
|
|
@classmethod |
|
def eval(cls, number): |
|
|
|
|
|
if isinstance(number, sympy.Float): |
|
return sympy.Integer(round(float(number), 0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class RoundDecimal(sympy.Function): |
|
is_integer = False |
|
is_real = True |
|
|
|
@classmethod |
|
def eval(cls, number, ndigits): |
|
|
|
|
|
if isinstance(number, sympy.Float) and isinstance(ndigits, sympy.Integer): |
|
return sympy.Float(round(float(number), int(ndigits))) |
|
|
|
|
|
class ToFloat(sympy.Function): |
|
is_integer = False |
|
is_real = True |
|
|
|
@classmethod |
|
def eval(cls, number): |
|
if number in [sympy.oo, -sympy.oo]: |
|
return number |
|
|
|
if isinstance(number, sympy.Integer): |
|
return sympy.Float(int(number)) |
|
|
|
|
|
def make_opaque_unary_fn(name): |
|
class OpaqueUnaryFn(sympy.Function): |
|
""" |
|
Unlike the builtin sympy functions on real numbers like sympy.sqrt, |
|
these equivalents do not do any nontrivial reasoning besides |
|
constant propagation. This helps avoid performing transformations |
|
that are valid for real numbers but are invalid for floating point; |
|
in particular, while we are willing to make optimizations that change |
|
numerics for Tensor compute, we are NOT willing to make optimziations |
|
that change numerics for size compute. |
|
""" |
|
|
|
_torch_handler_name = name |
|
|
|
@classmethod |
|
def eval(cls, a): |
|
if isinstance(a, (sympy.Integer, sympy.Float)): |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
return sympy.Float(getattr(math, name)(float(a))) |
|
|
|
|
|
except OverflowError: |
|
return getattr(sympy, name)(a) |
|
elif a in [sympy.oo, -sympy.oo, sympy.zoo, -sympy.zoo]: |
|
return getattr(sympy, name)(a) |
|
return None |
|
|
|
OpaqueUnaryFn.__name__ = "OpaqueUnaryFn_" + name |
|
|
|
return OpaqueUnaryFn |
|
|
|
|
|
|
|
OpaqueUnaryFn_sqrt = make_opaque_unary_fn("sqrt") |
|
OpaqueUnaryFn_cos = make_opaque_unary_fn("cos") |
|
OpaqueUnaryFn_cosh = make_opaque_unary_fn("cosh") |
|
OpaqueUnaryFn_sin = make_opaque_unary_fn("sin") |
|
OpaqueUnaryFn_sinh = make_opaque_unary_fn("sinh") |
|
OpaqueUnaryFn_tan = make_opaque_unary_fn("tan") |
|
OpaqueUnaryFn_tanh = make_opaque_unary_fn("tanh") |
|
OpaqueUnaryFn_asin = make_opaque_unary_fn("asin") |
|
OpaqueUnaryFn_acos = make_opaque_unary_fn("acos") |
|
OpaqueUnaryFn_atan = make_opaque_unary_fn("atan") |
|
OpaqueUnaryFn_exp = make_opaque_unary_fn("exp") |
|
OpaqueUnaryFn_log = make_opaque_unary_fn("log") |
|
OpaqueUnaryFn_asinh = make_opaque_unary_fn("asinh") |
|
|