Spaces:
Runtime error
Runtime error
File size: 1,719 Bytes
733aa30 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import functools
import tempfile
import torch
def spawn_and_init(fn, world_size, args=None):
if args is None:
args = ()
with tempfile.NamedTemporaryFile(delete=False) as tmp_file:
torch.multiprocessing.spawn(
fn=functools.partial(init_and_run, fn, args),
args=(world_size, tmp_file.name,),
nprocs=world_size,
join=True,
)
def distributed_init(rank, world_size, tmp_file):
torch.distributed.init_process_group(
backend="nccl",
init_method="file://{}".format(tmp_file),
world_size=world_size,
rank=rank,
)
torch.cuda.set_device(rank)
def init_and_run(fn, args, rank, world_size, tmp_file):
distributed_init(rank, world_size, tmp_file)
group = torch.distributed.new_group()
fn(rank, group, *args)
def objects_are_equal(a, b) -> bool:
if type(a) is not type(b):
return False
if isinstance(a, dict):
if set(a.keys()) != set(b.keys()):
return False
for k in a.keys():
if not objects_are_equal(a[k], b[k]):
return False
return True
elif isinstance(a, (list, tuple, set)):
if len(a) != len(b):
return False
return all(objects_are_equal(x, y) for x, y in zip(a, b))
elif torch.is_tensor(a):
return (
a.size() == b.size()
and a.dtype == b.dtype
and a.device == b.device
and torch.all(a == b)
)
else:
return a == b
|