|
from torch.nn import Module |
|
|
|
from torch.fx.graph_module import GraphModule |
|
from torch.fx.graph import Graph |
|
from torch.fx.passes.utils.matcher_utils import SubgraphMatcher |
|
from torch.fx._compatibility import compatibility |
|
|
|
|
|
__all__ = ['HolderModule', 'lift_subgraph_as_module', 'compare_graphs'] |
|
|
|
@compatibility(is_backward_compatible=False) |
|
class HolderModule(Module): |
|
""" |
|
HolderModule is used to copy all the attributes from original module to submodules |
|
that uses the attributes |
|
""" |
|
|
|
def __init__(self, d): |
|
super().__init__() |
|
for k, v in d.items(): |
|
self.add_module(k, v) |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def lift_subgraph_as_module(gm: GraphModule, subgraph: Graph, class_name: str = 'GraphModule') -> GraphModule: |
|
""" |
|
Create a GraphModule for subgraph, which copies the necessory attributes from the original parent graph_module. |
|
|
|
Args: |
|
gm (GraphModule): parent graph module |
|
|
|
subgraph (Graph): a valid subgraph that contains copied nodes from the parent graph |
|
|
|
class_name (str): name for the submodule |
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
submodule = HolderModule({}) |
|
for n in subgraph.nodes: |
|
if n.op not in ("call_module", "get_attr"): |
|
continue |
|
|
|
target = n.target |
|
assert isinstance(target, str) |
|
target_name_parts = target.split(".") |
|
curr = submodule |
|
orig_gm = gm |
|
|
|
for name in target_name_parts[:-1]: |
|
if not hasattr(curr, name): |
|
curr.add_module(name, HolderModule({})) |
|
|
|
curr = getattr(curr, name) |
|
orig_gm = getattr(orig_gm, name) |
|
|
|
leaf_node_name = target_name_parts[-1] |
|
leaf_node = getattr(orig_gm, leaf_node_name) |
|
|
|
|
|
setattr(curr, leaf_node_name, leaf_node) |
|
|
|
return GraphModule(submodule, subgraph, class_name) |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def compare_graphs(left: Graph, right: Graph) -> bool: |
|
""" |
|
Return True if two graphs are identical, i.e they |
|
- have the same number of outputs in the same order |
|
- have the same number of inputs in the same order |
|
- have the same set of nodes, and identical connectivity |
|
""" |
|
|
|
matcher = SubgraphMatcher(left, match_output=True, match_placeholder=True) |
|
matches = matcher.match(right) |
|
|
|
return len(matches) > 0 |
|
|