|
from dataclasses import dataclass, field |
|
from typing import List, Optional, Dict |
|
|
|
import torch.fx |
|
from torch.fx.graph import map_arg |
|
from .tools_common import NodeList |
|
from torch.fx._compatibility import compatibility |
|
from torch.fx.passes.utils import lift_subgraph_as_module, HolderModule |
|
|
|
__all__ = ['getattr_recursive', 'setattr_recursive', 'Component', 'split_by_tags'] |
|
|
|
@compatibility(is_backward_compatible=False) |
|
def getattr_recursive(obj, name): |
|
for layer in name.split("."): |
|
if hasattr(obj, layer): |
|
obj = getattr(obj, layer) |
|
else: |
|
return None |
|
return obj |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def setattr_recursive(obj, attr, value): |
|
if "." not in attr: |
|
setattr(obj, attr, value) |
|
else: |
|
layer = attr.split(".") |
|
setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
@dataclass |
|
class Component: |
|
""" |
|
A component serves as a container for a subgraph we want to create afterwards. |
|
""" |
|
|
|
graph: torch.fx.Graph |
|
order: int |
|
name: str |
|
|
|
|
|
input_placeholders: List = field(default_factory=list) |
|
|
|
|
|
orig_inputs: List = field(default_factory=list) |
|
|
|
|
|
orig_outputs: List = field(default_factory=list) |
|
|
|
|
|
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) |
|
constructor_args: List[str] = field(default_factory=list) |
|
gm: Optional[torch.fx.GraphModule] = None |
|
|
|
|
|
@compatibility(is_backward_compatible=False) |
|
def split_by_tags(gm: torch.fx.GraphModule, tags: List[str]) -> torch.fx.GraphModule: |
|
""" |
|
Splits a GraphModule using tags on its graph nodes. We honor the order of |
|
tags. For example, we have tags = ["a", "b", "c"], the function will create |
|
the initial submodules in the order of "a_0", "b_1", "c_2". |
|
|
|
To set a tag: |
|
gm.graph.nodes[idx].tag = "mytag" |
|
|
|
This will result in all nodes with the same tag being extracted and placed in their |
|
own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder |
|
and output nodes are created when needed while get_attr nodes get copied to submodules |
|
where they are used. |
|
|
|
Given the following module def: |
|
|
|
class SimpleModule(torch.nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
self.linear1 = torch.nn.Linear(...) |
|
self.linear2 = torch.nn.Linear(...) |
|
self.linear3 = torch.nn.Linear(...) |
|
|
|
def forward(self, in1, in2): |
|
r1 = self.linear1(in1) |
|
r2 = self.linear2(in2) |
|
r3 = torch.cat([r1, r2]) |
|
return self.linear3(r3) |
|
|
|
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: |
|
|
|
ro_0: |
|
def forward(self, in1): |
|
self = self.root |
|
linear1 = self.linear1(in1) |
|
return linear1 |
|
|
|
main_1: |
|
def forward(self, in2, linear1): |
|
self = self.root |
|
linear2 = self.linear2(in2) |
|
cat_1 = torch.cat([linear1, linear2]) |
|
linear3 = self.linear3(cat_1) |
|
return linear3 |
|
|
|
main_0: |
|
def forward(self, in1, in2): |
|
self = self.root |
|
ro_0 = self.ro_0(in1) |
|
main_1 = self.main_1(in2, ro_0) |
|
return main_1 |
|
""" |
|
|
|
def flatten(x: torch.fx.node.Argument) -> NodeList: |
|
""" |
|
Stores nodes in x to a list and returns the list. |
|
""" |
|
r: NodeList = [] |
|
map_arg(x, r.append) |
|
return r |
|
|
|
|
|
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} |
|
|
|
|
|
|
|
node_to_component: Dict[torch.fx.Node, Component] = {} |
|
|
|
|
|
tag_to_component: Dict[str, Component] = {} |
|
|
|
|
|
all_components: List[Component] = [] |
|
|
|
|
|
used_in_main: Dict[torch.fx.Node, None] = {} |
|
|
|
|
|
main_g = torch.fx.Graph() |
|
|
|
|
|
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} |
|
|
|
|
|
output_node: Optional[torch.fx.Node] = None |
|
|
|
|
|
for tag in tags: |
|
comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") |
|
all_components.append(comp) |
|
tag_to_component[tag] = comp |
|
|
|
|
|
for node in gm.graph.nodes: |
|
if node.op == "output": |
|
if output_node is not None: |
|
raise RuntimeError("Multiple output nodes in graph!") |
|
output_node = node |
|
continue |
|
|
|
|
|
if node.op == "placeholder": |
|
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) |
|
continue |
|
|
|
|
|
|
|
if node.op == "get_attr": |
|
continue |
|
|
|
|
|
|
|
assert hasattr(node, "tag") |
|
|
|
upstream_components = [ |
|
node_to_component[x] |
|
for x in flatten(node.args) + flatten(node.kwargs) |
|
if x.op not in {"placeholder", "get_attr"} |
|
] |
|
|
|
comp = tag_to_component[node.tag] |
|
node_to_component[node] = comp |
|
|
|
|
|
mx = max((c.order for c in upstream_components), default=0) |
|
|
|
|
|
assert comp.order >= mx |
|
|
|
|
|
def remap_func(x): |
|
|
|
|
|
if x.op == "get_attr": |
|
if x not in comp.getattr_maps: |
|
comp.getattr_maps[x] = comp.graph.get_attr( |
|
x.target, type_expr=x.type |
|
) |
|
return comp.getattr_maps[x] |
|
|
|
|
|
|
|
|
|
if x.op != "placeholder" and node_to_component[x] == comp: |
|
return node_remapping[x] |
|
|
|
|
|
|
|
if x not in comp.orig_inputs: |
|
comp.orig_inputs.append(x) |
|
comp.input_placeholders.append( |
|
comp.graph.placeholder(x.name, type_expr=x.type) |
|
) |
|
used_in_main[x] = None |
|
|
|
return comp.input_placeholders[ |
|
next(i for i, y in enumerate(comp.orig_inputs) if x is y) |
|
] |
|
|
|
n = comp.graph.node_copy(node, remap_func) |
|
n.tag = node.tag |
|
node_remapping[node] = n |
|
node_to_component[n] = comp |
|
|
|
if output_node is None: |
|
raise RuntimeError("Graph had no output node!") |
|
|
|
for x in flatten(output_node.args[0]): |
|
if x.op == "get_attr": |
|
|
|
|
|
|
|
main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) |
|
else: |
|
|
|
|
|
used_in_main[x] = None |
|
|
|
|
|
|
|
for n in used_in_main: |
|
if n.op != "placeholder": |
|
node_to_component[n].orig_outputs.append(n) |
|
|
|
|
|
for comp in all_components: |
|
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) |
|
|
|
|
|
|
|
|
|
|
|
comp.graph.output(outs[0] if len(outs) == 1 else outs) |
|
|
|
comp.gm = lift_subgraph_as_module(gm, comp.graph) |
|
|
|
|
|
main_node = main_g.call_module( |
|
comp.name, |
|
args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), |
|
kwargs=None, |
|
) |
|
|
|
if len(outs) == 1: |
|
main_remapping[comp.orig_outputs[0]] = main_node |
|
else: |
|
for i, o in enumerate(comp.orig_outputs): |
|
|
|
main_remapping[o] = torch.fx.Proxy(main_node)[i].node |
|
|
|
main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) |
|
main_root = HolderModule({comp.name: comp.gm for comp in all_components}) |
|
|
|
|
|
|
|
for x in flatten(output_node.args[0]): |
|
if x.op == "get_attr": |
|
setattr(main_root, x.name, getattr_recursive(gm, x.target)) |
|
|
|
return torch.fx.GraphModule(main_root, main_g) |
|
|