File size: 771 Bytes
4562a06
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import socket
from contextlib import closing

import torch.distributed as dist


def get_open_port():
    with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
        s.bind(("", 0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
        return s.getsockname()[1]


# Distributed process group
def ddp_setup(rank, world_size, port="12345"):
    """
    Args:
        rank: Unique Identifier
        world_size: number of processes
    """
    os.environ["MASTER_ADDR"] = "localhost"
    print(f"MasterPort: {str(port)}")
    os.environ["MASTER_PORT"] = str(port)

    # initialize the process group
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)


def cleanup():
    dist.destroy_process_group()