File size: 3,656 Bytes
ee21b96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import functools
import sys
import unittest

import torch

from fairseq.distributed import utils as dist_utils

from .utils import objects_are_equal, spawn_and_init


class DistributedTest(unittest.TestCase):
    def setUp(self):
        if not torch.cuda.is_available():
            raise unittest.SkipTest("CUDA not available, skipping test")
        if sys.platform == "win32":
            raise unittest.SkipTest("NCCL doesn't support Windows, skipping test")
        if torch.cuda.device_count() < 2:
            raise unittest.SkipTest("distributed tests require 2+ GPUs, skipping")


class TestBroadcastObject(DistributedTest):
    def test_str(self):
        spawn_and_init(
            functools.partial(
                TestBroadcastObject._test_broadcast_object, "hello world"
            ),
            world_size=2,
        )

    def test_tensor(self):
        spawn_and_init(
            functools.partial(
                TestBroadcastObject._test_broadcast_object,
                torch.rand(5),
            ),
            world_size=2,
        )

    def test_complex(self):
        spawn_and_init(
            functools.partial(
                TestBroadcastObject._test_broadcast_object,
                {
                    "a": "1",
                    "b": [2, torch.rand(2, 3), 3],
                    "c": (torch.rand(2, 3), 4),
                    "d": {5, torch.rand(5)},
                    "e": torch.rand(5),
                    "f": torch.rand(5).int().cuda(),
                },
            ),
            world_size=2,
        )

    @staticmethod
    def _test_broadcast_object(ref_obj, rank, group):
        obj = dist_utils.broadcast_object(
            ref_obj if rank == 0 else None, src_rank=0, group=group
        )
        assert objects_are_equal(ref_obj, obj)


class TestAllGatherList(DistributedTest):
    def test_str_equality(self):
        spawn_and_init(
            functools.partial(
                TestAllGatherList._test_all_gather_list_equality,
                "hello world",
            ),
            world_size=2,
        )

    def test_tensor_equality(self):
        spawn_and_init(
            functools.partial(
                TestAllGatherList._test_all_gather_list_equality,
                torch.rand(5),
            ),
            world_size=2,
        )

    def test_complex_equality(self):
        spawn_and_init(
            functools.partial(
                TestAllGatherList._test_all_gather_list_equality,
                {
                    "a": "1",
                    "b": [2, torch.rand(2, 3), 3],
                    "c": (torch.rand(2, 3), 4),
                    "d": {5, torch.rand(5)},
                    "e": torch.rand(5),
                    "f": torch.rand(5).int(),
                },
            ),
            world_size=2,
        )

    @staticmethod
    def _test_all_gather_list_equality(ref_obj, rank, group):
        objs = dist_utils.all_gather_list(ref_obj, group)
        for obj in objs:
            assert objects_are_equal(ref_obj, obj)

    def test_rank_tensor(self):
        spawn_and_init(
            TestAllGatherList._test_all_gather_list_rank_tensor, world_size=2
        )

    @staticmethod
    def _test_all_gather_list_rank_tensor(rank, group):
        obj = torch.tensor([rank])
        objs = dist_utils.all_gather_list(obj, group)
        for i, obj in enumerate(objs):
            assert obj.item() == i


if __name__ == "__main__":
    unittest.main()