Adi-69s's picture
Upload 5061 files
b2659ad verified
# Copyright (c) Meta Platforms, Inc. and affiliates
# implement matrix related ops for distributed tensor
from typing import List
import numpy as np
import torch
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding
from torch.distributed._tensor.ops.utils import register_prop_rule
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta
aten = torch.ops.aten
@register_prop_rule(aten.slice_backward.default)
def slice_backward_rules(op_schema: OpSchema) -> OutputSharding:
grad_output_spec, input_sizes, dim, start, end, step = op_schema.args_schema
assert isinstance(grad_output_spec, DTensorSpec)
assert isinstance(input_sizes, List)
assert grad_output_spec.tensor_meta is not None
grad_input_stride = list(np.cumprod(input_sizes[::-1])[:-1][::-1])
grad_input_stride.append(1)
dim_map = grad_output_spec.dim_map
sums = grad_output_spec.sums
grad_input_tensor_meta = TensorMeta(
torch.Size(input_sizes),
tuple(grad_input_stride),
grad_output_spec.tensor_meta.dtype,
)
grad_input_spec = DTensorSpec.from_dim_map(
grad_output_spec.mesh,
dim_map,
sums,
tensor_meta=grad_input_tensor_meta,
)
return OutputSharding(grad_input_spec)
@register_prop_rule(aten.bernoulli.default)
@register_prop_rule(aten.bernoulli_.float)
def bernoulli_rules(op_schema: OpSchema) -> OutputSharding:
input_spec = op_schema.args_schema[0]
assert isinstance(input_spec, DTensorSpec)
return OutputSharding(input_spec)
@register_prop_rule(aten.nll_loss_forward.default)
def nll_loss_forward_rules(op_schema: OpSchema) -> OutputSharding:
input_spec = op_schema.args_schema[0]
assert isinstance(input_spec, DTensorSpec)
assert input_spec.tensor_meta is not None
result_shape: List[int] = []
result_stride: List[int] = []
result_dim = 0
total_weight_shape: List[int] = []
total_weight_stride: List[int] = []
total_weight_dim = 0
result_tensor_meta = TensorMeta(
torch.Size(result_shape),
tuple(result_stride),
input_spec.tensor_meta.dtype,
)
total_weight_tensor_meta = TensorMeta(
torch.Size(total_weight_shape),
tuple(result_stride),
input_spec.tensor_meta.dtype,
)
result_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1 for _ in range(result_dim)],
[],
tensor_meta=result_tensor_meta,
)
total_weight_spec = DTensorSpec.from_dim_map(
input_spec.mesh,
[-1 for _ in range(total_weight_dim)],
[],
tensor_meta=total_weight_tensor_meta,
)
return OutputSharding([result_spec, total_weight_spec])
@register_prop_rule(aten.nll_loss_backward.default)
def nll_loss_backward_rules(op_schema: OpSchema) -> OutputSharding:
input_spec = op_schema.args_schema[1]
assert isinstance(input_spec, DTensorSpec)
return OutputSharding(input_spec)