|
from .graph_module import GraphModule |
|
from .graph import Graph |
|
from .node import Node |
|
from ._symbolic_trace import symbolic_trace |
|
from ._compatibility import compatibility |
|
|
|
import copy |
|
from typing import Callable, Dict, List, NamedTuple, Optional, Set |
|
import torch |
|
|
|
__all__ = ['Match', 'replace_pattern'] |
|
|
|
@compatibility(is_backward_compatible=True) |
|
class Match(NamedTuple): |
|
|
|
anchor: Node |
|
|
|
nodes_map: Dict[Node, Node] |
|
|
|
|
|
def _replace_submodules(gm: GraphModule, replacement: torch.nn.Module) -> None: |
|
gm.delete_all_unused_submodules() |
|
|
|
if isinstance(replacement, GraphModule): |
|
replacement.graph.lint() |
|
|
|
def try_get_submodule(mod: torch.nn.Module, target: str) -> Optional[torch.nn.Module]: |
|
try: |
|
mod_match = mod.get_submodule(target) |
|
return mod_match |
|
except AttributeError: |
|
return None |
|
|
|
for node in gm.graph.nodes: |
|
if node.op == "call_module" or node.op == "get_attr": |
|
|
|
gm_submod = try_get_submodule(gm, node.target) |
|
|
|
replacement_submod = try_get_submodule(replacement, node.target) |
|
|
|
|
|
|
|
|
|
if gm_submod is not None: |
|
continue |
|
|
|
|
|
|
|
elif replacement_submod is not None: |
|
new_submod = copy.deepcopy(getattr(replacement, node.target)) |
|
gm.add_submodule(node.target, new_submod) |
|
|
|
|
|
|
|
else: |
|
raise RuntimeError("Attempted to create a \"", node.op, |
|
"\" node during subgraph rewriting " |
|
f"with target {node.target}, but " |
|
"the referenced submodule does not " |
|
"exist in either the original " |
|
"GraphModule `gm` or the replacement" |
|
" GraphModule `replacement`") |
|
|
|
gm.graph.lint() |
|
|
|
@compatibility(is_backward_compatible=True) |
|
def replace_pattern(gm: GraphModule, pattern: Callable, replacement: Callable) -> List[Match]: |
|
""" |
|
Matches all possible non-overlapping sets of operators and their |
|
data dependencies (``pattern``) in the Graph of a GraphModule |
|
(``gm``), then replaces each of these matched subgraphs with another |
|
subgraph (``replacement``). |
|
|
|
Args: |
|
``gm``: The GraphModule that wraps the Graph to operate on |
|
``pattern``: The subgraph to match in ``gm`` for replacement |
|
``replacement``: The subgraph to replace ``pattern`` with |
|
|
|
Returns: |
|
List[Match]: A list of ``Match`` objects representing the places |
|
in the original graph that ``pattern`` was matched to. The list |
|
is empty if there are no matches. ``Match`` is defined as: |
|
|
|
.. code-block:: python |
|
|
|
class Match(NamedTuple): |
|
# Node from which the match was found |
|
anchor: Node |
|
# Maps nodes in the pattern subgraph to nodes in the larger graph |
|
nodes_map: Dict[Node, Node] |
|
|
|
Examples: |
|
|
|
.. code-block:: python |
|
|
|
import torch |
|
from torch.fx import symbolic_trace, subgraph_rewriter |
|
|
|
class M(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
|
|
def forward(self, x, w1, w2): |
|
m1 = torch.cat([w1, w2]).sum() |
|
m2 = torch.cat([w1, w2]).sum() |
|
return x + torch.max(m1) + torch.max(m2) |
|
|
|
def pattern(w1, w2): |
|
return torch.cat([w1, w2]).sum() |
|
|
|
def replacement(w1, w2): |
|
return torch.stack([w1, w2]) |
|
|
|
traced_module = symbolic_trace(M()) |
|
|
|
subgraph_rewriter.replace_pattern(traced_module, pattern, replacement) |
|
|
|
The above code will first match ``pattern`` in the ``forward`` |
|
method of ``traced_module``. Pattern-matching is done based on |
|
use-def relationships, not node names. For example, if you had |
|
``p = torch.cat([a, b])`` in ``pattern``, you could match |
|
``m = torch.cat([a, b])`` in the original ``forward`` function, |
|
despite the variable names being different (``p`` vs ``m``). |
|
|
|
The ``return`` statement in ``pattern`` is matched based on its |
|
value only; it may or may not match to the ``return`` statement in |
|
the larger graph. In other words, the pattern doesn't have to extend |
|
to the end of the larger graph. |
|
|
|
When the pattern is matched, it will be removed from the larger |
|
function and replaced by ``replacement``. If there are multiple |
|
matches for ``pattern`` in the larger function, each non-overlapping |
|
match will be replaced. In the case of a match overlap, the first |
|
found match in the set of overlapping matches will be replaced. |
|
("First" here being defined as the first in a topological ordering |
|
of the Nodes' use-def relationships. In most cases, the first Node |
|
is the parameter that appears directly after ``self``, while the |
|
last Node is whatever the function returns.) |
|
|
|
One important thing to note is that the parameters of the |
|
``pattern`` Callable must be used in the Callable itself, |
|
and the parameters of the ``replacement`` Callable must match |
|
the pattern. The first rule is why, in the above code block, the |
|
``forward`` function has parameters ``x, w1, w2``, but the |
|
``pattern`` function only has parameters ``w1, w2``. ``pattern`` |
|
doesn't use ``x``, so it shouldn't specify ``x`` as a parameter. |
|
As an example of the second rule, consider replacing |
|
|
|
.. code-block:: python |
|
|
|
def pattern(x, y): |
|
return torch.neg(x) + torch.relu(y) |
|
|
|
with |
|
|
|
.. code-block:: python |
|
|
|
def replacement(x, y): |
|
return torch.relu(x) |
|
|
|
In this case, ``replacement`` needs the same number of parameters |
|
as ``pattern`` (both ``x`` and ``y``), even though the parameter |
|
``y`` isn't used in ``replacement``. |
|
|
|
After calling ``subgraph_rewriter.replace_pattern``, the generated |
|
Python code looks like this: |
|
|
|
.. code-block:: python |
|
|
|
def forward(self, x, w1, w2): |
|
stack_1 = torch.stack([w1, w2]) |
|
sum_1 = stack_1.sum() |
|
stack_2 = torch.stack([w1, w2]) |
|
sum_2 = stack_2.sum() |
|
max_1 = torch.max(sum_1) |
|
add_1 = x + max_1 |
|
max_2 = torch.max(sum_2) |
|
add_2 = add_1 + max_2 |
|
return add_2 |
|
""" |
|
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher, InternalMatch |
|
|
|
|
|
original_graph: Graph = gm.graph |
|
pattern_graph: Graph = symbolic_trace(pattern).graph |
|
replacement_graph: Graph = symbolic_trace(replacement).graph |
|
|
|
matcher = SubgraphMatcher(pattern_graph, match_output=False, match_placeholder=False, |
|
remove_overlapping_matches=True) |
|
_matches: List[InternalMatch] = matcher.match(original_graph) |
|
|
|
replacement_placeholders = [n for n in replacement_graph.nodes if n.op == "placeholder"] |
|
|
|
|
|
match_changed_node: Dict[Node, Node] = {} |
|
|
|
for match in _matches: |
|
|
|
|
|
|
|
|
|
|
|
assert len(match.placeholder_nodes) == len(replacement_placeholders) |
|
val_map: Dict[Node, Node] = {} |
|
for rn, gn in zip(replacement_placeholders, match.placeholder_nodes): |
|
val_map[rn] = match_changed_node.get(gn, gn) |
|
|
|
|
|
user_nodes: Set[Node] = set() |
|
for n in match.returning_nodes: |
|
for user in n.users: |
|
user_nodes.add(user) |
|
assert user_nodes, "The returning_nodes should have at least one user node" |
|
|
|
if len(user_nodes) == 1: |
|
first_user_node = list(user_nodes)[0] |
|
else: |
|
|
|
|
|
for n in original_graph.nodes: |
|
if n in user_nodes: |
|
first_user_node = n |
|
break |
|
|
|
with original_graph.inserting_before(first_user_node): |
|
copied_returning_nodes = original_graph.graph_copy(replacement_graph, val_map) |
|
|
|
if isinstance(copied_returning_nodes, Node): |
|
copied_returning_nodes = (copied_returning_nodes, ) |
|
|
|
|
|
|
|
assert len(match.returning_nodes) == len(copied_returning_nodes) |
|
for gn, copied_node in zip(match.returning_nodes, copied_returning_nodes): |
|
gn.replace_all_uses_with(copied_node) |
|
match_changed_node[gn] = copied_node |
|
|
|
for node in reversed(pattern_graph.nodes): |
|
if node.op != "placeholder" and node.op != "output": |
|
gn = match.nodes_map[node] |
|
gm.graph.erase_node(gn) |
|
|
|
|
|
|
|
gm.recompile() |
|
|
|
|
|
|
|
if isinstance(replacement, torch.nn.Module): |
|
_replace_submodules(gm, replacement) |
|
|
|
|
|
matches: List[Match] = [Match(anchor=match.anchors[0], nodes_map=match.nodes_map) for match in _matches] |
|
return matches |
|
|