Spaces:
Running
Running
import copy | |
from dataclasses import dataclass, field | |
from typing import Dict, List, Optional, Tuple, Type, Union | |
import torch.fx | |
from torch.fx._compatibility import compatibility | |
from torch.fx.graph import map_arg | |
from torch.fx.passes.utils import HolderModule, lift_subgraph_as_module | |
from .tools_common import NodeList | |
__all__ = ["getattr_recursive", "setattr_recursive", "Component", "split_by_tags"] | |
def getattr_recursive(obj, name): | |
for layer in name.split("."): | |
if hasattr(obj, layer): | |
obj = getattr(obj, layer) | |
else: | |
return None | |
return obj | |
def setattr_recursive(obj, attr, value): | |
if "." not in attr: | |
setattr(obj, attr, value) | |
else: | |
layer = attr.split(".") | |
setattr_recursive(getattr(obj, layer[0]), ".".join(layer[1:]), value) | |
class Component: | |
""" | |
A component serves as a container for a subgraph we want to create afterwards. | |
""" | |
graph: torch.fx.Graph | |
order: int | |
name: str | |
# Stores the placeholder nodes in `graph`. | |
input_placeholders: List = field(default_factory=list) | |
# Store the nodes in original graph that are placeholder in `graph`. | |
orig_inputs: List = field(default_factory=list) | |
# Store the nodes in original graph that are outputs in `graph`. | |
orig_outputs: List = field(default_factory=list) | |
# Mapping from get_attr node in original graph to get_attr node in `graph`. | |
getattr_maps: Dict[torch.fx.Node, torch.fx.Node] = field(default_factory=dict) | |
constructor_args: List[str] = field(default_factory=list) | |
gm: Optional[torch.fx.GraphModule] = None | |
def split_by_tags( | |
gm: torch.fx.GraphModule, | |
tags: List[str], | |
return_fqn_mapping: bool = False, | |
return_tuple: bool = False, | |
GraphModuleCls: Type[torch.fx.GraphModule] = torch.fx.GraphModule, | |
) -> Union[torch.fx.GraphModule, Tuple[torch.fx.GraphModule, Dict[str, str]]]: | |
""" | |
Splits a GraphModule using tags on its graph nodes. We honor the order of | |
tags. For example, we have tags = ["a", "b", "c"], the function will create | |
the initial submodules in the order of "a", "b", "c". | |
To set a tag: | |
gm.graph.nodes[idx].tag = "mytag" | |
This will result in all nodes with the same tag being extracted and placed in their | |
own submodule. For placeholder, output and get_attr node, the tag is ignored. placeholder | |
and output nodes are created when needed while get_attr nodes get copied to submodules | |
where they are used. | |
Given the following module def: | |
class SimpleModule(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.linear1 = torch.nn.Linear(...) | |
self.linear2 = torch.nn.Linear(...) | |
self.linear3 = torch.nn.Linear(...) | |
def forward(self, in1, in2): | |
r1 = self.linear1(in1) | |
r2 = self.linear2(in2) | |
r3 = torch.cat([r1, r2]) | |
return self.linear3(r3) | |
Marking the node corresponding to in1 with the tag sc.REQUEST_ONLY.lower() results in the following split: | |
ro: | |
def forward(self, in1): | |
self = self.root | |
linear1 = self.linear1(in1) | |
return linear1 | |
main: | |
def forward(self, in2, linear1): | |
self = self.root | |
linear2 = self.linear2(in2) | |
cat_1 = torch.cat([linear1, linear2]) | |
linear3 = self.linear3(cat_1) | |
return linear3 | |
main: | |
def forward(self, in1, in2): | |
self = self.root | |
ro_0 = self.ro_0(in1) | |
main_1 = self.main_1(in2, ro_0) | |
return main_1 | |
Returns: | |
split_gm: torch fx graph after split | |
orig_to_split_fqn_mapping: a map between the original fqn and the fqn | |
after split for call_module and get_attr. | |
""" | |
def flatten(x: torch.fx.node.Argument) -> NodeList: | |
""" | |
Stores nodes in x to a list and returns the list. | |
""" | |
r: NodeList = [] | |
map_arg(x, r.append) | |
return r | |
# Mapping from node in original module to node in created submodule. | |
node_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} | |
# Mapping from node in original module or created submodules to | |
# corresponding component. | |
node_to_component: Dict[torch.fx.Node, Component] = {} | |
# Mapping from tag to the corresponding component. | |
tag_to_component: Dict[str, Component] = {} | |
# Stores all components. | |
all_components: List[Component] = [] | |
# Stores nodes that will be used in main graph. | |
used_in_main: Dict[torch.fx.Node, None] = {} | |
# Main graph after split. | |
main_g = torch.fx.Graph() | |
# Mapping from node in original module to node in main graph after split. | |
main_remapping: Dict[torch.fx.Node, torch.fx.Node] = {} | |
# Output node of original module. | |
output_node: Optional[torch.fx.Node] = None | |
# Create a component for each tag, we don't expect to create other components afterwards. | |
for tag in tags: | |
comp = Component(torch.fx.Graph(), len(all_components), f"{tag}") | |
all_components.append(comp) | |
tag_to_component[tag] = comp | |
# Traverse the nodes in original graph and take care of them. | |
for node in gm.graph.nodes: | |
if node.op == "output": | |
if output_node is not None: | |
raise RuntimeError("Multiple output nodes in graph!") | |
output_node = node | |
continue | |
# Placeholders in the original graph get copied to main graph. | |
if node.op == "placeholder": | |
main_remapping[node] = main_g.placeholder(node.name, type_expr=node.type) | |
main_remapping[node].meta = copy.copy(node.meta) | |
continue | |
# Get_attr nodes are ignored because we are not tagging them. | |
# Instead, we copy them directly to the submodules use them afterwards. | |
if node.op == "get_attr": | |
continue | |
# Now we process callable nodes which are nodes with op of call_module, | |
# call_function or call_method. Every callable nodes should be tagged. | |
assert hasattr(node, "tag") | |
upstream_components = [ | |
node_to_component[x] | |
for x in flatten(node.args) + flatten(node.kwargs) | |
if x.op not in {"placeholder", "get_attr"} | |
] | |
comp = tag_to_component[node.tag] | |
node_to_component[node] = comp | |
# Max order of upperstream components. | |
mx = max((c.order for c in upstream_components), default=0) | |
# Expect the component for `node` has higher order then its upstream components. | |
assert comp.order >= mx | |
# Map a input of `node` to nodes in the component's graph. | |
def remap_func(x): | |
# If input is a get_attr node, copy it to current component's graph. | |
# Returns the get_attr node in current component's graph. | |
if x.op == "get_attr": | |
if x not in comp.getattr_maps: | |
comp.getattr_maps[x] = comp.graph.get_attr( | |
x.target, type_expr=x.type | |
) | |
return comp.getattr_maps[x] | |
# If input is not a placeholder, it should have been put into a component | |
# already. If it's the current component then we return the corresponding | |
# node in the component. | |
if x.op != "placeholder" and node_to_component[x] == comp: | |
return node_remapping[x] | |
# If input is a placeholder or it's in other components, we want to make it | |
# as a placeholder in current component's graph. | |
if x not in comp.orig_inputs: | |
comp.orig_inputs.append(x) | |
placeholder = comp.graph.placeholder(x.name, type_expr=x.type) | |
placeholder.meta = copy.copy(x.meta) | |
comp.input_placeholders.append(placeholder) | |
used_in_main[x] = None | |
return comp.input_placeholders[comp.orig_inputs.index(x)] | |
n = comp.graph.node_copy(node, remap_func) | |
n.tag = node.tag # type: ignore[attr-defined] | |
node_remapping[node] = n | |
node_to_component[n] = comp | |
if output_node is None: | |
raise RuntimeError("Graph had no output node!") | |
for x in flatten(output_node.args[0]): | |
if x.op == "get_attr": | |
# We don't need components mapping for nodes of type "get_attr" | |
# that are consumed by the output. Only need to make sure we create | |
# corresponding counterparts in the resulting graph. | |
main_remapping[x] = main_g.get_attr(x.name, type_expr=x.type) | |
else: | |
# All component results consumed by the output node should be | |
# marked as "used in main". | |
used_in_main[x] = None | |
# If a node is used in main graph then we mark it as an output in the component | |
# it belongs to. | |
for n in used_in_main: | |
if n.op != "placeholder": | |
node_to_component[n].orig_outputs.append(n) | |
# Now we create a graphmodule for each component. | |
orig_to_split_fqn_mapping: Dict[str, str] = {} | |
for comp in all_components: | |
outs = tuple(map(node_remapping.__getitem__, comp.orig_outputs)) | |
if return_tuple: | |
comp.graph.output(outs) | |
else: | |
# Take care of the args of FX output node. If there's a single | |
# output then the output node args is like (output_single), else | |
# if there're multiple outputs then the output node args is like | |
# ((output_0, output_1, ...)). | |
comp.graph.output(outs[0] if len(outs) == 1 else outs) | |
comp.gm, comp_orig_to_split_fqn_mapping = lift_subgraph_as_module( | |
gm, subgraph=comp.graph, comp_name=comp.name | |
) | |
orig_to_split_fqn_mapping.update(comp_orig_to_split_fqn_mapping) | |
# Create a call_module node in main graph. | |
main_node = main_g.call_module( | |
comp.name, | |
args=tuple(map(main_remapping.__getitem__, comp.orig_inputs)), | |
kwargs=None, | |
) | |
if len(outs) == 1 and not return_tuple: | |
main_remapping[comp.orig_outputs[0]] = main_node | |
else: | |
for i, o in enumerate(comp.orig_outputs): | |
# Use Proxy to record getitem access. | |
main_remapping[o] = torch.fx.Proxy(main_node)[i].node # type: ignore[index] | |
main_g.output(map_arg(output_node.args[0], main_remapping.__getitem__)) | |
main_root = HolderModule({comp.name: comp.gm for comp in all_components}) | |
main_g._codegen = gm.graph._codegen | |
# If the output nodes consumes get_attr directly in the original graph, | |
# then we need to make sure get_attr is copied to the new graph. | |
for x in flatten(output_node.args[0]): | |
if x.op == "get_attr": | |
setattr(main_root, x.name, getattr_recursive(gm, x.target)) # type: ignore[arg-type] | |
result_gm = GraphModuleCls(main_root, main_g) | |
if return_fqn_mapping: | |
return result_gm, orig_to_split_fqn_mapping | |
return result_gm | |