arithmetic-grpo / tests /single_controller /test_split_resource_pool.py
LeTue09's picture
initial clean commit
1faccd4
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import ray
import torch
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.base.decorator import Dispatch, register
from verl.single_controller.ray.base import (
RayClassWithInitArgs,
RayResourcePool,
RayWorkerGroup,
split_resource_pool,
)
from verl.utils.device import get_device_name, get_nccl_backend
@ray.remote
class Actor(Worker):
def __init__(self, worker_id) -> None:
super().__init__()
self.worker_id = worker_id
self.temp_tensor = torch.rand(4096, 4096).to(get_device_name())
if not torch.distributed.is_initialized():
rank = int(os.environ.get("RANK", 0))
world_size = int(os.environ.get("WORLD_SIZE", 1))
torch.distributed.init_process_group(backend=get_nccl_backend(), world_size=world_size, rank=rank)
@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def add(self, data: DataProto):
data.batch["a"] += self.rank + self.worker_id
return data
def test_split_resource_pool_with_split_size():
ray.init()
# assume we have 2 nodes, with 4 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[4, 4])
global_resource_pool.get_placement_groups(device_name=get_device_name())
# first 4 gpus for actor_1, last 4 gpus for actor_2
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(resource_pool=global_resource_pool, split_size=4)
actor_cls_1 = RayClassWithInitArgs(cls=Actor, worker_id=0)
actor_cls_2 = RayClassWithInitArgs(cls=Actor, worker_id=100)
actor_worker_1 = RayWorkerGroup(
resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1, device_name=get_device_name()
)
actor_worker_2 = RayWorkerGroup(
resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2, device_name=get_device_name()
)
assert actor_worker_1.world_size == 4
assert actor_worker_2.world_size == 4
data = DataProto.from_dict({"a": torch.zeros(8)})
actor_output_1 = actor_worker_1.add(data)
actor_output_2 = actor_worker_2.add(data)
assert actor_output_1.batch["a"].tolist() == [0, 0, 1, 1, 2, 2, 3, 3]
assert actor_output_2.batch["a"].tolist() == [100, 100, 101, 101, 102, 102, 103, 103]
ray.shutdown()
def test_split_resource_pool_with_split_size_list():
ray.init()
# assume we have 4 nodes, with 2 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2])
global_resource_pool.get_placement_groups(device_name=get_device_name())
# first 2 gpus for actor_1, last 6 gpus for actor_2
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(
resource_pool=global_resource_pool,
split_size=[2, 6],
)
actor_cls_1 = RayClassWithInitArgs(cls=Actor, worker_id=0)
actor_cls_2 = RayClassWithInitArgs(cls=Actor, worker_id=100)
actor_worker_1 = RayWorkerGroup(
resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1, device_name=get_device_name()
)
actor_worker_2 = RayWorkerGroup(
resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2, device_name=get_device_name()
)
assert actor_worker_1.world_size == 2
assert actor_worker_2.world_size == 6
data_1 = DataProto.from_dict({"a": torch.zeros(4)})
data_2 = DataProto.from_dict({"a": torch.zeros(6)})
actor_output_1 = actor_worker_1.add(data_1)
actor_output_2 = actor_worker_2.add(data_2)
print(actor_output_1.batch["a"].tolist())
print(actor_output_2.batch["a"].tolist())
assert actor_output_1.batch["a"].tolist() == [0, 0, 1, 1]
assert actor_output_2.batch["a"].tolist() == [100, 101, 102, 103, 104, 105]
ray.shutdown()
def test_split_resource_pool_with_split_size_list_cross_nodes():
ray.init()
# assume we have 4 nodes, with 2 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[4, 4])
global_resource_pool.get_placement_groups(device_name=get_device_name())
# first 2 gpus for actor_1, last 6 gpus for actor_2
actor_1_resource_pool, actor_2_resource_pool = split_resource_pool(
resource_pool=global_resource_pool,
split_size=[2, 6],
)
actor_cls_1 = RayClassWithInitArgs(cls=Actor, worker_id=0)
actor_cls_2 = RayClassWithInitArgs(cls=Actor, worker_id=100)
actor_worker_1 = RayWorkerGroup(
resource_pool=actor_1_resource_pool, ray_cls_with_init=actor_cls_1, device_name=get_device_name()
)
actor_worker_2 = RayWorkerGroup(
resource_pool=actor_2_resource_pool, ray_cls_with_init=actor_cls_2, device_name=get_device_name()
)
assert actor_worker_1.world_size == 2
assert actor_worker_2.world_size == 6
data_1 = DataProto.from_dict({"a": torch.zeros(4)})
data_2 = DataProto.from_dict({"a": torch.zeros(6)})
actor_output_1 = actor_worker_1.add(data_1)
actor_output_2 = actor_worker_2.add(data_2)
print(actor_output_1.batch["a"].tolist())
print(actor_output_2.batch["a"].tolist())
assert actor_output_1.batch["a"].tolist() == [0, 0, 1, 1]
assert actor_output_2.batch["a"].tolist() == [100, 101, 102, 103, 104, 105]
ray.shutdown()
def test_split_resource_pool_with_split_twice():
ray.init()
# assume we have 4 nodes, with 2 GPUs each
global_resource_pool = RayResourcePool(process_on_nodes=[2, 2, 2, 2])
global_resource_pool.get_placement_groups(device_name=get_device_name())
# actors with [2, 1, 1, 1, 1, 2] (split twice)
rp_1, rp_2, rp_3 = split_resource_pool(
resource_pool=global_resource_pool,
split_size=[2, 4, 2],
)
rp_2_1, rp_2_2, rp_2_3, rp_2_4 = split_resource_pool(
resource_pool=rp_2,
split_size=1,
)
fp_list = [rp_1, rp_2_1, rp_2_2, rp_2_3, rp_2_4, rp_3]
correct_world_size = [2, 1, 1, 1, 1, 2]
correct_output = [
[0.0, 0.0, 1.0, 1.0], # 2 worker
[100.0, 100.0, 100.0, 100.0], # 1 worker
[200.0, 200.0, 200.0, 200.0], # 1 worker
[300.0, 300.0, 300.0, 300.0], # 1 worker
[400.0, 400.0, 400.0, 400.0], # 1 worker
[500.0, 500.0, 501.0, 501.0], # 2 worker
]
for idx, rp in enumerate(fp_list):
actor_cls = RayClassWithInitArgs(cls=Actor, worker_id=idx * 100)
actor_worker = RayWorkerGroup(resource_pool=rp, ray_cls_with_init=actor_cls, device_name=get_device_name())
data = DataProto.from_dict({"a": torch.zeros(4)})
actor_output = actor_worker.add(data)
assert actor_worker.world_size == correct_world_size[idx]
assert actor_output.batch["a"].tolist() == correct_output[idx]
ray.shutdown()