File size: 3,449 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import subprocess
from typing import Optional, Dict, Tuple

MANAGER_NODE_TABLE = {
    '10.198.8': '10.198.8.31',
    '10.198.6': '10.198.6.31',
    '10.5.38': '10.5.38.31',
    '10.5.39': '10.5.38.31',
    '10.5.36': '10.5.36.31',
    '10.5.37': '10.5.36.31',
    '10.10.30': '10.10.30.91',
}


def get_ip() -> str:
    """
    Overview:
        Get the ip of the current node
    """

    assert os.environ.get('SLURMD_NODENAME'), 'not found SLURMD_NODENAME env variable'
    # expecting nodename to be like: 'SH-IDC1-10-5-36-64'
    nodename = os.environ.get('SLURMD_NODENAME', '')
    myaddr = '.'.join(nodename.split('-')[-4:])
    return myaddr


def get_manager_node_ip(node_ip: Optional[str] = None) -> str:
    """
    Overview:
        Look up the manager node of the slurm cluster and return the node ip
    Arguments:
        - node_ip (:obj:`Optional[str]`): The ip of the current node
    """
    if 'SLURM_JOB_ID' not in os.environ:
        from ditk import logging
        logging.error(
            'We are not running on slurm!, \'auto\' for manager_ip or '
            'coordinator_ip is only intended for running on multiple slurm clusters'
        )
        return '127.0.0.1'
    node_ip = node_ip or get_ip()
    learner_manager_ip_prefix = '.'.join(node_ip.split('.')[0:3])

    if learner_manager_ip_prefix in MANAGER_NODE_TABLE:
        return MANAGER_NODE_TABLE[learner_manager_ip_prefix]
    else:
        raise KeyError("Cluster not found, please add it to the MANAGER_NODE_TABLE in {}".format(__file__))


# get all info of cluster
def get_cls_info() -> Dict[str, list]:
    """
    Overview:
        Get the cluster info
    """

    ret_dict = {}
    info = subprocess.getoutput('sinfo -Nh').split('\n')
    for line in info:
        line = line.strip().split()
        if len(line) != 4:
            continue
        node, _, partition, state = line
        if partition not in ret_dict:
            ret_dict[partition] = []
        assert node not in ret_dict[partition]
        if state in ['idle', 'mix']:
            ret_dict[partition].append(node)

    return ret_dict


def node_to_partition(target_node: str) -> Tuple[str, str]:
    """
    Overview:
        Get the partition of the target node
    Arguments:
        - target_node (:obj:`str`): The target node
    """

    info = subprocess.getoutput('sinfo -Nh').split('\n')
    for line in info:
        line = line.strip().split()
        if len(line) != 4:
            continue
        node, _, partition, state = line
        if node == target_node:
            return partition
    raise RuntimeError("not found target_node: {}".format(target_node))


def node_to_host(node: str) -> str:
    """
    Overview:
        Get the host of the node
    Arguments:
        - node (:obj:`str`): The node
    """

    return '.'.join(node.split('-')[-4:])


def find_free_port_slurm(node: str) -> int:
    """
    Overview:
        Find a free port on the node
    Arguments:
        - node (:obj:`str`): The node
    """

    partition = node_to_partition(node)
    if partition == 'spring_scheduler':
        comment = '--comment=spring-submit'
    else:
        comment = ''
    output = subprocess.getoutput(
        "srun -p {} -w {} {} python -c \"from ding.utils import find_free_port; print('port' + str(find_free_port(0)))\""  # noqa
        .format(partition, node, comment)
    )
    port = output.split('port')[-1]
    return int(port)