| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import torch |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
|
|
| from verl import DataProto |
| from verl.utils.device import get_device_name, get_nccl_backend, get_torch_device |
| from verl.utils.model import create_random_mask |
| from verl.utils.seqlen_balancing import ( |
| ceildiv, |
| get_reverse_idx, |
| prepare_dynamic_batch, |
| rearrange_micro_batches, |
| restore_dynamic_batch, |
| ) |
|
|
|
|
| def test_seqlen_balancing(): |
| input_ids = torch.randint(low=0, high=10, size=(20, 100)) |
|
|
| attention_mask = create_random_mask( |
| input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 |
| ) |
| data = {"input_ids": input_ids, "attention_mask": attention_mask} |
| dataproto = DataProto.from_single_dict(data) |
| micro_batches, micro_bsz_idx_lst = rearrange_micro_batches(dataproto.batch, max_token_len=300) |
| batch = torch.cat(micro_batches) |
| micro_bsz_idx = [] |
| for idx in micro_bsz_idx_lst: |
| micro_bsz_idx.extend(idx) |
| reverse_idx_map = get_reverse_idx(micro_bsz_idx) |
| reverse_idx_map = torch.tensor(reverse_idx_map) |
| new_batch = batch[reverse_idx_map] |
| torch.testing.assert_close(new_batch, dataproto.batch) |
|
|
|
|
| def test_dynamic_batch(): |
| input_ids = torch.randint(low=0, high=10, size=(20, 100)) |
|
|
| attention_mask = create_random_mask( |
| input_ids=input_ids, max_ratio_of_left_padding=0.1, max_ratio_of_valid_token=0.9, min_ratio_of_valid_token=0.5 |
| ) |
| data = {"input_ids": input_ids, "attention_mask": attention_mask} |
| dataproto = DataProto.from_single_dict(data) |
| micro_batches, micro_bsz_idx_lst = prepare_dynamic_batch(dataproto, max_token_len=300) |
| input_ids = torch.cat([micro_batch.batch["input_ids"] for micro_batch in micro_batches], dim=0) |
| input_ids = restore_dynamic_batch(input_ids, micro_bsz_idx_lst) |
| torch.testing.assert_close(input_ids, dataproto.batch["input_ids"]) |
|
|
|
|
| def _worker(rank, world_size, init_method, max_token_len, use_same_dp, min_mb): |
| |
| get_torch_device().set_device(rank) |
| dist.init_process_group( |
| backend=get_nccl_backend(), |
| init_method=init_method, |
| world_size=world_size, |
| rank=rank, |
| ) |
|
|
| |
| torch.manual_seed(42 + rank) |
| input_ids = torch.randint(0, 10, (20 + rank * 5, 100), device=f"{get_device_name()}:{rank}") |
| attention_mask = create_random_mask( |
| input_ids=input_ids, |
| max_ratio_of_left_padding=0.1, |
| max_ratio_of_valid_token=0.9, |
| min_ratio_of_valid_token=0.5, |
| ) |
| dp = {"input_ids": input_ids, "attention_mask": attention_mask} |
| proto = DataProto.from_single_dict(dp) |
| batch = proto.batch |
|
|
| |
| micros, idx_lst = rearrange_micro_batches( |
| batch, |
| max_token_len=max_token_len, |
| dp_group=dist.group.WORLD, |
| same_micro_num_in_dp=use_same_dp, |
| min_num_micro_batch=min_mb, |
| ) |
|
|
| |
| seq_len_effective: torch.Tensor = batch["attention_mask"].sum(dim=1) |
| total_seqlen = seq_len_effective.sum().item() |
| local = min(len(seq_len_effective), ceildiv(total_seqlen, max_token_len)) |
|
|
| if min_mb is not None: |
| expected = max(local, min_mb) |
| assert len(micros) == expected |
| if use_same_dp: |
| |
| counts = [torch.zeros(1, device=f"{get_device_name()}:{rank}") for _ in range(world_size)] |
| counts[rank].fill_(local) |
| dist.all_gather(counts, counts[rank]) |
| expected = max(int(c.item()) for c in counts) |
| assert len(micros) == expected |
| else: |
| |
| assert len(micros) == local |
|
|
| |
| flat = torch.cat(micros, dim=0) |
| idx = [] |
| for sub in idx_lst: |
| idx.extend(sub) |
| inv = get_reverse_idx(idx) |
| inv = torch.tensor(inv, device=flat.device) |
| reconstructed = flat[inv] |
| torch.testing.assert_close(reconstructed, batch) |
|
|
| dist.destroy_process_group() |
|
|
|
|
| def test_dataproto_split_uneven(): |
| """Test DataProto.split with uneven splits""" |
| |
| input_ids = torch.randint(low=0, high=10, size=(10, 5)) |
| attention_mask = torch.ones(10, 5) |
| data = {"input_ids": input_ids, "attention_mask": attention_mask} |
| dataproto = DataProto.from_single_dict(data) |
|
|
| |
| splits = dataproto.split(3) |
| assert len(splits) == 4 |
| assert len(splits[0]) == 3 |
| assert len(splits[1]) == 3 |
| assert len(splits[2]) == 3 |
| assert len(splits[3]) == 1 |
|
|
| reconstructed = DataProto.concat(splits) |
| torch.testing.assert_close(reconstructed.batch["input_ids"], dataproto.batch["input_ids"]) |
| torch.testing.assert_close(reconstructed.batch["attention_mask"], dataproto.batch["attention_mask"]) |
|
|
| |
| splits = dataproto.split(10) |
| assert len(splits) == 1 |
| assert len(splits[0]) == 10 |
|
|
| |
| splits = dataproto.split(15) |
| assert len(splits) == 1 |
| assert len(splits[0]) == 10 |
|
|
| |
| import numpy as np |
|
|
| data_with_non_tensor = { |
| "input_ids": input_ids, |
| "attention_mask": attention_mask, |
| "labels": np.array([f"label_{i}" for i in range(10)], dtype=object), |
| } |
| dataproto_with_non_tensor = DataProto.from_single_dict(data_with_non_tensor) |
|
|
| splits = dataproto_with_non_tensor.split(3) |
| assert len(splits) == 4 |
| assert len(splits[0]) == 3 |
| assert len(splits[1]) == 3 |
| assert len(splits[2]) == 3 |
| assert len(splits[3]) == 1 |
|
|
| |
| reconstructed = DataProto.concat(splits) |
| np.testing.assert_array_equal( |
| reconstructed.non_tensor_batch["labels"], dataproto_with_non_tensor.non_tensor_batch["labels"] |
| ) |
|
|
|
|
| def test_seqlen_balancing_distributed_params(tmp_path): |
| world_size = 2 |
| init_file = tmp_path / "dist_init" |
| init_file.write_text("") |
| init_method = f"file://{init_file}" |
|
|
| |
| mp.spawn( |
| _worker, |
| args=(world_size, init_method, 300, False, 4), |
| nprocs=world_size, |
| join=True, |
| ) |
|
|
| |
| mp.spawn( |
| _worker, |
| args=(world_size, init_method, 300, True, None), |
| nprocs=world_size, |
| join=True, |
| ) |
|
|
|
|
| def test_group_balanced_partitions(): |
| """Test group-level balancing keeps same-uid samples together.""" |
| from verl.utils.seqlen_balancing import get_group_balanced_partitions |
|
|
| |
| |
| |
| |
| |
| seqlen_list = [100] * 4 + [200] * 4 + [150] * 4 + [50] * 4 |
| uid_list = [0] * 4 + [1] * 4 + [2] * 4 + [3] * 4 |
|
|
| |
| partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2) |
|
|
| assert len(partitions) == 2 |
|
|
| |
| all_indices = set() |
| for partition in partitions: |
| all_indices.update(partition) |
| assert all_indices == set(range(16)) |
|
|
| |
| for partition in partitions: |
| uids_in_partition = set(uid_list[i] for i in partition) |
| for uid in uids_in_partition: |
| |
| uid_indices = [i for i, u in enumerate(uid_list) if u == uid] |
| assert all(i in partition for i in uid_indices), f"uid {uid} samples split across partitions" |
|
|
|
|
| def test_group_balanced_partitions_single_sample_groups(): |
| """Test group balancing with single-sample groups (n=1).""" |
| from verl.utils.seqlen_balancing import get_group_balanced_partitions |
|
|
| |
| seqlen_list = [100, 200, 150, 50, 300, 250] |
| uid_list = [0, 1, 2, 3, 4, 5] |
|
|
| partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=2) |
|
|
| assert len(partitions) == 2 |
| all_indices = set() |
| for partition in partitions: |
| all_indices.update(partition) |
| assert all_indices == set(range(6)) |
|
|
|
|
| def test_group_balanced_partitions_equal_size(): |
| """Test group balancing with equal_size constraint simulation.""" |
| from verl.utils.seqlen_balancing import get_group_balanced_partitions |
|
|
| |
| |
| seqlen_list = [100, 100, 200, 200, 150, 150, 50, 50, 300, 300, 250, 250, 180, 180, 120, 120] |
| uid_list = [0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7] |
|
|
| partitions = get_group_balanced_partitions(seqlen_list, uid_list, k_partitions=4) |
|
|
| assert len(partitions) == 4 |
|
|
| |
| all_indices = set() |
| for partition in partitions: |
| all_indices.update(partition) |
| assert all_indices == set(range(16)) |
|
|
| |
| for partition in partitions: |
| uids_in_partition = set(uid_list[i] for i in partition) |
| for uid in uids_in_partition: |
| uid_indices = [i for i, u in enumerate(uid_list) if u == uid] |
| assert all(i in partition for i in uid_indices) |
|
|