| |
| |
| |
| |
|
|
| import functools |
| import sys |
| import unittest |
|
|
| import torch |
|
|
| from fairseq.distributed import utils as dist_utils |
|
|
| from .utils import objects_are_equal, spawn_and_init |
|
|
|
|
| class DistributedTest(unittest.TestCase): |
| def setUp(self): |
| if not torch.cuda.is_available(): |
| raise unittest.SkipTest("CUDA not available, skipping test") |
| if sys.platform == "win32": |
| raise unittest.SkipTest("NCCL doesn't support Windows, skipping test") |
| if torch.cuda.device_count() < 2: |
| raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping") |
|
|
|
|
| class TestBroadcastObject(DistributedTest): |
| def test_str(self): |
| spawn_and_init( |
| functools.partial( |
| TestBroadcastObject._test_broadcast_object, "hello world" |
| ), |
| world_size=2, |
| ) |
|
|
| def test_tensor(self): |
| spawn_and_init( |
| functools.partial( |
| TestBroadcastObject._test_broadcast_object, |
| torch.rand(5), |
| ), |
| world_size=2, |
| ) |
|
|
| def test_complex(self): |
| spawn_and_init( |
| functools.partial( |
| TestBroadcastObject._test_broadcast_object, |
| { |
| "a": "1", |
| "b": [2, torch.rand(2, 3), 3], |
| "c": (torch.rand(2, 3), 4), |
| "d": {5, torch.rand(5)}, |
| "e": torch.rand(5), |
| "f": torch.rand(5).int().cuda(), |
| }, |
| ), |
| world_size=2, |
| ) |
|
|
| @staticmethod |
| def _test_broadcast_object(ref_obj, rank, group): |
| obj = dist_utils.broadcast_object( |
| ref_obj if rank == 0 else None, src_rank=0, group=group |
| ) |
| assert objects_are_equal(ref_obj, obj) |
|
|
|
|
| class TestAllGatherList(DistributedTest): |
| def test_str_equality(self): |
| spawn_and_init( |
| functools.partial( |
| TestAllGatherList._test_all_gather_list_equality, |
| "hello world", |
| ), |
| world_size=2, |
| ) |
|
|
| def test_tensor_equality(self): |
| spawn_and_init( |
| functools.partial( |
| TestAllGatherList._test_all_gather_list_equality, |
| torch.rand(5), |
| ), |
| world_size=2, |
| ) |
|
|
| def test_complex_equality(self): |
| spawn_and_init( |
| functools.partial( |
| TestAllGatherList._test_all_gather_list_equality, |
| { |
| "a": "1", |
| "b": [2, torch.rand(2, 3), 3], |
| "c": (torch.rand(2, 3), 4), |
| "d": {5, torch.rand(5)}, |
| "e": torch.rand(5), |
| "f": torch.rand(5).int(), |
| }, |
| ), |
| world_size=2, |
| ) |
|
|
| @staticmethod |
| def _test_all_gather_list_equality(ref_obj, rank, group): |
| objs = dist_utils.all_gather_list(ref_obj, group) |
| for obj in objs: |
| assert objects_are_equal(ref_obj, obj) |
|
|
| def test_rank_tensor(self): |
| spawn_and_init( |
| TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2 |
| ) |
|
|
| @staticmethod |
| def _test_all_gather_list_rank_tensor(rank, group): |
| obj = torch.tensor([rank]) |
| objs = dist_utils.all_gather_list(obj, group) |
| for i, obj in enumerate(objs): |
| assert obj.item() == i |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|