|
|
|
|
|
|
|
|
|
|
|
from typing import Any |
|
|
|
|
|
import stk |
|
|
import stk.backend.triton_kernels |
|
|
import stk.ops |
|
|
import torch |
|
|
from packaging import version |
|
|
|
|
|
from megablocks import grouped_gemm_util as gg |
|
|
from megablocks.layers import common, gelu, mpu |
|
|
from megablocks.layers.activation_fn import act_fn |
|
|
from megablocks.layers.arguments import DEFAULT_ACTIVATION_FN, Arguments, InitFn |
|
|
|
|
|
|
|
|
class ScaleGradient(torch.autograd.Function): |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.autocast_mode.custom_fwd(device_type='cuda') |
|
|
def forward(ctx: Any, x: torch.Tensor, scale: float): |
|
|
ctx.scale = scale |
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.autocast_mode.custom_bwd(device_type='cuda') |
|
|
def backward(ctx: torch.Tensor, grad: torch.Tensor): |
|
|
return grad * ctx.scale, None |
|
|
|
|
|
|
|
|
scale_gradient = ScaleGradient.apply |
|
|
|
|
|
|
|
|
def resolve_dtensor(weight: torch.Tensor): |
|
|
if version.parse(torch.__version__) >= version.parse('2.0.0'): |
|
|
from torch.distributed._tensor import DTensor |
|
|
if isinstance(weight, DTensor): |
|
|
return weight.to_local() |
|
|
return weight |
|
|
|
|
|
|
|
|
def create_moe_expert_weights( |
|
|
args: Arguments, |
|
|
num_experts: int, |
|
|
ffn_hidden_size: int, |
|
|
hidden_size: int, |
|
|
init_method: InitFn, |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
master_weights = torch.empty( |
|
|
num_experts, |
|
|
ffn_hidden_size, |
|
|
hidden_size, |
|
|
device=args.device, |
|
|
dtype=common.dtype(args), |
|
|
) |
|
|
init_method(master_weights) |
|
|
|
|
|
if not args.moe_expert_model_parallelism: |
|
|
return master_weights |
|
|
|
|
|
|
|
|
expert_sharding_degree = mpu.expert_sharding_degree(args) |
|
|
hidden_sharding_degree = mpu.hidden_sharding_degree(args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rank = mpu.get_expert_parallel_rank(args) |
|
|
expert_rank = rank % expert_sharding_degree |
|
|
num_experts_per_rank = num_experts // expert_sharding_degree |
|
|
start_expert = expert_rank * num_experts_per_rank |
|
|
end_expert = (expert_rank + 1) * num_experts_per_rank |
|
|
|
|
|
|
|
|
row_rank = rank // expert_sharding_degree |
|
|
num_rows_per_rank = ffn_hidden_size // hidden_sharding_degree |
|
|
start_row = row_rank * num_rows_per_rank |
|
|
end_row = (row_rank + 1) * num_rows_per_rank |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
weights = master_weights[start_expert:end_expert, start_row:end_row] |
|
|
return weights |
|
|
|
|
|
|
|
|
class MLP(torch.nn.Module): |
|
|
|
|
|
def __init__(self, args: Arguments): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
|
|
|
experts_per_rank = mpu.experts_per_rank(args) |
|
|
|
|
|
self.w1 = torch.nn.Parameter( |
|
|
torch.empty( |
|
|
experts_per_rank, |
|
|
args.hidden_size, |
|
|
mpu.features_per_rank(args), |
|
|
device=args.device, |
|
|
dtype=common.dtype(args), |
|
|
), |
|
|
) |
|
|
self.w2 = torch.nn.Parameter( |
|
|
torch.empty( |
|
|
experts_per_rank, |
|
|
mpu.features_per_rank(args), |
|
|
args.hidden_size, |
|
|
device=args.device, |
|
|
dtype=common.dtype(args), |
|
|
), |
|
|
) |
|
|
mpu.set_expert_model_parallel_attributes( |
|
|
self.w1, |
|
|
args.moe_expert_model_parallelism, |
|
|
) |
|
|
mpu.set_expert_model_parallel_attributes( |
|
|
self.w2, |
|
|
args.moe_expert_model_parallelism, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
w1 = create_moe_expert_weights( |
|
|
args, |
|
|
args.moe_num_experts, |
|
|
args.ffn_hidden_size, |
|
|
args.hidden_size, |
|
|
args.init_method, |
|
|
) |
|
|
self.w1.copy_(w1.transpose(1, 2).contiguous()) |
|
|
self.w2.copy_( |
|
|
create_moe_expert_weights( |
|
|
args, |
|
|
args.moe_num_experts, |
|
|
args.ffn_hidden_size, |
|
|
args.hidden_size, |
|
|
args.output_layer_init_method, |
|
|
), |
|
|
) |
|
|
|
|
|
self.gradient_scale = None |
|
|
if self.args.moe_expert_model_parallelism: |
|
|
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) |
|
|
|
|
|
def scale_grad(self, w): |
|
|
if self.gradient_scale is None: |
|
|
return w |
|
|
return scale_gradient(w, self.gradient_scale) |
|
|
|
|
|
def forward(self, x): |
|
|
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) |
|
|
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) |
|
|
x = torch.bmm(x, w1) |
|
|
x = self.args.activation_fn(x) |
|
|
return torch.bmm(x, w2) |
|
|
|
|
|
|
|
|
def create_dmoe_expert_weights( |
|
|
args: Arguments, |
|
|
num_experts: int, |
|
|
rows: int, |
|
|
columns: int, |
|
|
init_method: InitFn, |
|
|
): |
|
|
weights = create_moe_expert_weights( |
|
|
args, |
|
|
num_experts, |
|
|
rows, |
|
|
columns, |
|
|
init_method, |
|
|
) |
|
|
return weights.view([-1, columns]) |
|
|
|
|
|
|
|
|
class MemoryOptimizedMLP(torch.autograd.Function): |
|
|
"""Sparse MLP with manually scheduled memory reuse.""" |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.autocast_mode.custom_fwd(device_type='cuda') |
|
|
def forward(ctx, x, w1, w2, topo, activation_fn): |
|
|
|
|
|
if ctx._fwd_used_autocast: |
|
|
x = x.to(ctx._dtype) |
|
|
w1 = w1.to(ctx._dtype) |
|
|
w2 = w2.to(ctx._dtype) |
|
|
|
|
|
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): |
|
|
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") |
|
|
|
|
|
topo_tensors = ( |
|
|
topo.row_indices, |
|
|
topo.column_indices, |
|
|
topo.offsets, |
|
|
topo.column_indices_t, |
|
|
topo.offsets_t, |
|
|
topo.block_offsets_t, |
|
|
) |
|
|
|
|
|
|
|
|
sdd_out = stk.ops.sdd(x, w1.t(), topo) |
|
|
|
|
|
|
|
|
activation_fn_out = act_fn(sdd_out, activation_fn) |
|
|
|
|
|
|
|
|
dsd_out = stk.ops.dsd(activation_fn_out, w2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx.shape = topo.shape |
|
|
ctx.x_shape = x.shape |
|
|
ctx.sdd_out_shape = sdd_out.data.shape |
|
|
ctx.dtype = x.dtype |
|
|
ctx.activation_fn = activation_fn |
|
|
ctx.save_for_backward(w1, w2, *topo_tensors, x, sdd_out.data) |
|
|
return dsd_out |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.autocast_mode.custom_bwd(device_type='cuda') |
|
|
def backward(ctx, ddsd_out): |
|
|
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): |
|
|
raise ValueError('Expected all MLP inputs to need grad.') |
|
|
|
|
|
|
|
|
|
|
|
saved_tensors = ctx.saved_tensors |
|
|
w1, w2 = saved_tensors[:2] |
|
|
topo_tensors = saved_tensors[2:8] |
|
|
x = saved_tensors[8] |
|
|
sdd_out_data = saved_tensors[9] |
|
|
|
|
|
|
|
|
activation_fn = ctx.activation_fn |
|
|
sdd_out = stk.Matrix(ctx.shape, sdd_out_data, *topo_tensors) |
|
|
activation_fn_out, activation_grad_fn = act_fn( |
|
|
sdd_out, |
|
|
activation_fn, |
|
|
return_grad_fn=True, |
|
|
) |
|
|
|
|
|
|
|
|
dw2 = stk.ops.dsd(activation_fn_out.t(), ddsd_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dactivation_fn_out = activation_fn_out |
|
|
stk.backend.triton_kernels.sdd( |
|
|
ddsd_out, |
|
|
w2.t(), |
|
|
dactivation_fn_out.shape, |
|
|
dactivation_fn_out.data, |
|
|
dactivation_fn_out.offsets, |
|
|
dactivation_fn_out.row_indices, |
|
|
dactivation_fn_out.column_indices, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if activation_fn is DEFAULT_ACTIVATION_FN: |
|
|
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) |
|
|
else: |
|
|
assert activation_grad_fn is not None |
|
|
activation_grad_fn(dactivation_fn_out.data) |
|
|
dsdd_out = stk.Matrix(ctx.shape, sdd_out.data.grad, *topo_tensors) |
|
|
|
|
|
|
|
|
dw1 = stk.ops.dsd(dsdd_out.t(), x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
stk.backend.triton_kernels.dsd( |
|
|
dsdd_out.shape, |
|
|
dsdd_out.data, |
|
|
dsdd_out.offsets, |
|
|
dsdd_out.row_indices, |
|
|
dsdd_out.column_indices, |
|
|
dsdd_out.offsets_t, |
|
|
dsdd_out.column_indices_t, |
|
|
dsdd_out.block_offsets_t, |
|
|
False, |
|
|
w1, |
|
|
ddsd_out, |
|
|
) |
|
|
dx = ddsd_out |
|
|
return dx, dw1, dw2, None, None |
|
|
|
|
|
|
|
|
memory_optimized_mlp = MemoryOptimizedMLP.apply |
|
|
|
|
|
|
|
|
class SparseMLP(torch.nn.Module): |
|
|
|
|
|
def __init__(self, args: Arguments): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
self._num_rows_per_rank = mpu.experts_per_rank(args) * mpu.features_per_rank(args) |
|
|
|
|
|
self.w1 = torch.nn.Parameter( |
|
|
torch.empty( |
|
|
self._num_rows_per_rank, |
|
|
args.hidden_size, |
|
|
device=args.device, |
|
|
dtype=common.dtype(args), |
|
|
), |
|
|
) |
|
|
self.w2 = torch.nn.Parameter( |
|
|
torch.empty( |
|
|
self._num_rows_per_rank, |
|
|
args.hidden_size, |
|
|
device=args.device, |
|
|
dtype=common.dtype(args), |
|
|
), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
self.w1.copy_( |
|
|
create_dmoe_expert_weights( |
|
|
args, |
|
|
args.moe_num_experts, |
|
|
args.ffn_hidden_size, |
|
|
args.hidden_size, |
|
|
args.init_method, |
|
|
), |
|
|
) |
|
|
self.w2.copy_( |
|
|
create_dmoe_expert_weights( |
|
|
args, |
|
|
args.moe_num_experts, |
|
|
args.ffn_hidden_size, |
|
|
args.hidden_size, |
|
|
args.output_layer_init_method, |
|
|
), |
|
|
) |
|
|
|
|
|
self._should_set_parallelism_attribute = args.moe_expert_model_parallelism |
|
|
mpu.set_expert_model_parallel_attributes( |
|
|
self.w1, |
|
|
self._should_set_parallelism_attribute, |
|
|
) |
|
|
mpu.set_expert_model_parallel_attributes( |
|
|
self.w2, |
|
|
self._should_set_parallelism_attribute, |
|
|
) |
|
|
|
|
|
self.gradient_scale = None |
|
|
if self.args.moe_expert_model_parallelism: |
|
|
self.gradient_scale = 1 / mpu.get_expert_parallel_world_size(self.args,) |
|
|
|
|
|
def scale_grad(self, w): |
|
|
if self.gradient_scale is None: |
|
|
return w |
|
|
return scale_gradient(w, self.gradient_scale) |
|
|
|
|
|
def forward(self, x, topo): |
|
|
w1, w2 = self.scale_grad(self.w1), self.scale_grad(self.w2) |
|
|
w1, w2 = resolve_dtensor(w1), resolve_dtensor(w2) |
|
|
if self.args.memory_optimized_mlp: |
|
|
return memory_optimized_mlp( |
|
|
x, |
|
|
w1, |
|
|
w2, |
|
|
topo, |
|
|
self.args.activation_fn, |
|
|
) |
|
|
|
|
|
|
|
|
x = stk.ops.sdd(x, w1.t(), topo) |
|
|
activation_fn_out = act_fn(x, self.args.activation_fn) |
|
|
return stk.ops.dsd(activation_fn_out, w2) |
|
|
|
|
|
|
|
|
class MemoryOptimizedGroupedMLP(torch.autograd.Function): |
|
|
"""GroupedMLP with manually scheduled memory reuse.""" |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.autocast_mode.custom_fwd(device_type='cuda') |
|
|
def forward(ctx, x, w1, w2, batch_sizes, activation_fn): |
|
|
|
|
|
if ctx._fwd_used_autocast: |
|
|
x = x.to(ctx._dtype) |
|
|
w1 = w1.to(ctx._dtype) |
|
|
w2 = w2.to(ctx._dtype) |
|
|
|
|
|
if (not x.is_contiguous() or not w1.is_contiguous() or not w2.is_contiguous()): |
|
|
raise ValueError("Expected contiguous 'x', 'w1' and 'w2'.") |
|
|
|
|
|
|
|
|
assert gg.backend is not None |
|
|
sdd_out = gg.backend.gmm(x, w1, batch_sizes, trans_b=True) |
|
|
|
|
|
|
|
|
activation_fn_out = activation_fn(sdd_out) |
|
|
|
|
|
|
|
|
dsd_out = gg.backend.gmm(activation_fn_out, w2, batch_sizes) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ctx.x_shape = x.shape |
|
|
ctx.sdd_out_shape = sdd_out.shape |
|
|
ctx.dtype = x.dtype |
|
|
ctx.activation_fn = activation_fn |
|
|
ctx.save_for_backward(w1, w2, batch_sizes, x, sdd_out) |
|
|
return dsd_out |
|
|
|
|
|
@staticmethod |
|
|
@torch.amp.autocast_mode.custom_bwd(device_type='cuda') |
|
|
def backward(ctx: Any, ddsd_out: torch.Tensor): |
|
|
if (not ctx.needs_input_grad[0] or not ctx.needs_input_grad[1] or not ctx.needs_input_grad[2]): |
|
|
raise ValueError('Expected all MLP inputs to need grad.') |
|
|
|
|
|
|
|
|
|
|
|
saved_tensors = ctx.saved_tensors |
|
|
w1, w2 = saved_tensors[:2] |
|
|
batch_sizes = saved_tensors[2] |
|
|
x = saved_tensors[3] |
|
|
sdd_out = saved_tensors[4] |
|
|
|
|
|
|
|
|
activation_fn = ctx.activation_fn |
|
|
with torch.set_grad_enabled(True): |
|
|
sdd_out.requires_grad = True |
|
|
activation_fn_out = activation_fn(sdd_out) |
|
|
activation_grad_fn = activation_fn_out.backward |
|
|
|
|
|
|
|
|
assert gg.backend is not None |
|
|
dw2 = gg.backend.gmm( |
|
|
activation_fn_out, |
|
|
ddsd_out, |
|
|
batch_sizes, |
|
|
trans_a=True, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dactivation_fn_out = activation_fn_out |
|
|
gg.backend.gmm( |
|
|
ddsd_out, |
|
|
w2, |
|
|
batch_sizes, |
|
|
trans_b=True, |
|
|
c=dactivation_fn_out, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if activation_fn is DEFAULT_ACTIVATION_FN: |
|
|
dsdd_out = gelu.gelu_backward_(dactivation_fn_out, sdd_out) |
|
|
else: |
|
|
assert activation_grad_fn is not None |
|
|
activation_grad_fn(dactivation_fn_out) |
|
|
dsdd_out = sdd_out.grad |
|
|
|
|
|
|
|
|
dw1 = gg.backend.gmm(dsdd_out, x, batch_sizes, trans_a=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gg.backend.gmm(dsdd_out, w1, batch_sizes, c=ddsd_out) |
|
|
dx = ddsd_out |
|
|
return dx, dw1, dw2, None, None |
|
|
|
|
|
|
|
|
memory_optimized_grouped_mlp = MemoryOptimizedGroupedMLP.apply |
|
|
|
|
|
|
|
|
class GroupedMLP(SparseMLP): |
|
|
|
|
|
def forward(self, x, tokens_per_expert): |
|
|
batch_sizes = tokens_per_expert.cpu().to(torch.long) |
|
|
w1, w2 = (self.scale_grad(self.w1), self.scale_grad(self.w2)) |
|
|
|
|
|
|
|
|
ne = mpu.experts_per_rank(self.args) |
|
|
w1 = resolve_dtensor(w1).view(ne, -1, self.args.hidden_size) |
|
|
w2 = resolve_dtensor(w2).view(ne, -1, self.args.hidden_size) |
|
|
|
|
|
if self.args.memory_optimized_mlp: |
|
|
return memory_optimized_grouped_mlp( |
|
|
x, |
|
|
w1, |
|
|
w2, |
|
|
batch_sizes, |
|
|
self.args.activation_fn, |
|
|
) |
|
|
|
|
|
|
|
|
assert gg.ops is not None |
|
|
x = gg.ops.gmm(x, w1, batch_sizes, trans_b=True) |
|
|
x = self.args.activation_fn(x) |
|
|
return gg.ops.gmm(x, w2, batch_sizes) |
|
|
|
|
|
|
|
|
class SharedMLP(torch.nn.Module): |
|
|
"""MLP for shared expert. |
|
|
|
|
|
Note: this is a copy -> pasta -> modify of the LLM-Foundry MPTMLP class |
|
|
""" |
|
|
|
|
|
def __init__(self, args: Arguments): |
|
|
super().__init__() |
|
|
self.args = args |
|
|
self.fc_kwargs: dict[str, Any] = { |
|
|
'bias': args.bias, |
|
|
'device': args.device, |
|
|
} |
|
|
self.fc_kwargs.update(args.fc_kwargs) |
|
|
|
|
|
self.up_proj = args.fc_cls( |
|
|
args.hidden_size, |
|
|
args.shared_expert_hidden_size, |
|
|
**self.fc_kwargs, |
|
|
) |
|
|
self.act = args.activation_fn |
|
|
self.down_proj = args.fc_cls( |
|
|
args.shared_expert_hidden_size, |
|
|
args.hidden_size, |
|
|
**self.fc_kwargs, |
|
|
) |
|
|
self.down_proj._is_residual = True |
|
|
|
|
|
def add_experts_sharedexpert( |
|
|
self, |
|
|
shared_expert_out: torch.Tensor, |
|
|
expert_out: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
|
|
|
|
|
|
if self.args.shared_expert_weighted_sum: |
|
|
|
|
|
|
|
|
t_experts = self.args.moe_top_k + 1 |
|
|
sh_mlp_out = shared_expert_out / t_experts |
|
|
return sh_mlp_out.add( |
|
|
expert_out, |
|
|
alpha=(self.args.moe_top_k / t_experts), |
|
|
) |
|
|
|
|
|
return shared_expert_out + expert_out |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return self.down_proj(self.act(self.up_proj(x))) |
|
|
|