|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Start multiple process locally for DDP. |
|
""" |
|
|
|
import logging |
|
import subprocess as sp |
|
import sys |
|
|
|
from hydra import utils |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ChildrenManager: |
|
def __init__(self): |
|
self.children = [] |
|
self.failed = False |
|
|
|
def add(self, child): |
|
child.rank = len(self.children) |
|
self.children.append(child) |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
if exc_value is not None: |
|
logger.error("An exception happened while starting workers %r", exc_value) |
|
self.failed = True |
|
try: |
|
while self.children and not self.failed: |
|
for child in list(self.children): |
|
try: |
|
exitcode = child.wait(0.1) |
|
except sp.TimeoutExpired: |
|
continue |
|
else: |
|
self.children.remove(child) |
|
if exitcode: |
|
logger.error(f"Worker {child.rank} died, killing all workers") |
|
self.failed = True |
|
except KeyboardInterrupt: |
|
logger.error("Received keyboard interrupt, trying to kill all workers.") |
|
self.failed = True |
|
for child in self.children: |
|
child.terminate() |
|
if not self.failed: |
|
logger.info("All workers completed successfully") |
|
|
|
|
|
def start_ddp_workers(): |
|
import torch as th |
|
|
|
world_size = th.cuda.device_count() |
|
if not world_size: |
|
logger.error( |
|
"DDP is only available on GPU. Make sure GPUs are properly configured with cuda.") |
|
sys.exit(1) |
|
logger.info(f"Starting {world_size} worker processes for DDP.") |
|
with ChildrenManager() as manager: |
|
for rank in range(world_size): |
|
kwargs = {} |
|
argv = list(sys.argv) |
|
argv += [f"world_size={world_size}", f"rank={rank}"] |
|
if rank > 0: |
|
kwargs['stdin'] = sp.DEVNULL |
|
kwargs['stdout'] = sp.DEVNULL |
|
kwargs['stderr'] = sp.DEVNULL |
|
log = utils.HydraConfig().hydra.job_logging.handlers.file.filename |
|
log += f".{rank}" |
|
argv.append("hydra.job_logging.handlers.file.filename=" + log) |
|
manager.add(sp.Popen([sys.executable] + argv, cwd=utils.get_original_cwd(), **kwargs)) |
|
sys.exit(int(manager.failed)) |
|
|