zjowowen's picture
init space
3dfe8fb
raw
history blame contribute delete
5.66 kB
import os
import re
from time import sleep
import numpy as np
from typing import Any, Dict, List, Optional
class SlurmParser():
def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None:
"""
Overview:
Should only set global cluster properties
"""
self.kwargs = kwargs
self.ntasks = int(os.environ["SLURM_NTASKS"])
self.platform_spec = platform_spec
self.tasks = {}
self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"])
self.nodelist = self._parse_node_list()
self.ports = int(kwargs.get("ports") or 15151)
self.parallel_workers = kwargs.get("parallel_workers") or 1
self.topology = kwargs.get("topology") or "alone"
def parse(self) -> dict:
procid = int(os.environ["SLURM_PROCID"])
task = self._get_task(procid)
# Validation
assert task["address"] == os.environ["SLURMD_NODENAME"]
return {**self.kwargs, **task}
def _get_task(self, procid: int) -> Dict[str, Any]:
if procid in self.tasks:
return self.tasks.get(procid)
if self.platform_spec:
task = self.platform_spec["tasks"][procid]
else:
task = {}
if "ports" not in task:
task["ports"] = self._get_ports(procid)
if "address" not in task:
task["address"] = self._get_address(procid)
if "node_ids" not in task:
task["node_ids"] = self._get_node_id(procid)
task["attach_to"] = self._get_attach_to(procid, task.get("attach_to"))
task["topology"] = self.topology
task["parallel_workers"] = self.parallel_workers
self.tasks[procid] = task
return task
def _parse_node_list(self) -> List[str]:
nodelist = os.environ["SLURM_NODELIST"]
result = re.match(r"(.*)?\[(.*)\]$", nodelist)
if result:
prefix, tails = result.groups()
nodelist = []
for tail in tails.split(","):
if "-" in tail:
start, stop = tail.split("-")
for number in range(int(start), int(stop) + 1):
nodelist.append(prefix + str(number))
else:
nodelist.append(prefix + tail)
elif isinstance(nodelist, str):
nodelist = [nodelist]
if self.ntasks_per_node > 1:
expand_nodelist = [] # Expand node for each task
for node in nodelist:
for _ in range(self.ntasks_per_node):
expand_nodelist.append(node)
nodelist = expand_nodelist
return nodelist
def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str:
if attach_to:
attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")]
elif procid == 0:
attach_to = []
else:
if self.topology == "mesh":
prev_tasks = [self._get_task(i) for i in range(procid)]
attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks]
attach_to = list(np.concatenate(attach_to))
elif self.topology == "star":
head_task = self._get_task(0)
attach_to = self._get_attach_to_from_task(head_task)
else:
attach_to = []
return ",".join(attach_to)
def _get_attach_to_part(self, attach_part: str) -> str:
"""
Overview:
Parse each part of attach_to.
Arguments:
- attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0
Returns
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
"""
if not attach_part.startswith("$node."):
return attach_part
attach_node_id = int(attach_part[6:])
attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id))
return self._get_tcp_link(attach_task["address"], attach_task["ports"])
def _get_attach_to_from_task(self, task: dict) -> List[str]:
"""
Overview:
Get attach_to list from task, note that parallel_workers will affact the connected processes.
Arguments:
- task (:obj:`dict`): The task object.
Returns
- attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
"""
port = task.get("ports")
address = task.get("address")
ports = [int(port) + i for i in range(self.parallel_workers)]
attach_to = [self._get_tcp_link(address, port) for port in ports]
return attach_to
def _get_procid_from_nodeid(self, nodeid: int) -> int:
procid = None
for i in range(self.ntasks):
task = self._get_task(i)
if task["node_ids"] == nodeid:
procid = i
break
if procid is None:
raise Exception("Can not find procid from nodeid: {}".format(nodeid))
return procid
def _get_ports(self, procid) -> int:
return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers
def _get_address(self, procid: int) -> str:
address = self.nodelist[procid]
return address
def _get_node_id(self, procid: int) -> int:
return procid * self.parallel_workers
def _get_tcp_link(self, address: str, port: int) -> str:
return "tcp://{}:{}".format(address, port)
def slurm_parser(platform_spec: str, **kwargs) -> dict:
return SlurmParser(platform_spec, **kwargs).parse()