File size: 2,616 Bytes
b8fae22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
"""Minimal DDP helpers driven entirely by torchrun environment variables.

Launch with:  torchrun --nproc_per_node=<N> framework/train.py ...
Single-process (no torchrun) also works: world_size falls back to 1.
"""
from __future__ import annotations

import os
import random
from typing import List, Any

import numpy as np
import torch
import torch.distributed as dist


def is_dist() -> bool:
    return dist.is_available() and dist.is_initialized()


def get_rank() -> int:
    return dist.get_rank() if is_dist() else 0


def get_world_size() -> int:
    return dist.get_world_size() if is_dist() else 1


def is_main() -> bool:
    return get_rank() == 0


def setup_distributed() -> int:
    """Init the process group if launched under torchrun. Returns local_rank."""
    if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
        local_rank = int(os.environ.get("LOCAL_RANK", 0))
        torch.cuda.set_device(local_rank)
        # bind this rank's device to the PG so collectives/barrier don't guess
        # GPU 0 (avoids the NCCL "devices unknown" warning + potential hang)
        try:
            dist.init_process_group(backend="nccl", init_method="env://",
                                    device_id=torch.device("cuda", local_rank))
        except TypeError:  # older torch without device_id kwarg
            dist.init_process_group(backend="nccl", init_method="env://")
        dist.barrier(device_ids=[local_rank])
        return local_rank
    # single GPU / CPU fallback
    if torch.cuda.is_available():
        torch.cuda.set_device(0)
    return 0


def cleanup_distributed() -> None:
    if is_dist():
        dist.barrier()
        dist.destroy_process_group()


def all_gather_object(obj: Any) -> List[Any]:
    """Gather arbitrary picklable objects from all ranks into a flat list."""
    if not is_dist():
        return [obj]
    out: List[Any] = [None for _ in range(get_world_size())]
    dist.all_gather_object(out, obj)
    return out


def set_seed(seed: int, rank: int = 0, deterministic: bool = False) -> None:
    """Seed all RNGs. Each rank gets a distinct stream (seed + rank) so DDP
    workers don't draw identical augmentation noise, while staying reproducible."""
    s = seed + rank
    random.seed(s)
    np.random.seed(s)
    torch.manual_seed(s)
    torch.cuda.manual_seed_all(s)
    if deterministic:
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    else:
        torch.backends.cudnn.benchmark = True


def print_main(*args, **kwargs) -> None:
    if is_main():
        print(*args, **kwargs, flush=True)