import collections from typing import Any, Callable, Dict, Optional import torch import torch.utils._pytree as pytree aten = torch.ops.aten def replace_node_with_constant(gm, node, constant): g = gm.graph if not hasattr(gm, "_frozen_param_count"): gm._frozen_param_count = 0 i = gm._frozen_param_count while True: qualname = f"_frozen_param{i}" if not hasattr(gm, qualname): break i += 1 gm._frozen_param_count = i + 1 with g.inserting_before(node): new_input_node = g.create_node("get_attr", qualname, (), {}) node.replace_all_uses_with(new_input_node) new_input_node.meta.update(node.meta) g.erase_node(node) # needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning gm.register_buffer(qualname, constant) setattr(gm, qualname, constant) class ConstantFolder(torch.fx.Interpreter): def __init__( self, gm, skip_constructors=False, ): super().__init__(gm) self.node_replacements: Dict[torch.fx.Node, Any] = {} self.replaced_uses: Dict[torch.fx.Node, int] = collections.Counter() self.unknown_value = object() self.skip_constructors: bool = skip_constructors # overwrite this to deallocate env values if their only remaining use # is the output self.user_to_last_uses = self.node_to_last_non_output_use() def is_impure(self, node: torch.fx.node.Node): if node.target in [ torch.ops.quantized_decomposed.dequantize_per_channel.default, torch.ops.quantized_decomposed.dequantize_per_tensor.default, torch.ops.quantized_decomposed.dequantize_per_tensor.tensor, ]: # For the pattern fp32_weight -> q -> dq # We only folding fp32_weight -> q # int8_weight and leave dq in graph to be fused return True return False def node_to_last_non_output_use(self): last_non_output_use = collections.defaultdict(list) seen_uses = set() output_node = next(iter(reversed(self.module.graph.nodes))) for node in reversed(self.module.graph.nodes): if node.target == "output": continue def add_use(inp): if inp in seen_uses: return seen_uses.add(inp) last_non_output_use[node].append(inp) pytree.tree_map_only(torch.fx.Node, add_use, (node.args, node.kwargs)) # if this node is only used in output, we want to gc it right away if len(node.users) == 1 and output_node in node.users: last_non_output_use[node].append(node) return last_non_output_use def run_node(self, node): if node.target == "output": # because we remove nodes from env on last non output use, # re-define them now or we'll get error in interpreter def set_env(arg): self.env[arg] = self.unknown_value pytree.tree_map_only(torch.fx.Node, set_env, node.args) return super().run_node(node) args, kwargs = self.fetch_args_kwargs_from_env(node) flattened_inputs = pytree.arg_tree_leaves(*args, **kwargs) if self.unknown_value in flattened_inputs: return self.unknown_value # TODO - fix errors with this if ( node.op == "call_function" and node.target == aten._efficientzerotensor.default ): return self.unknown_value # skip constructors, since inductor generates optimal code for them already # and turning into tensor would result in an additional global memory read # TODO - more complicated strategy if ( self.skip_constructors and node.op != "get_attr" and not any(isinstance(e, torch.Tensor) for e in flattened_inputs) ): return self.unknown_value # All mutations should either be removed or on inputs which we did not make constant if ( isinstance(node.target, torch._ops.OpOverload) and torch.Tag.nondeterministic_seeded in node.target.tags ): return self.unknown_value out = super().run_node(node) if node.op != "get_attr" and isinstance(out, torch.Tensor): if not self.insertable_tensor_check(out): return out if self.is_impure(node): return self.unknown_value self.add_node_replacement(node, out) flattened_node_inps = pytree.arg_tree_leaves(*node.args, **node.kwargs) for n in flattened_node_inps: if not isinstance(n, torch.fx.Node): continue self.replaced_uses[n] += 1 for to_delete in self.user_to_last_uses.get(node, []): if self.replaced_uses[to_delete] == len(to_delete.users): self.node_replacements.pop(to_delete, None) return out def insertable_tensor_check(self, tensor: torch.Tensor) -> bool: return True def add_node_replacement(self, node: torch.fx.Node, tensor: torch.Tensor) -> None: self.node_replacements[node] = tensor def run(self): env = {} for n in self.module.graph.nodes: if n.op == "placeholder": env[n] = self.unknown_value return super().run(initial_env=env) @torch.utils._python_dispatch._disable_current_modes() def constant_fold(gm, constraint_fn: Optional[Callable[[torch.fx.Node], bool]] = None): cf = ConstantFolder(gm, skip_constructors=True) cf.run() for node, constant in cf.node_replacements.items(): if constraint_fn is not None and not constraint_fn(node): continue replace_node_with_constant(gm, node, constant) erased_params = [] for node in gm.graph.nodes: if node.op == "get_attr" and len(node.users) == 0: if hasattr(gm, node.target): delattr(gm, node.target) erased_params.append(node) for node in erased_params: gm.graph.erase_node(node) gm.graph.eliminate_dead_code() gm.graph.lint() gm.recompile()