victan commited on
Commit
c24ca9f
1 Parent(s): 519ab1b

Upload seamless_communication/cli/m4t/finetune/dist_utils.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/finetune/dist_utils.py ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import logging
9
+ import os
10
+ from datetime import timedelta
11
+ from typing import List
12
+
13
+ import torch
14
+ import torch.distributed as dist
15
+ import torch.multiprocessing
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ def is_dist_initialized() -> bool:
21
+ if not dist.is_available():
22
+ return False
23
+ if not dist.is_initialized():
24
+ return False
25
+ return True
26
+
27
+
28
+ def get_rank() -> int:
29
+ if not is_dist_initialized():
30
+ return 0
31
+ return dist.get_rank()
32
+
33
+
34
+ def get_local_rank() -> int:
35
+ if not is_dist_initialized():
36
+ return 0
37
+ return int(os.environ["LOCAL_RANK"])
38
+
39
+
40
+ def get_world_size() -> int:
41
+ if not is_dist_initialized():
42
+ return 1
43
+ return dist.get_world_size()
44
+
45
+
46
+ def is_main_process() -> bool:
47
+ return get_rank() == 0
48
+
49
+
50
+ def init_distributed(loggers: List[logging.Logger]) -> None:
51
+ """Initializes the distributed backend"""
52
+ torch.multiprocessing.set_start_method("spawn")
53
+ if "RANK" not in os.environ:
54
+ logger.error(
55
+ "Cannot init disributed context, as environment varaibles are not set."
56
+ )
57
+ return
58
+ rank = int(os.environ["RANK"])
59
+ world_size = int(os.environ["WORLD_SIZE"])
60
+ local_rank = int(os.environ["LOCAL_RANK"])
61
+ logger.info(
62
+ f"Rank={rank} local rank={local_rank}, world_size={world_size}, is_master={rank == 0}"
63
+ )
64
+ dist.init_process_group(
65
+ backend="nccl",
66
+ init_method="env://",
67
+ world_size=world_size,
68
+ rank=rank,
69
+ timeout=timedelta(seconds=180),
70
+ )
71
+ logger.info(f"Setting cuda:{local_rank} as main device")
72
+ if not is_main_process():
73
+ for to_mute in loggers:
74
+ to_mute.setLevel(logging.ERROR)
75
+ torch.cuda.set_device(local_rank)
76
+ dist.barrier()