from typing import Callable, Dict, List, Set import torch import torch.fx as fx import torch.utils._pytree as pytree from torch import Tensor from torch.distributed._tensor import DeviceMesh, Replicate, Shard from torch.distributed._tensor.ops.view_ops import ( DimSpec, InputDim, ops as view_op_rules, ) from torch.distributed._tensor.placement_types import _Partial, DTensorSpec aten = torch.ops.aten class BatchDimAnalyzer: """This class is used to analyze the batch dimension of each tensor/node in the graph. We need to know the batch dimension of each tensor/node so that we know exactly the sharding layout of intermediate tensors. We possibly should evaluate using symbolic shapes to track the batch dimension. We can experiment it later with dynamo integration (as dynamo have mark_dynamic API which allows marking batch dimension only) or try to use FakeTensorMode to mark the batch dimension. For now, let's just use the batch dimension of the first input tensor as the hint to track the batch dimension of all tensors/nodes in the graph. """ def __init__(self, batch_dim: int = 0) -> None: self.batch_dim = batch_dim self.batch_dim_map: Dict[fx.Node, int] = {} # batch dim size is used to track the batch dim size of the input tensor self.batch_dim_size = -1 self.dim_rule_map: Dict[torch._ops.OpOverload, Callable[..., torch.Tensor]] = { aten.squeeze.default: torch.squeeze, aten.squeeze.dim: torch.squeeze, aten.view.default: Tensor.view, aten.reshape.default: torch.reshape, aten._unsafe_view.default: Tensor.view, aten.unsqueeze.default: torch.unsqueeze, aten.expand.default: Tensor.expand, aten.permute.default: torch.permute, aten.repeat.default: Tensor.repeat, aten.transpose.int: torch.transpose, } def init_batch_dim_size(self, batch_dim_size: int) -> None: """Initialize batch dim size base on the first input batch size.""" if self.batch_dim_size != -1 and self.batch_dim_size != batch_dim_size: raise RuntimeError( f"batch dim size is already initialized! " f"Found new batch size: {batch_dim_size} not " f"matching existing batch dim size: {self.batch_dim_size}!" ) self.batch_dim_size = batch_dim_size def set_batch_dim(self, node: fx.Node, batch_dim: int) -> None: self.batch_dim_map[node] = batch_dim def get_batch_dim(self, node: fx.Node) -> int: if node not in self.batch_dim_map: raise RuntimeError(f"batch dim analysis failed on node: {node}!") return self.batch_dim_map[node] def compute_batch_dim(self, node: fx.Node, full_reduction=False) -> int: """Compute the batch dimension for the `node`.""" assert self.batch_dim_size != -1, "batch dim size is not initialized!" if node in self.batch_dim_map: # if batch dim already computed, simply return it return self.batch_dim_map[node] if node.target in self.dim_rule_map: view_op_rule = view_op_rules[self.dim_rule_map[node.target]] # type: ignore[index] args_val = pytree.tree_map_only(fx.Node, lambda n: n.meta["val"], node.args) kwargs_val = pytree.tree_map_only( fx.Node, lambda n: n.meta["val"], node.kwargs ) output_dim_rules = view_op_rule.dim_map(*args_val, **kwargs_val) def collect_input_dim(cmd: DimSpec, input_dims: Set[int]): if isinstance(cmd, InputDim): input_dims.add(cmd.input_dim) for inp in cmd.inputs(): collect_input_dim(inp, input_dims) output_dim_to_input_dims: List[Set[int]] = [] for inp in output_dim_rules: input_dims: Set[int] = set() collect_input_dim(inp, input_dims=input_dims) output_dim_to_input_dims.append(input_dims) operand = node.all_input_nodes[0] operand_batch_dim = self.get_batch_dim(operand) for output_dim, input_dims in enumerate(output_dim_to_input_dims): if operand_batch_dim in input_dims: self.set_batch_dim(node, output_dim) # update batch dim size before return # this is because batch dim size might change during the middle self.batch_dim_size = node.meta["val"].shape[output_dim] return output_dim # if there's no hints from the output_dim_rules, we infer from output # shape to see if there's batch dim, and shard correspondingly node_val = node.meta["val"] if isinstance(node_val, (list, tuple)): shapes = [val.shape for val in node_val] else: shapes = [node_val.shape] # for reduction op that reduces over the sharded batch dim # we don't generate partial, but rather, we generate shard # This is because the intention of data parallel is to never # do full reduction across batch dimension, it would still # keep the reduction activation as sharded. full_reduction = False # loop through the dim size to find the output batch dim for shape in shapes: if len(shape) == 0: full_reduction = True for i, dim_size in enumerate(shape): if dim_size == self.batch_dim_size: self.set_batch_dim(node, i) return i operands = node.all_input_nodes if not operands: # if there's no operands, it must be factory ops and it's a tensor # generated for computation and should be marked as replicated self.set_batch_dim(node, -1) # -1 means replicated return -1 else: # if there's operand we see the operand have batch dim, if operand # have batch dim but output does not, it's either a full reduction, # where we should stay sharded, or it's a reduction on batch dim only # where we should produce partial operand_batch_dim = -1 for operand in operands: if operand in self.batch_dim_map: operand_batch_dim = self.get_batch_dim(operand) # self.get_batch_dim(operands[0]) if operand_batch_dim < 0: # if operand does not have batch dim, we also don't have batch dim self.set_batch_dim(node, operand_batch_dim) return operand_batch_dim elif full_reduction: self.set_batch_dim(node, operand_batch_dim) return operand_batch_dim else: # if operand have batch dim but output does not, it should # produce partial, we use -2 to indicate partial self.set_batch_dim(node, -2) return -2 def compute_act_spec(self, node: fx.Node, mesh: DeviceMesh) -> DTensorSpec: """Compute the batch dimension for the current node, then generate the sharding spec that shards on the batch dimension.""" node_batch_dim = self.compute_batch_dim(node) if node_batch_dim == -1: # indicate this activation is replicated act_spec = DTensorSpec(mesh=mesh, placements=(Replicate(),)) elif node_batch_dim == -2: # indicate this activation is partial act_spec = DTensorSpec(mesh=mesh, placements=(_Partial(),)) else: # indicate this activation is Shard act_spec = DTensorSpec(mesh=mesh, placements=(Shard(node_batch_dim),)) return act_spec