| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | from typing import Dict
|
| |
|
| | from verl.single_controller.base import ResourcePool, WorkerGroup
|
| |
|
| | from .worker import DistGlobalInfo, DistRankInfo
|
| |
|
| |
|
| | class MegatronWorkerGroup(WorkerGroup):
|
| | def __init__(self, resource_pool: ResourcePool, **kwargs):
|
| | super().__init__(resource_pool=resource_pool, **kwargs)
|
| | self._megatron_rank_info = None
|
| | self._megatron_global_info: DistGlobalInfo = None
|
| |
|
| | def init_megatron(self, default_megatron_kwargs: Dict = None):
|
| | raise NotImplementedError("MegatronWorkerGroup.init_megatron should be overwritten")
|
| |
|
| | def get_megatron_rank_info(self, rank: int) -> DistRankInfo:
|
| | assert 0 <= rank < self.world_size, f"rank must be from [0, world_size), Got {rank}"
|
| | return self._megatron_rank_info[rank]
|
| |
|
| | @property
|
| | def tp_size(self):
|
| | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
| | return self._megatron_global_info.tp_size
|
| |
|
| | @property
|
| | def dp_size(self):
|
| | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
| | return self._megatron_global_info.dp_size
|
| |
|
| | @property
|
| | def pp_size(self):
|
| | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
| | return self._megatron_global_info.pp_size
|
| |
|
| | @property
|
| | def cp_size(self):
|
| | assert self._megatron_global_info is not None, "MegatronWorkerGroup._megatron_global_info must be initialized"
|
| | return self._megatron_global_info.cp_size
|
| |
|
| | def get_megatron_global_info(self):
|
| | return self._megatron_global_info
|
| |
|