import functools import itertools import logging from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, Union import sympy from sympy import Expr from torch.fx.experimental.symbolic_shapes import ShapeEnv from torch.utils._sympy.functions import FloorDiv, ModularIndexing from torch.utils._sympy.value_ranges import bound_sympy from .utils import sympy_subs, sympy_symbol, VarRanges from .virtualized import V log = logging.getLogger(__name__) # This class is a little awkward, because ShapeEnv is doing most of the heavy # lifting and in some cases we should be directly passing through to ShapeEnv, # but there is some extra inductor logic that needs to be handled here class SizeVarAllocator: def __init__(self, shape_env=None): super().__init__() if shape_env is None: shape_env = ShapeEnv() self.shape_env = shape_env self.var_to_val = self.shape_env.var_to_val self.replacements: Dict[sympy.Symbol, Expr] = self.shape_env.replacements # Maps of dynamic sizes that have to be precomputed on the host to the kernel args. # The basic idea is if we have some complicated sympy expression # f(s0), we may choose to precompute it on the host and then replace # all occurrences of that sympy expression with ps0, so that when we # codegen we simply reference ps0 directly without repeating # f(s0). Unlike regular size variables, ps variables cannot be # guarded upon; so if we are asked to guard on a Sympy expression # which potentially could have already had a precomputed replacement # on it, we are obligated to invert the precomputed replacements # (inv_precomputed_replacements). self.precomputed_replacements: Dict[Expr, sympy.Symbol] = dict() self.inv_precomputed_replacements: Dict[sympy.Symbol, Expr] = dict() self.stride_vars = self.make_stride_vars_cache() self.simplify_with_ranges = self.make_simplify_with_ranges_cache() self._simplify_loops = self.make_simplify_loops_cache() def simplify(self, expr: Expr): return sympy.expand(expr).xreplace(self.replacements) def make_simplify_with_ranges_cache(self) -> Callable[[Expr, VarRanges], Expr]: """ self._simplify_with_ranges() can be expensive, cache its results """ cache: Dict[Tuple[Any, ...], Expr] = dict() replacement_count = len(self.replacements) def simplify_with_ranges(expr: Expr, var_ranges: VarRanges) -> Expr: nonlocal replacement_count if replacement_count != len(self.replacements): # new replacements invalidates cached results cache.clear() replacement_count = len(self.replacements) key = (expr, *var_ranges.items()) result = cache.get(key, None) if result is None: result = self._simplify_with_ranges(expr, var_ranges) cache[key] = result return result return simplify_with_ranges def make_simplify_loops_cache(self): """ self._simplify_with_ranges() can be expensive, cache its results """ cache: Dict[Tuple[Any, ...], Any] = dict() replacement_count = len(self.replacements) def simplify_loops(index_vars, sizes, index_formulas): nonlocal replacement_count if replacement_count != len(self.replacements): # new replacements invalidates cached results cache.clear() replacement_count = len(self.replacements) key = (*index_vars, *sizes, *index_formulas) result = cache.get(key, None) if result is None: result = self._simplify_loops_impl(index_vars, sizes, index_formulas) cache[key] = result return result return simplify_loops def _simplify_with_ranges(self, expr: Expr, var_ranges: VarRanges) -> Expr: """ Simplify indexing expression with knowledge of the ranges of iteration variables. """ expr = join_dimensions(self.simplify(expr)) original_expr = expr def remove_zero_terms(base, divisor): """Symbols smaller than the divisor are zero""" for v in base.free_symbols: if v in var_ranges: # var smaller than divisor can be removed # if the rest is guaranteed to be multiple of divisor rest = sympy.Wild("_rest", exclude=[v]) m = base.match(v + rest) if m and v not in m[rest].free_symbols: gcd = sympy.gcd(m[rest], divisor) if gcd == divisor: if self.statically_known_leq(var_ranges[v], divisor): base = m[rest] return base def visit_indexing_div(base, divisor): return FloorDiv(remove_zero_terms(base, divisor), divisor) def visit_modular_indexing(base, divisor, modulus): base = remove_zero_terms(base, divisor) base_pos = True if isinstance(base, ModularIndexing): # for modular indexing, biggest values from the ranges don't necessarily result in # the biggest result, the biggest result is modulus - 1 base_s = base.args[2] - 1 elif not base.has(ModularIndexing): # actual iteration range is to size-1 iter_ranges_zero = {k: 0 for k, v in var_ranges.items()} base_lowest = sympy_subs(base, iter_ranges_zero) if self.statically_known_leq(0, base_lowest): # can't replace with indexing div if base can be negative base_pos = True else: base_pos = False iter_ranges = {k: v - 1 for k, v in var_ranges.items()} base_s = sympy_subs(base, iter_ranges) else: base_s = base if self.statically_known_lt(base_s, modulus * divisor) and base_pos: return FloorDiv(base, divisor) return ModularIndexing(base, divisor, modulus) if expr.has(ModularIndexing): expr = expr.replace( ModularIndexing( sympy.Wild("base"), sympy.Wild("divisor"), sympy.Wild("modulus"), ), visit_modular_indexing, ) if expr.has(FloorDiv): expr = expr.replace( FloorDiv( sympy.Wild("base"), sympy.Wild("divisor"), ), visit_indexing_div, ) if expr != original_expr: return self._simplify_with_ranges(expr, var_ranges) return expr def _simplify_loops_impl( self, index_vars: List[sympy.Symbol], sizes, index_formulas ): """ Try to remove as many axis from loop iterations as possible, by: 1) removing size==1 dimensions 2) fuse contiguous dimensions into a single loop If channel_last = True, we will prevent the last dim fused with other dims """ sizes = list(map(self.simplify, sizes)) strides = [self.stride_vars(x, index_vars) for x in index_formulas] assert len(sizes) == len(strides[0]), (len(sizes), len(strides[0])) for i in range(len(sizes)): if sizes[i] == 1: # remove dim sizes[i] = None def can_merge_dims(a, b): for k in range(len(strides)): if self.simplify(strides[k][a] * sizes[a]) == self.simplify( strides[k][b] ): # approximate test passed, try sound version va = index_vars[a] vb = index_vars[b] v = sympy_symbol("_merge_tester") expr1 = sympy_subs(index_formulas[k], {va: v * sizes[a], vb: 0}) expr2 = sympy_subs(index_formulas[k], {va: 0, vb: v}) if self.simplify(expr1) == self.simplify(expr2): continue return False return True changed = True while changed: changed = False for i, j in itertools.product( reversed(range(len(sizes))), reversed(range(len(sizes))) ): if i == j or sizes[i] is None or sizes[j] is None: continue if can_merge_dims(i, j): changed = True sizes[i] = sizes[i] * sizes[j] sizes[j] = None def reindex(index): it = list(reversed(index)) new_index = [] for size in sizes: if size is None: new_index.append(sympy.Integer(0)) else: new_index.append(it.pop()) assert not it return new_index def prune(index): assert len(index) == len(sizes) return [i for i, s in zip(index, sizes) if s is not None] return [x for x in sizes if x is not None], reindex, prune # Note - [On Statically Known] # # The statically_known_* family of functions below replaces a prior system, called maybe_guard_*. The prior system # operated by providing essentially a question, where the size hinted values were evaluated. If the condition was # true, we add a guard and return True, otherwise, False. # # def maybe_guard_foo(args): # if size_hinted_check(args): # return False # No guard, no optim # guard(args) # Make a guard # return True # Safe to apply optimization # # The prior system incurred a guard, and green lit an optimization. # # The new system works in reverse - in the new system, if we know that the inputs are static, and evaluate the # condition as true, we green light the optimization, and we do not incur a guard. If we cannot prove that, we # return False. # # def maybe_guard_foo(args): # if all_static(args): # return True # Safe to apply optimization # else: # return False # No guard, no optim # See Note - [On Statically Known] def is_expr_static_and_true(self, expr: Union[Expr, int]) -> bool: if expr in (True, False): return bool(expr) try: simplified = self.shape_env._maybe_evaluate_static(expr) if simplified is not None: return bool(simplified) except Exception: log.debug("Could not simplify %s", expr) return False def statically_known_equals(self, left: Expr, right: Expr) -> bool: """ Returns a bool indicating if it is sound to optimize as if left and right are equal. """ return self.is_expr_static_and_true(sympy.Eq(left, right)) # See Note - [On Statically Known] def statically_known_list_equals(self, left: List[Expr], right: List[Expr]) -> bool: """ Returns a bool indicating if it is sound to optimize as if left and right lists are equal. """ if len(left) != len(right): return False if all(self.statically_known_equals(l, r) for l, r in zip(left, right)): return True return False # See Note - [On Statically Known] def statically_known_leq(self, left: Expr, right: Expr) -> bool: """ Returns a bool indicating if it is sound to optimize as if left is less than or equal to right. """ expr = left <= right return self.is_expr_static_and_true(expr) # See Note - [On Statically Known] def statically_known_lt(self, left: Expr, right: Expr) -> bool: """ Returns a bool indicating if it is sound to optimize as if left is less than right. """ expr = left < right return self.is_expr_static_and_true(expr) # See Note - [On Statically Known] def statically_known_multiple_of(self, numerator: Expr, denominator: Expr) -> bool: """ Return a bool indicating if it is sound to optimize for the numerator being a multiple of the denominator. """ expr = sympy.Eq(numerator % denominator, 0) return self.is_expr_static_and_true(expr) # The guard functions require you to ALREADY KNOW that a particular # condition holds. If you don't know (you want to guard on an expression # being a particular value, and then get access to that value), use # the evaluate functions. def guard_equals(self, left: Expr, right: Expr) -> Expr: if isinstance(left, Expr): left = sympy_subs(left, self.inv_precomputed_replacements) if isinstance(right, Expr): right = sympy_subs(right, self.inv_precomputed_replacements) assert self.shape_env.evaluate_expr(sympy.Eq(left, right)) return left def guard_leq(self, left: Expr, right: Expr) -> None: return self.guard_lt(left, right + 1) def guard_lt(self, left: Expr, right: Expr) -> None: assert self.shape_env.evaluate_expr(sympy.Lt(left, right)) # The evaluate functions evaluate some symbolic sympy expression # (NB: not necessarily an Expr) and return what the concrete result # is, guarding on the expression being that result # NB: write evaluate_expr(sympy.Lt(a, b)) rather than evaluate_expr(a < b) # as this will ensure that you actually have a sympy'ified expression, # and will prevent you from incorrectly writing evaluate_expr(a == b) # which does the wrong thing if a or b is a sympy expression def evaluate_expr(self, left: Union[Expr, sympy.logic.boolalg.Boolean]) -> bool: assert isinstance(left, (Expr, sympy.logic.boolalg.Boolean)), type(left) return self.shape_env.evaluate_expr(sympy.sympify(left)) def evaluate_min(self, left: Expr, right: Expr) -> Expr: """return the smaller of left and right, and guard on that choice""" lv = self.size_hint(left) rv = self.size_hint(right) if lv <= rv: self.guard_leq(left, right) return left else: self.guard_leq(right, left) return right def evaluate_static_shape(self, left: Expr) -> int: right = self.size_hint(left) self.guard_equals(left, sympy.Integer(right)) return int(right) def evaluate_static_shapes(self, left: List[Expr]) -> List[int]: return [self.evaluate_static_shape(x) for x in left] def symbolic_hint(self, expr: Expr) -> Expr: # Substitute all hints into expr, but leave unbacked symints alone if not isinstance(expr, Expr): assert isinstance(expr, int) return expr free_symbols = expr.free_symbols if not free_symbols: return int(expr) while any(s.name.startswith("ps") for s in free_symbols): expr = sympy_subs(expr, self.inv_precomputed_replacements) free_symbols = expr.free_symbols return sympy_subs(expr, self.var_to_val) def size_hint(self, expr: Expr, *, fallback: Optional[int] = None) -> int: out = self.symbolic_hint(expr) if not isinstance(out, (int, sympy.Integer)) and fallback is not None: # Use the provided heuristic fallback hint sym_vrs = { s: self.shape_env.var_to_range.get(s, None) for s in expr.free_symbols } if all(vr is not None for vr in sym_vrs.values()): expr_vr = bound_sympy(expr, sym_vrs) lower = self.size_hint(expr_vr.lower) upper = self.size_hint(expr_vr.upper) fallback = min(max(fallback, lower), upper) return fallback try: return int(out) except Exception: log.debug("failed on: %s", out) raise def size_hints( self, exprs: Iterable[Expr], *, fallback: Optional[int] = None, ) -> Tuple[int, ...]: return tuple(self.size_hint(x, fallback=fallback) for x in exprs) def _lru_cache(self, fn, maxsize=None): """ Wrapper around functools.lru_cache that clears when replacements has been invalidated. """ fn_cache = functools.lru_cache(maxsize)(fn) prior_len = len(self.replacements) @functools.wraps(fn) def wrapper(*args, **kwargs): nonlocal prior_len if prior_len != len(self.replacements): prior_len = len(self.replacements) fn_cache.cache_clear() return fn_cache(*args, **kwargs) return wrapper def make_stride_vars_cache(self): cache = self._lru_cache(self._stride_vars) def stride_vars( index: Expr, vars: List[sympy.Symbol], support_vars: Optional[List[sympy.Symbol]] = None, ) -> List[Expr]: if not support_vars: support_vars = vars return cache(index, tuple(vars), tuple(support_vars)) return stride_vars def _stride_vars( self, index: Expr, vars: List[sympy.Symbol], support_vars: List[sympy.Symbol] ) -> List[Expr]: """Convert an indexing expression back into strides NOTE: This is only valid if the index is a standard strided offset calculation. e.g. 10 * ModularIndexing(i0 + 1, 1, 2) would give a stride of -10 because the index wraps around after the first element """ strides = [] index = self.simplify(index) # remove any offset index = index - sympy_subs( index, {v: sympy.Integer(0) for v in support_vars if v != 0} ) for i in range(len(vars)): # drop all the other dims index_dim = sympy_subs( index, { support_vars[j]: sympy.Integer(0) for j in range(len(support_vars)) if vars[i] != support_vars[j] and support_vars[j] != 0 }, ) v = vars[i] if v == 0: strides.append(sympy.Integer(0)) else: # TODO(jansel): should we use sympy.diff here? strides.append( sympy_subs(index_dim, {v: sympy.Integer(1)}) - sympy_subs(index_dim, {v: sympy.Integer(0)}) ) return strides def offset_var(self, index: Expr, vars: List[sympy.Symbol]) -> Expr: """Extract offset part of an indexing expression""" index = self.simplify(index) return sympy_subs(index, {v: sympy.Integer(0) for v in vars if v != 0}) def stride_hints( self, index: Expr, vars: List[sympy.Symbol], support_vars: Optional[List[sympy.Symbol]] = None, ) -> List[int]: for v in index.free_symbols: if v.name.startswith("indirect"): index = sympy_subs(index, {v: 0}) result = [] for s in self.stride_vars(index, vars, support_vars): try: result.append(self.size_hint(s)) except TypeError: result.append(0) return result def stride_order(self, index: Expr, vars: List[sympy.Symbol]) -> List[int]: strides = tuple(map(abs, self.stride_hints(index, vars))) order = list(range(len(strides))) order.sort(key=lambda x: (strides[x] == 0, strides[x])) return order def lookup_precomputed_size(self, expr: Expr) -> sympy.Symbol: if expr not in self.precomputed_replacements: sym = sympy_symbol(f"ps{len(self.precomputed_replacements)}") self.precomputed_replacements[expr] = sym self.inv_precomputed_replacements[sym] = expr return self.precomputed_replacements[expr] def free_symbols(self) -> Set[sympy.Symbol]: return set(self.var_to_val.keys()) - set(self.replacements.keys()) def join_dimensions(expr: Expr) -> Expr: if not isinstance(expr, sympy.Add) or not expr.has(ModularIndexing): return expr # fast exit path return _join_dimensions_cached(expr) @functools.lru_cache(256) def _join_dimensions_cached(expr: Expr) -> Expr: """ ModularIndexing(i0, 1, 32) + 32 * ModularIndexing(i0, 32, 4) becomes ModularIndexing(i0, 1, 128) ModularIndexing(i0, 1, 32) + 32 * FloorDiv(i0, 32) becomes i0 This type of pattern can come from view operations """ assert isinstance(expr, sympy.Add) scale = sympy.Wild("scale", exclude=[0]) base = sympy.Wild("base") divisor = sympy.Wild("divisor") mod1 = sympy.Wild("modulus") mod2 = sympy.Wild("modulus2") for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: for term2 in expr.args: m2 = term2.match( m1[scale] * m1[mod1] * ModularIndexing(m1[base], m1[divisor] * m1[mod1], mod2) ) if m2 and term1 != term2: expr = join_dimensions( expr - term1 - term2 + m1[scale] * ModularIndexing(m1[base], m1[divisor], m1[mod1] * m2[mod2]) ) return expr for term1 in expr.args: m1 = term1.match(scale * ModularIndexing(base, divisor, mod1)) if m1: for term2 in expr.args: m2 = term2.match( m1[scale] * m1[mod1] * FloorDiv(m1[base], m1[divisor] * m1[mod1]) ) if m2 is not None: # in case of success we get an empty dict here expr = join_dimensions( expr - term1 - term2 + m1[scale] * FloorDiv(m1[base], m1[divisor]) ) return expr return expr class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined] """ A wrapper around .virtualize.ops that uses var range information to simplify ModularIndexing/FloorDiv. """ def __init__(self, inner, var_ranges: VarRanges): super().__init__(inner) self.name = "SimplifyIndexing" self._simplify: Callable[ [Expr], Expr ] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) def load(self, name: str, index: sympy.Expr): return self._inner.load(name, self._simplify(index)) def store(self, name, index, value, mode=None): return self._inner.store(name, self._simplify(index), value, mode=mode) def store_reduction(self, name, index, value): return self._inner.store_reduction(name, self._simplify(index), value) def index_expr(self, index, dtype): return self._inner.index_expr(self._simplify(index), dtype)