| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """ |
| In this test, we instantiate a data parallel worker with 8 GPUs |
| """ |
|
|
| import ray |
| import tensordict |
| import torch |
| from codetiming import Timer |
| from torch import distributed as dist |
|
|
| from verl import DataProto |
| from verl.single_controller.base import Worker |
| from verl.single_controller.base.decorator import Dispatch, register |
| from verl.single_controller.ray import RayClassWithInitArgs, RayResourcePool, RayWorkerGroup |
| from verl.utils.ray_utils import parallel_put |
|
|
|
|
| @ray.remote |
| class DummyWorker(Worker): |
| def __init__(self): |
| super().__init__() |
| dist.init_process_group() |
|
|
| @register(dispatch_mode=Dispatch.DP_COMPUTE, blocking=False) |
| def do_nothing(self, data): |
| for key in data.batch.keys(): |
| data.batch[key] += 1 |
| if tensordict.__version__ >= "0.5.0": |
| data.batch = data.batch.consolidate() |
| return data |
|
|
|
|
| def test_data_transfer(): |
| ray.init() |
| |
| resource_pool = RayResourcePool([8]) |
| cls_with_init = RayClassWithInitArgs(cls=DummyWorker) |
| |
| wg = RayWorkerGroup(resource_pool, cls_with_init) |
|
|
| |
| batch_size = 4096 |
| seqlen = 32768 |
|
|
| data_dict = {} |
|
|
| for i in range(2): |
| data_dict[str(i)] = torch.randint(0, 10000, (batch_size, seqlen)) |
|
|
| data = DataProto.from_dict(tensors=data_dict) |
|
|
| print(data) |
|
|
| |
| data_list = data.chunk(wg.world_size) |
|
|
| for i in range(wg.world_size): |
| |
| if tensordict.__version__ >= "0.5.0": |
| data_list[i].batch = data_list[i].batch.consolidate() |
|
|
| with Timer(name="ray.pickle", initial_text=True): |
| for i in range(wg.world_size): |
| ray.cloudpickle.pickle.dumps(data_list[i]) |
|
|
| with Timer(name="raw.pickle", initial_text=True): |
| import pickle |
|
|
| for i in range(wg.world_size): |
| pickle.dumps(data_list[i]) |
|
|
| |
| with Timer(name="put", initial_text=True): |
| |
| data_list_ref = parallel_put(data_list) |
| |
| |
|
|
| with Timer(name="launch", initial_text=True): |
| output_ref = wg.do_nothing(data_list_ref) |
|
|
| with Timer(name="get", initial_text=True): |
| |
| output_lst = ray.get(output_ref) |
|
|
| for input_data, output_data in zip(data_list, output_lst, strict=True): |
| for key in input_data.batch.keys(): |
| assert torch.all(torch.eq(input_data.batch[key] + 1, output_data.batch[key])), ( |
| input_data.batch[key], |
| output_data.batch[key], |
| key, |
| ) |
|
|
| ray.shutdown() |
|
|