| | |
| | |
| |
|
| | import torch.distributed as dist |
| |
|
| | from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode |
| | from internlm.core.context import global_context as gpc |
| |
|
| |
|
| | def is_model_parallel_parameter(p): |
| | return hasattr(p, IS_TENSOR_PARALLEL) and getattr(p, IS_TENSOR_PARALLEL) |
| |
|
| |
|
| | def sync_model_param(model, parallel_mode): |
| | r"""Make sure data parameters are consistent during Data Parallel Mode. |
| | |
| | Args: |
| | model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. |
| | parallel_mode (:class:`internlm.core.context.ParallelMode`): Parallel mode to be checked. |
| | """ |
| | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
| | for param in model.parameters(): |
| | ranks = gpc.get_ranks_in_group(parallel_mode) |
| | dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) |
| |
|
| |
|
| | def sync_model_param_within_tp(model): |
| | r"""This function is changed from colossalai, which is ``sync_model_param``. |
| | |
| | We modified this function to make sure it only sync parameters within tensor parallelism |
| | but they are not splitted by tensor parallelism. |
| | This function is used to make sure parameters that are not splitted by tensor parallelism |
| | are the same across each tensor parallelism. |
| | For tools, parameters like RMSNorm, LayerNorm... |
| | |
| | Args: |
| | model (:class:`torch.nn.Module`): A pyTorch model on whose parameters you check the consistency. |
| | """ |
| | parallel_mode = ParallelMode.TENSOR |
| | if gpc.is_initialized(parallel_mode) and gpc.get_world_size(parallel_mode) > 1: |
| | for param in model.parameters(): |
| | if not is_model_parallel_parameter(param): |
| | ranks = gpc.get_ranks_in_group(parallel_mode) |
| | dist.broadcast(param, src=ranks[0], group=gpc.get_group(parallel_mode)) |
| |
|
| |
|
| | def is_no_pp_or_last_stage(): |
| | return not gpc.is_initialized(ParallelMode.PIPELINE) or gpc.is_last_rank(ParallelMode.PIPELINE) |
| |
|
| |
|
| | def get_parallel_log_file_name(): |
| | if gpc.is_rank_for_log(): |
| | fn_prefix = "main_" |
| | else: |
| | fn_prefix = "" |
| |
|
| | log_file_name = ( |
| | f"{fn_prefix}dp={gpc.get_local_rank(ParallelMode.DATA)}_" |
| | f"tp={gpc.get_local_rank(ParallelMode.TENSOR)}_pp={gpc.get_local_rank(ParallelMode.PIPELINE)}" |
| | ) |
| | return log_file_name |
| |
|