|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import re |
|
|
from importlib.metadata import version |
|
|
from typing import Optional |
|
|
|
|
|
import packaging |
|
|
import torch |
|
|
from megatron.core import ModelParallelConfig, parallel_state |
|
|
from megatron.core.dist_checkpointing.mapping import ShardedStateDict |
|
|
from megatron.core.tensor_parallel import ColumnParallelLinear, RowParallelLinear |
|
|
from megatron.core.tensor_parallel.mappings import ( |
|
|
gather_from_sequence_parallel_region, |
|
|
scatter_to_sequence_parallel_region, |
|
|
) |
|
|
from megatron.core.transformer.mlp import apply_swiglu_sharded_factory |
|
|
from torch import nn |
|
|
|
|
|
from nemo.collections.common.parts.adapter_modules import AdapterModuleUtil |
|
|
from nemo.collections.common.parts.utils import activation_registry |
|
|
from nemo.core.classes.mixins import adapter_mixin_strategies |
|
|
from nemo.utils.import_utils import safe_import_from |
|
|
|
|
|
TEColumnParallelLinear, HAVE_TE_COL_LINEAR = safe_import_from( |
|
|
"megatron.core.extensions.transformer_engine", "TEColumnParallelLinear" |
|
|
) |
|
|
TELayerNormColumnParallelLinear, HAVE_TE_LN_COL_LINEAR = safe_import_from( |
|
|
"megatron.core.extensions.transformer_engine", |
|
|
"TELayerNormColumnParallelLinear", |
|
|
) |
|
|
TEColumnParallelGroupedLinear, HAVE_TE_COL_GRP_LINEAR = safe_import_from( |
|
|
"megatron.core.extensions.transformer_engine", "TEColumnParallelGroupedLinear" |
|
|
) |
|
|
TERowParallelLinear, HAVE_TE_ROW_LINEAR = safe_import_from( |
|
|
"megatron.core.extensions.transformer_engine", "TERowParallelLinear" |
|
|
) |
|
|
TERowParallelGroupedLinear, HAVE_TE_ROW_GRP_LINEAR = safe_import_from( |
|
|
"megatron.core.extensions.transformer_engine", "TERowParallelGroupedLinear" |
|
|
) |
|
|
TELinear, HAVE_TE_LINEAR = safe_import_from("megatron.core.extensions.transformer_engine", "TELinear") |
|
|
HAVE_TE = all( |
|
|
( |
|
|
HAVE_TE_COL_LINEAR, |
|
|
HAVE_TE_LN_COL_LINEAR, |
|
|
HAVE_TE_ROW_LINEAR, |
|
|
HAVE_TE_LINEAR, |
|
|
HAVE_TE_COL_GRP_LINEAR, |
|
|
HAVE_TE_ROW_GRP_LINEAR, |
|
|
) |
|
|
) |
|
|
|
|
|
MixedFusedLayerNorm, HAVE_APEX = safe_import_from("apex.normalization.fused_layer_norm", "MixedFusedLayerNorm") |
|
|
|
|
|
TECL = (TEColumnParallelLinear, TELayerNormColumnParallelLinear, TEColumnParallelGroupedLinear) |
|
|
TERL = (TERowParallelLinear, TERowParallelGroupedLinear) |
|
|
|
|
|
|
|
|
def get_adapter_attributes_from_linear(m: nn.Module): |
|
|
""" |
|
|
Return input_is_parallel, in_features, out_feature attributes based on implementation of the base layer. |
|
|
""" |
|
|
disable_sequence_parallel_comm = not m.config.sequence_parallel |
|
|
base_linear_is_parallel = True |
|
|
if HAVE_TE and any(isinstance(m, te_column_parallel) for te_column_parallel in TECL): |
|
|
input_is_parallel = False |
|
|
|
|
|
|
|
|
tp_size = parallel_state.get_tensor_model_parallel_world_size() |
|
|
in_features = m.in_features |
|
|
out_features = m.out_features * tp_size |
|
|
|
|
|
if isinstance(m, TELayerNormColumnParallelLinear): |
|
|
|
|
|
m.return_layernorm_output = True |
|
|
|
|
|
if hasattr(m, "ub_overlap_ag"): |
|
|
ub_overlap_ag = m.ub_overlap_ag |
|
|
elif hasattr(m, "ub_overlap_ag_fprop"): |
|
|
ub_overlap_ag = m.ub_overlap_ag_fprop |
|
|
else: |
|
|
ub_overlap_ag = False |
|
|
if m.config.sequence_parallel and not ub_overlap_ag: |
|
|
m.return_layernorm_output_gathered = True |
|
|
te_version = packaging.version.Version(version("transformer-engine")) |
|
|
if te_version >= packaging.version.Version("1.5.0dev") and ( |
|
|
not getattr(m.config, "tp_comm_overlap", False) |
|
|
or getattr(m.config, "tp_comm_overlap_disable_qkv", False) |
|
|
): |
|
|
|
|
|
|
|
|
|
|
|
disable_sequence_parallel_comm = True |
|
|
elif HAVE_TE and any(isinstance(m, te_row_parallel) for te_row_parallel in TERL): |
|
|
input_is_parallel = True |
|
|
tp_size = parallel_state.get_tensor_model_parallel_world_size() |
|
|
in_features = m.in_features * tp_size |
|
|
out_features = m.out_features |
|
|
elif HAVE_TE and isinstance(m, TELinear): |
|
|
input_is_parallel = False |
|
|
in_features = m.in_features |
|
|
out_features = m.out_features |
|
|
base_linear_is_parallel = False |
|
|
elif isinstance(m, ColumnParallelLinear): |
|
|
input_is_parallel = False |
|
|
in_features = m.input_size |
|
|
out_features = m.output_size |
|
|
elif isinstance(m, RowParallelLinear): |
|
|
input_is_parallel = True |
|
|
in_features = m.input_size |
|
|
out_features = m.output_size |
|
|
else: |
|
|
raise NotImplementedError(f"Layer type is unrecognized for LoRA: {type(m)}") |
|
|
|
|
|
return input_is_parallel, in_features, out_features, disable_sequence_parallel_comm, base_linear_is_parallel |
|
|
|
|
|
|
|
|
def is_expert_linear(fqn): |
|
|
""" |
|
|
Return whether the current base module is an expert linear module. |
|
|
See ParallelLinearAdapter.is_expert for usage details. |
|
|
""" |
|
|
return re.match(r'.*mlp\..*experts.*\.linear_fc[1-2]$', fqn) is not None |
|
|
|
|
|
|
|
|
def wildcard_match(pattern, key): |
|
|
""" |
|
|
Return whether the pattern (target module to add LoRA) matches the key (model weight name). |
|
|
|
|
|
Example: |
|
|
-------- |
|
|
>>> wildcard_match("*.layers.0.*.linear_qkv", "decoder.layers.0.self_attention.linear_qkv") |
|
|
True |
|
|
>>> wildcard_match("*.layers.0.*.linear_qkv", "decoder.layers.1.self_attention.linear_qkv") |
|
|
False |
|
|
""" |
|
|
if key is None: |
|
|
return None |
|
|
regex_pattern = re.compile("^" + pattern.replace("*", "(.*)") + "$") |
|
|
match = regex_pattern.match(key) |
|
|
return match is not None |
|
|
|
|
|
|
|
|
def init_method_normal(sigma): |
|
|
"""Init method based on N(0, sigma).""" |
|
|
|
|
|
def init_(tensor): |
|
|
return nn.init.normal_(tensor, mean=0.0, std=sigma) |
|
|
|
|
|
return init_ |
|
|
|
|
|
|
|
|
def init_method_kaiming_uniform(val): |
|
|
""" """ |
|
|
|
|
|
def init_(tensor): |
|
|
return nn.init.kaiming_uniform_(tensor, a=val) |
|
|
|
|
|
return init_ |
|
|
|
|
|
|
|
|
def init_method_const(val): |
|
|
""" """ |
|
|
|
|
|
def init_(tensor): |
|
|
return nn.init.constant_(tensor, val) |
|
|
|
|
|
return init_ |
|
|
|
|
|
|
|
|
def pad_seq_to_mult(x, mult): |
|
|
""" """ |
|
|
|
|
|
if x.shape[0] % mult == 0: |
|
|
return x, 0 |
|
|
pad_len = mult - (x.shape[0] % mult) |
|
|
with torch.no_grad(): |
|
|
|
|
|
x = nn.functional.pad(x, (0, 0, 0, pad_len)) |
|
|
return x, pad_len |
|
|
|
|
|
|
|
|
def unpad_seq_to_mult(x, pad_len): |
|
|
""" """ |
|
|
if pad_len <= 0: |
|
|
return x |
|
|
with torch.no_grad(): |
|
|
|
|
|
return x[:-pad_len, :] |
|
|
|
|
|
|
|
|
class _All2AllHp2Sp(torch.autograd.Function): |
|
|
""" |
|
|
All-2-All from Hidden Parallel to Sequence Parallel |
|
|
This is a temporary workaround and can be updated in the future |
|
|
TODO: Move the functionality to MCore |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def forward(ctx, input_): |
|
|
""" """ |
|
|
|
|
|
world_size = parallel_state.get_tensor_model_parallel_world_size() |
|
|
group = parallel_state.get_tensor_model_parallel_group() |
|
|
send_list = list(input_.chunk(world_size, dim=0)) |
|
|
send_list = [tensor.contiguous() for tensor in send_list] |
|
|
receive_list = [torch.empty_like(send_list[0]) for _ in range(world_size)] |
|
|
torch.distributed.all_to_all(receive_list, send_list, group=group) |
|
|
x = torch.cat(receive_list, dim=-1) |
|
|
|
|
|
return x |
|
|
|
|
|
@staticmethod |
|
|
def backward(ctx, grad_output): |
|
|
""" """ |
|
|
|
|
|
world_size = parallel_state.get_tensor_model_parallel_world_size() |
|
|
group = parallel_state.get_tensor_model_parallel_group() |
|
|
send_list = list(grad_output.chunk(world_size, dim=-1)) |
|
|
send_list = [tensor.contiguous() for tensor in send_list] |
|
|
receive_list = [torch.empty_like(send_list[0]) for _ in range(world_size)] |
|
|
torch.distributed.all_to_all(receive_list, send_list, group=group) |
|
|
x = torch.cat(receive_list, dim=0) |
|
|
|
|
|
return x |
|
|
|
|
|
|
|
|
def all2all_hp2sp(input_): |
|
|
""" """ |
|
|
return _All2AllHp2Sp.apply(input_) |
|
|
|
|
|
|
|
|
class ParallelLinearAdapter(nn.Module, AdapterModuleUtil): |
|
|
""" """ |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_features: int, |
|
|
out_features: int, |
|
|
dim: int, |
|
|
base_linear_name: str, |
|
|
activation: str = 'swish', |
|
|
column_init_method: str = 'xavier', |
|
|
row_init_method: str = 'zero', |
|
|
input_is_parallel: bool = False, |
|
|
dropout: float = 0.0, |
|
|
model_parallel_config: Optional[ModelParallelConfig] = None, |
|
|
alpha: float | None = None, |
|
|
dropout_position: str = 'post', |
|
|
a2a_experimental: bool = False, |
|
|
is_expert: bool = False, |
|
|
disable_sequence_parallel_comm: bool = True, |
|
|
dropout_recompute: bool = False, |
|
|
base_linear_is_parallel: bool = True, |
|
|
**kwargs, |
|
|
): |
|
|
super().__init__() |
|
|
self.base_linear_name = base_linear_name |
|
|
self.activation = activation_registry[activation]() |
|
|
self.dim = dim |
|
|
self.alpha = alpha if alpha is not None else self.dim |
|
|
self.input_is_parallel = input_is_parallel |
|
|
self.dropout_position = dropout_position |
|
|
self.use_a2a = a2a_experimental |
|
|
self.is_expert = is_expert |
|
|
|
|
|
|
|
|
|
|
|
if model_parallel_config is None: |
|
|
model_parallel_config = ModelParallelConfig() |
|
|
_sequence_parallel = model_parallel_config.sequence_parallel |
|
|
model_parallel_config.sequence_parallel = False |
|
|
self.config = model_parallel_config |
|
|
|
|
|
if input_is_parallel: |
|
|
self.linear_in = RowParallelLinear( |
|
|
in_features, |
|
|
dim, |
|
|
config=model_parallel_config, |
|
|
input_is_parallel=True, |
|
|
skip_bias_add=True, |
|
|
bias=False, |
|
|
init_method=self._get_init_fn(column_init_method), |
|
|
) |
|
|
else: |
|
|
self.linear_in = ColumnParallelLinear( |
|
|
in_features, |
|
|
dim, |
|
|
config=model_parallel_config, |
|
|
bias=False, |
|
|
gather_output=True, |
|
|
init_method=self._get_init_fn(column_init_method), |
|
|
disable_grad_reduce=_sequence_parallel, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
lin_out_gather_output = True if input_is_parallel else False |
|
|
if self.use_a2a and input_is_parallel and _sequence_parallel: |
|
|
lin_out_gather_output = False |
|
|
|
|
|
if not base_linear_is_parallel: |
|
|
lin_out_gather_output = True |
|
|
|
|
|
self.linear_out = ColumnParallelLinear( |
|
|
dim, |
|
|
out_features, |
|
|
config=model_parallel_config, |
|
|
bias=False, |
|
|
gather_output=lin_out_gather_output, |
|
|
init_method=self._get_init_fn(row_init_method), |
|
|
) |
|
|
|
|
|
if dropout > 0.0: |
|
|
if dropout_recompute: |
|
|
import thunder |
|
|
|
|
|
self.dropout = thunder.jit(nn.Dropout(dropout)) |
|
|
else: |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
else: |
|
|
self.dropout = None |
|
|
|
|
|
|
|
|
if model_parallel_config.bf16: |
|
|
self.bfloat16() |
|
|
elif model_parallel_config.fp16: |
|
|
self.half() |
|
|
|
|
|
|
|
|
self.setup_adapter_strategy(adapter_mixin_strategies.ReturnResultAdapterStrategy()) |
|
|
|
|
|
|
|
|
model_parallel_config.sequence_parallel = _sequence_parallel |
|
|
self.disable_sequence_parallel_comm = disable_sequence_parallel_comm |
|
|
if not _sequence_parallel: |
|
|
self.disable_sequence_parallel_comm = True |
|
|
|
|
|
if not base_linear_is_parallel: |
|
|
self.disable_sequence_parallel_comm = True |
|
|
|
|
|
def _get_init_fn(self, init_method: str): |
|
|
if init_method == 'xavier': |
|
|
init_fn = nn.init.xavier_normal_ |
|
|
elif init_method == 'normal': |
|
|
init_fn = init_method_normal(0.2) |
|
|
elif init_method == 'kaiming': |
|
|
init_fn = init_method_kaiming_uniform(math.sqrt(5)) |
|
|
elif init_method == "zero": |
|
|
init_fn = init_method_const(0.0) |
|
|
else: |
|
|
raise NotImplementedError("out_init_method should be zero, normal, kaiming or xavier") |
|
|
return init_fn |
|
|
|
|
|
def forward(self, x): |
|
|
""" """ |
|
|
|
|
|
if self.dropout is not None and self.dropout_position == 'pre': |
|
|
x = self.dropout(x) |
|
|
|
|
|
pad_len = 0 |
|
|
if self.is_expert: |
|
|
x, pad_len = pad_seq_to_mult(x, self.config.tensor_model_parallel_size) |
|
|
|
|
|
if not self.disable_sequence_parallel_comm and not self.input_is_parallel and not self.is_expert: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
x = gather_from_sequence_parallel_region(x) |
|
|
|
|
|
if self.config.cpu_offloading and self.config.cpu_offloading_activations: |
|
|
x.activation_offloading = True |
|
|
x, _ = self.linear_in(x) |
|
|
|
|
|
x = self.activation(x) |
|
|
|
|
|
if self.config.cpu_offloading and self.config.cpu_offloading_activations: |
|
|
x.activation_offloading = True |
|
|
x, _ = self.linear_out(x) |
|
|
|
|
|
if not self.disable_sequence_parallel_comm and self.input_is_parallel and not self.is_expert: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.use_a2a: |
|
|
|
|
|
x = all2all_hp2sp(x) |
|
|
else: |
|
|
x = scatter_to_sequence_parallel_region(x) |
|
|
|
|
|
|
|
|
if self.dropout is not None and self.dropout_position == 'post': |
|
|
x = self.dropout(x) |
|
|
|
|
|
x = x * (self.alpha / self.dim) |
|
|
|
|
|
if pad_len > 0: |
|
|
|
|
|
x = unpad_seq_to_mult(x, pad_len) |
|
|
|
|
|
return x |
|
|
|
|
|
def sharded_state_dict( |
|
|
self, prefix: str = '', sharded_offsets: tuple = (), metadata: Optional[dict] = None |
|
|
) -> ShardedStateDict: |
|
|
""" |
|
|
Sharded state dict for LoRA adapter. Special treatment is given to the linear_fc1 adapter |
|
|
since TP is sharded separately for the two logical matrices (gate and up) |
|
|
""" |
|
|
sharded_state_dict = {} |
|
|
linear_in_sd = self.linear_in.sharded_state_dict(f"{prefix}linear_in.", sharded_offsets, metadata) |
|
|
linear_out_sd = self.linear_out.sharded_state_dict(f"{prefix}linear_out.", sharded_offsets, metadata) |
|
|
|
|
|
if 'linear_fc1' in self.base_linear_name: |
|
|
for k, v in linear_out_sd.items(): |
|
|
if k in (f'{prefix}linear_out.weight', f'{prefix}linear_out.bias'): |
|
|
linear_out_sd[k] = apply_swiglu_sharded_factory(v, sharded_offsets) |
|
|
|
|
|
sharded_state_dict.update(linear_in_sd) |
|
|
sharded_state_dict.update(linear_out_sd) |
|
|
return sharded_state_dict |
|
|
|