| | |
| | |
| | |
| | |
| |
|
| | import functools |
| | import tempfile |
| |
|
| | import torch |
| |
|
| |
|
| | def spawn_and_init(fn, world_size, args=None): |
| | if args is None: |
| | args = () |
| | with tempfile.NamedTemporaryFile(delete=False) as tmp_file: |
| | torch.multiprocessing.spawn( |
| | fn=functools.partial(init_and_run, fn, args), |
| | args=(world_size, tmp_file.name,), |
| | nprocs=world_size, |
| | join=True, |
| | ) |
| |
|
| |
|
| | def distributed_init(rank, world_size, tmp_file): |
| | torch.distributed.init_process_group( |
| | backend="nccl", |
| | init_method="file://{}".format(tmp_file), |
| | world_size=world_size, |
| | rank=rank, |
| | ) |
| | torch.cuda.set_device(rank) |
| |
|
| |
|
| | def init_and_run(fn, args, rank, world_size, tmp_file): |
| | distributed_init(rank, world_size, tmp_file) |
| | group = torch.distributed.new_group() |
| | fn(rank, group, *args) |
| |
|
| |
|
| | def objects_are_equal(a, b) -> bool: |
| | if type(a) is not type(b): |
| | return False |
| | if isinstance(a, dict): |
| | if set(a.keys()) != set(b.keys()): |
| | return False |
| | for k in a.keys(): |
| | if not objects_are_equal(a[k], b[k]): |
| | return False |
| | return True |
| | elif isinstance(a, (list, tuple, set)): |
| | if len(a) != len(b): |
| | return False |
| | return all(objects_are_equal(x, y) for x, y in zip(a, b)) |
| | elif torch.is_tensor(a): |
| | return ( |
| | a.size() == b.size() |
| | and a.dtype == b.dtype |
| | and a.device == b.device |
| | and torch.all(a == b) |
| | ) |
| | else: |
| | return a == b |
| |
|