Spaces:
Running
Running
import inspect | |
from typing import Any, Callable, Dict, List, Optional, Set, TYPE_CHECKING | |
from collections import OrderedDict | |
import logging | |
import torch | |
from torch.fx._compatibility import compatibility | |
from torch.fx.graph_module import GraphModule | |
from torch.fx.node import Node | |
if TYPE_CHECKING: | |
import sympy # noqa: F401 | |
__all__ = ["Partition", "split_module"] | |
_LOGGER = logging.getLogger(__name__) | |
class Partition: | |
def __init__(self, name: str): | |
self.name: str = name | |
self.submod_name = f"submod_{name}" | |
self.node_names: List[str] = [] | |
self.inputs: Dict[str, None] = {} | |
self.outputs: Dict[str, None] = {} | |
self.dependencies: Dict[str, None] = {} | |
self.dependents: Dict[str, None] = {} | |
self.graph: torch.fx.graph.Graph = torch.fx.graph.Graph() | |
self.environment: Dict[Node, Node] = {} | |
self.targets: Dict[str, Any] = {} | |
def __repr__(self) -> str: | |
return ( | |
f"name: {self.name},\n" | |
f" nodes: {self.node_names},\n" | |
f" inputs: {self.inputs},\n" | |
f" outputs: {self.outputs},\n" | |
f" partitions depended on: {self.dependencies},\n" | |
f" partition dependents: {self.dependents}" | |
) | |
# Creates subgraphs out of main graph | |
def split_module( | |
m: GraphModule, | |
root_m: torch.nn.Module, | |
split_callback: Callable[[Node], int], | |
qualname_map: Optional[Dict[str, str]] = None, | |
keep_original_order: Optional[bool] = False, | |
keep_original_node_name: Optional[bool] = False, | |
): | |
""" | |
Creates subgraphs out of main graph | |
Args: | |
m (GraphModule): Graph module to split | |
root_m (torch.nn.Module): root nn module. Not currently used. Included | |
because the root nn module is usually transformed via | |
torch.fx._symbolic_trace.symbolic_trace (see example below) | |
split_callback (Callable[[Node], int]): Callable function | |
that maps a given Node instance to a numeric partition identifier. | |
split_module will use this function as the policy for which operations | |
appear in which partitions in the output Module. | |
qualname_map: Optional[Dict[str, str]]: optional output parameter that returns a | |
mapping from new target names in the module after split to old target | |
names in the original module. | |
keep_original_order: Optional[bool]: keep the original order of the GraphModule | |
or use the Topological order of the new constructed GraphModule | |
Returns: | |
GraphModule: the module after split. | |
Example: | |
This is a sample setup: | |
import torch | |
from torch.fx.symbolic_trace import symbolic_trace | |
from torch.fx.graph_module import GraphModule | |
from torch.fx.node import Node | |
from torch.fx.passes.split_module import split_module | |
class MyModule(torch.nn.Module): | |
def __init__(self): | |
super().__init__() | |
self.param = torch.nn.Parameter(torch.rand(3, 4)) | |
self.linear = torch.nn.Linear(4, 5) | |
def forward(self, x, y): | |
z = self.linear(x + self.param).clamp(min=0.0, max=1.0) | |
w = self.linear(y).clamp(min=0.0, max=1.0) | |
return z + w | |
# symbolically trace model | |
my_module = MyModule() | |
my_module_traced = symbolic_trace(my_module) | |
# random mod partitioning | |
partition_counter = 0 | |
NPARTITIONS = 3 | |
def mod_partition(node: Node): | |
global partition_counter | |
partition = partition_counter % NPARTITIONS | |
partition_counter = (partition_counter + 1) % NPARTITIONS | |
return partition | |
# split module in module with submodules | |
module_with_submodules = split_module( | |
my_module_traced, my_module, mod_partition | |
) | |
Output looks like this. Original graph is broken into partitions | |
> print(module_with_submodules) | |
GraphModule( | |
(submod_0): GraphModule( | |
(linear): Linear(in_features=4, out_features=5, bias=True) | |
) | |
(submod_1): GraphModule( | |
(linear): Linear(in_features=4, out_features=5, bias=True) | |
) | |
(submod_2): GraphModule() | |
) | |
def forward(self, x, y): | |
param = self.param | |
submod_0 = self.submod_0(x, param, y); x = param = y = None | |
getitem = submod_0[0] | |
getitem_1 = submod_0[1]; submod_0 = None | |
submod_1 = self.submod_1(getitem, getitem_1); getitem = getitem_1 = None | |
getitem_2 = submod_1[0] | |
getitem_3 = submod_1[1]; submod_1 = None | |
submod_2 = self.submod_2(getitem_2, getitem_3); getitem_2 = getitem_3 = None | |
return submod_2 | |
Output of split module is the same as output of input traced module. | |
This is an example within a test setting: | |
> orig_out = my_module_traced(x, y) | |
> submodules_out = module_with_submodules(x, y) | |
> self.assertEqual(orig_out, submodules_out) | |
True | |
""" | |
def construct_graph( | |
node: Node, | |
base_mod_env: Dict[str, Node], | |
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule], | |
): | |
if node.op == "placeholder": | |
default_value = ( | |
node.args[0] if len(node.args) > 0 else inspect.Signature.empty | |
) | |
if keep_original_node_name: | |
args = () if default_value is inspect.Signature.empty else (default_value,) | |
base_mod_env[node.name] = base_mod_graph.create_node('placeholder', node.name, args=args, type_expr=node.type) | |
else: | |
base_mod_env[node.name] = base_mod_graph.placeholder( | |
node.target, type_expr=node.type, default_value=default_value | |
) | |
base_mod_env[node.name].meta = node.meta.copy() | |
elif node.op == "get_attr": | |
base_mod_env[node.name] = base_mod_graph.get_attr(node.target) | |
base_mod_env[node.name].meta = node.meta.copy() | |
attr_val = m | |
for atom in node.target.split("."): # type: ignore[union-attr] | |
if not hasattr(attr_val, atom): | |
raise AttributeError(f"Node target {node.target} not found!") | |
attr_val = getattr(attr_val, atom) | |
base_mod_attrs[node.target] = attr_val # type: ignore[index] | |
return base_mod_env, base_mod_attrs | |
partitions: Dict[str, Partition] = {} | |
orig_nodes: Dict[str, Node] = {} | |
symbol_to_node: Dict["sympy.Symbol", Node] = {} | |
def record_cross_partition_use( | |
def_node: Node, use_node: Optional[Node] | |
): # noqa: B950 | |
from torch.fx.experimental.symbolic_shapes import free_symbols | |
defined = getattr(def_node, "_fx_partition", None) | |
used = getattr(use_node, "_fx_partition", None) | |
if defined != used: | |
if defined is not None: | |
def_partition = partitions[defined] | |
def_partition.outputs.setdefault(def_node.name) | |
if used is not None: | |
def_partition.dependents.setdefault(used) | |
if used is not None: | |
use_partition = partitions[used] | |
use_partition.inputs.setdefault(def_node.name) | |
if (def_val := def_node.meta.get("example_value")) is not None: | |
for s in sorted(free_symbols(def_val), key=str): | |
use_partition.inputs.setdefault(symbol_to_node[s].name) | |
if defined is not None: | |
use_partition.dependencies.setdefault(defined) | |
def instantiate_node_partition_mapping(node): | |
partition_name = str(split_callback(node)) | |
# add node to partitions | |
partition = partitions.get(partition_name) | |
if partition is None: | |
partitions[partition_name] = partition = Partition(partition_name) | |
partition.node_names.append(node.name) | |
node._fx_partition = partition_name | |
# Global State Nodes are nodes which by their global state effects, | |
# "taint" all downstream nodes while they are active. | |
GLOBAL_STATE_NODES = [ | |
torch.amp._enter_autocast, | |
torch.amp._exit_autocast, | |
torch._C._set_grad_enabled | |
] | |
# For grad regions: | |
# ------------------------ | |
# 1. first region: we do nothing | |
# 2. subsequent regions: we insert the set_grad at the beginning | |
grad_regions: OrderedDict[Node, Set[int]] = OrderedDict() | |
# For autocast regions: | |
# ------------------------ | |
# 1. first region: we will only insert the _exit at the end | |
# 2. intermediate regions: we will insert both the | |
# _enter at the beginning and _exit at the end | |
# 3. last region: we will only insert _enter at the beginning | |
# We will do so in the order in which the autocasts were instantiated. | |
autocast_regions: OrderedDict[Node, Set[int]] = OrderedDict() | |
autocast_exits: Dict[Node, Optional[Node]] = {} | |
active_grad = None | |
active_autocasts = set() | |
import sympy # noqa: F811 | |
for node in m.graph.nodes: | |
if node.op in ["placeholder", "get_attr", "output"]: | |
if ( | |
node.op == "placeholder" and | |
(val := node.meta.get("example_value")) is not None and | |
isinstance(val, torch.SymInt) and | |
isinstance(val.node.expr, sympy.Symbol) | |
): | |
symbol_to_node[val.node.expr] = node | |
continue | |
instantiate_node_partition_mapping(node) | |
if node.op == "call_function" and node.target in GLOBAL_STATE_NODES: | |
if node.target == torch._C._set_grad_enabled: | |
assert len(node.args) == 1 | |
assert isinstance(node.args[0], bool) | |
active_grad = node | |
grad_regions[active_grad] = set({split_callback(node)}) | |
elif node.target == torch.amp._enter_autocast: | |
# Should all be python constants | |
assert all(not isinstance(arg, Node) for arg in node.args) | |
active_autocasts.add(node) | |
autocast_regions[node] = set({split_callback(node)}) | |
autocast_exits[node] = None | |
elif node.target == torch.amp._exit_autocast: | |
assert len(node.args) == 1 | |
autocast_regions[node.args[0]].add(split_callback(node)) | |
active_autocasts.remove(node.args[0]) | |
autocast_exits[node.args[0]] = node | |
if active_grad is not None: | |
grad_regions[active_grad].add(split_callback(node)) | |
for a in active_autocasts: | |
autocast_regions[a].add(split_callback(node)) | |
assert all(v is not None for v in autocast_exits.values()), "autocast must exit" | |
autocast_regions = {k: sorted(v) for k, v in autocast_regions.items()} | |
grad_regions = {k: sorted(v) for k, v in grad_regions.items()} | |
if _LOGGER.isEnabledFor(logging.DEBUG): | |
_LOGGER.debug("autocast_regions: %s", autocast_regions) | |
_LOGGER.debug("grad_regions: %s", grad_regions) | |
assert_monotonically_increasing = bool(autocast_regions) or bool(grad_regions) | |
# split nodes into partitions | |
highest_partition = -1 | |
for node in m.graph.nodes: | |
orig_nodes[node.name] = node | |
# TODO currently placeholders/parameters aren't put into random partitions, | |
# rather they're added to the graphs where they are used down below | |
if node.op in ["placeholder", "get_attr"]: | |
continue | |
if node.op == "output": | |
torch.fx.graph.map_arg( | |
node.args[0], lambda n: record_cross_partition_use(n, None) | |
) | |
continue | |
if assert_monotonically_increasing: | |
pid = split_callback(node) | |
assert highest_partition <= pid, \ | |
("autocast or set_grad_enabled require monotonically increasing partitions:" | |
f"highest: {highest_partition}, this node's: {pid}") | |
highest_partition = pid | |
# do not capture cross-partition dependencies for global state nodes as they will be | |
# self-contained - their setup and unwind will be isolated to each partition submodule. | |
if node.target not in GLOBAL_STATE_NODES: | |
torch.fx.graph.map_arg( | |
node.args, lambda def_node: record_cross_partition_use(def_node, node) | |
) | |
torch.fx.graph.map_arg( | |
node.kwargs, lambda def_node: record_cross_partition_use(def_node, node) | |
) # noqa: B950 | |
original_partition_order = list(partitions.keys()) | |
# find partitions with no dependencies | |
root_partitions: List[str] = [] | |
for partition_name, partition in partitions.items(): | |
if not len(partition.dependencies): | |
root_partitions.append(partition_name) | |
# check partitions for circular dependencies and create topological partition ordering | |
sorted_partitions: List[str] = [] | |
while root_partitions: | |
root_partition = root_partitions.pop() | |
sorted_partitions.append(root_partition) | |
for dependent in partitions[root_partition].dependents: | |
partitions[dependent].dependencies.pop(root_partition) | |
if not partitions[dependent].dependencies: | |
root_partitions.append(dependent) | |
if len(sorted_partitions) != len(partitions): | |
raise RuntimeError("cycle exists between partitions!") | |
# Enter prelude | |
for regions_mapping in [autocast_regions, grad_regions]: | |
for node, regions in regions_mapping.items(): | |
assert len(regions) > 0 | |
partitions[str(regions[0])].environment[node] = node | |
for r in regions[1:]: | |
partition = partitions[str(r)] | |
new_node = partition.graph.create_node( | |
op=node.op, | |
target=node.target, | |
args=tuple(arg for arg in node.args), | |
kwargs={}, | |
type_expr=node.type, | |
) | |
new_node.meta = node.meta.copy() # is it really a good idea to copy this? | |
partition.environment[node] = new_node | |
# add placeholders to partition inputs | |
for partition_name in sorted_partitions: | |
partition = partitions[partition_name] | |
for inp in partition.inputs: | |
placeholder = partition.graph.placeholder( | |
inp, | |
type_expr=orig_nodes[inp].type, | |
) | |
placeholder.meta = orig_nodes[inp].meta.copy() | |
partition.environment[orig_nodes[inp]] = placeholder | |
# Transform nodes and collect targets for partition's submodule | |
for node in m.graph.nodes: | |
if hasattr(node, "_fx_partition"): | |
partition = partitions[node._fx_partition] | |
# swap out old graph nodes in kw/args with references to new nodes in this submodule | |
environment = partition.environment | |
gathered_args = torch.fx.graph.map_arg(node.args, lambda n: environment[n]) | |
gathered_kwargs = torch.fx.graph.map_arg( | |
node.kwargs, lambda n: environment[n] | |
) | |
if node.op not in ["call_module", "get_attr"]: | |
target = node.target | |
else: | |
target_atoms = node.target.split(".") | |
target_attr = m | |
for atom in target_atoms: | |
if not hasattr(target_attr, atom): | |
raise AttributeError(f"Operator target {node.target} not found!") | |
target_attr = getattr(target_attr, atom) | |
# target = target_atoms[-1] | |
target = "_".join(target_atoms) | |
partition.targets[target] = target_attr | |
# Fill in the passed-in mapping from new qualname to old qualname | |
if qualname_map is not None: | |
# When creating the split module later, the submodules will have | |
# path prefix matching the corresponding partition's submod_name | |
qualname = f"{partition.submod_name}.{target}" | |
qualname_map[qualname] = node.target | |
assert isinstance(gathered_args, tuple) | |
assert isinstance(gathered_kwargs, dict) | |
name = node.name if keep_original_node_name else None | |
new_node = partition.graph.create_node( | |
op=node.op, | |
target=target, | |
args=gathered_args, | |
kwargs=gathered_kwargs, | |
type_expr=node.type, | |
name=name, | |
) | |
new_node.meta = node.meta.copy() | |
partition.environment[node] = new_node | |
# Exit epilogue | |
for regions_mapping in [autocast_regions]: | |
for node in reversed(regions_mapping): | |
regions = regions_mapping[node] | |
assert len(regions) > 0 | |
for r in regions[:-1]: | |
partition = partitions[str(r)] | |
exit_node = autocast_exits[node] | |
assert exit_node is not None, "Missing exit node" | |
new_node = partition.graph.create_node( | |
op=exit_node.op, | |
target=exit_node.target, | |
args=(partition.environment[node],), | |
kwargs={}, | |
type_expr=exit_node.type, | |
) | |
new_node.meta = exit_node.meta.copy() # is it really a good idea to copy this? | |
# original module environment dict mapping node names to nodes | |
orig_mod_env: Dict[str, Node] = {} | |
# Set up values to construct base module | |
base_mod_env: Dict[str, Node] = {} | |
base_mod_graph: torch.fx.graph.Graph = torch.fx.graph.Graph() | |
base_mod_attrs: Dict[str, torch.fx.graph_module.GraphModule] = {} | |
if not keep_original_order: | |
for node in m.graph.nodes: | |
base_mod_env, base_mod_attrs = construct_graph( | |
node, base_mod_env, base_mod_attrs | |
) | |
else: | |
# Go through the graph to construct the mapping dict | |
for node in m.graph.nodes: | |
orig_mod_env[node.name] = node | |
# Do some things iterating over the partitions in topological order again: | |
# 1) Finish off submodule Graphs by setting corresponding outputs | |
# 2) Construct GraphModules for each submodule | |
# 3) Construct the base graph by emitting calls to those submodules in | |
# topological order or original order specified by keep_original_order | |
construct_order_partitions = ( | |
sorted_partitions if not keep_original_order else original_partition_order | |
) | |
already_constructed_attr_nodes = set() | |
for partition_name in construct_order_partitions: | |
partition = partitions[partition_name] | |
# Set correct output values | |
output_vals = tuple( | |
partition.environment[orig_nodes[name]] for name in partition.outputs | |
) | |
# skip output node generation if there are no output values | |
num_output_vals = len(output_vals) | |
if num_output_vals == 1: | |
partition.graph.output(output_vals[0]) | |
elif num_output_vals > 1: | |
partition.graph.output(output_vals) | |
if keep_original_order: | |
# first get the attr nodes required by this partition | |
orig_mod_attr_nodes: List[Node] = [ | |
orig_mod_env[key] for key in partition.inputs | |
] | |
# Construct GraphModule for this partition | |
for node in orig_mod_attr_nodes: # type: ignore[attr-defined] | |
if node in already_constructed_attr_nodes: | |
continue | |
base_mod_env, base_mod_attrs = construct_graph( | |
node, base_mod_env, base_mod_attrs | |
) | |
already_constructed_attr_nodes.add(node) | |
base_mod_attrs[partition.submod_name] = torch.fx.graph_module.GraphModule( | |
partition.targets, partition.graph | |
) # noqa: B950 | |
# Emit call in base graph to this submodule | |
output_val = base_mod_graph.call_module( | |
partition.submod_name, | |
tuple(base_mod_env[name] for name in partition.inputs), | |
) | |
num_outputs = len(partition.outputs) | |
if num_outputs > 1: | |
# Unpack multiple return values from submodule | |
output_val_proxy = torch.fx.proxy.Proxy(output_val) | |
for i, output_name in enumerate(partition.outputs): | |
base_mod_env[output_name] = output_val_proxy[i].node # type: ignore[index] | |
elif num_outputs == 1: | |
base_mod_env[next(iter(partition.outputs))] = output_val | |
for node in m.graph.nodes: | |
if node.op == "output": | |
base_mod_graph.output( | |
torch.fx.graph.map_arg(node.args[0], lambda n: base_mod_env[n.name]) | |
) # noqa: B950 | |
return torch.fx.graph_module.GraphModule(base_mod_attrs, base_mod_graph) | |