Spaces:
Runtime error
Runtime error
# Copyright (c) Meta Platforms, Inc. and affiliates | |
# implement matrix related ops for distributed tensor | |
from typing import List | |
import numpy as np | |
import torch | |
from torch.distributed._tensor.op_schema import OpSchema, OutputSharding | |
from torch.distributed._tensor.ops.utils import register_prop_rule | |
from torch.distributed._tensor.placement_types import DTensorSpec, TensorMeta | |
aten = torch.ops.aten | |
def slice_backward_rules(op_schema: OpSchema) -> OutputSharding: | |
grad_output_spec, input_sizes, dim, start, end, step = op_schema.args_schema | |
assert isinstance(grad_output_spec, DTensorSpec) | |
assert isinstance(input_sizes, List) | |
assert grad_output_spec.tensor_meta is not None | |
grad_input_stride = list(np.cumprod(input_sizes[::-1])[:-1][::-1]) | |
grad_input_stride.append(1) | |
dim_map = grad_output_spec.dim_map | |
sums = grad_output_spec.sums | |
grad_input_tensor_meta = TensorMeta( | |
torch.Size(input_sizes), | |
tuple(grad_input_stride), | |
grad_output_spec.tensor_meta.dtype, | |
) | |
grad_input_spec = DTensorSpec.from_dim_map( | |
grad_output_spec.mesh, | |
dim_map, | |
sums, | |
tensor_meta=grad_input_tensor_meta, | |
) | |
return OutputSharding(grad_input_spec) | |
def bernoulli_rules(op_schema: OpSchema) -> OutputSharding: | |
input_spec = op_schema.args_schema[0] | |
assert isinstance(input_spec, DTensorSpec) | |
return OutputSharding(input_spec) | |
def nll_loss_forward_rules(op_schema: OpSchema) -> OutputSharding: | |
input_spec = op_schema.args_schema[0] | |
assert isinstance(input_spec, DTensorSpec) | |
assert input_spec.tensor_meta is not None | |
result_shape: List[int] = [] | |
result_stride: List[int] = [] | |
result_dim = 0 | |
total_weight_shape: List[int] = [] | |
total_weight_stride: List[int] = [] | |
total_weight_dim = 0 | |
result_tensor_meta = TensorMeta( | |
torch.Size(result_shape), | |
tuple(result_stride), | |
input_spec.tensor_meta.dtype, | |
) | |
total_weight_tensor_meta = TensorMeta( | |
torch.Size(total_weight_shape), | |
tuple(result_stride), | |
input_spec.tensor_meta.dtype, | |
) | |
result_spec = DTensorSpec.from_dim_map( | |
input_spec.mesh, | |
[-1 for _ in range(result_dim)], | |
[], | |
tensor_meta=result_tensor_meta, | |
) | |
total_weight_spec = DTensorSpec.from_dim_map( | |
input_spec.mesh, | |
[-1 for _ in range(total_weight_dim)], | |
[], | |
tensor_meta=total_weight_tensor_meta, | |
) | |
return OutputSharding([result_spec, total_weight_spec]) | |
def nll_loss_backward_rules(op_schema: OpSchema) -> OutputSharding: | |
input_spec = op_schema.args_schema[1] | |
assert isinstance(input_spec, DTensorSpec) | |
return OutputSharding(input_spec) | |