# Copyright (c) Facebook, Inc. and its affiliates. # # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. import logging import os import signal import threading from torch import nn logger = logging.getLogger(__name__) class DistributedTimeoutWrapper(nn.Module): """ A wrapper that kills the process if no progress is made within a given *timeout*. The timer is reset every time :func:`forward` is called. Usage:: module = DistributedTimeoutWrapper(module, timeout=30) x = module(input) time.sleep(20) # safe x = module(input) time.sleep(45) # job will be killed before this returns Args: module (nn.Module): module to wrap timeout (int): number of seconds before killing the process (set to a value <= 0 to disable the timeout) signal (Optional): signal to send once timeout is triggered """ def __init__(self, module: nn.Module, timeout: int, signal=signal.SIGINT): super().__init__() self.module = module self.timeout = timeout self.signal = signal if timeout > 0: self._heartbeat = threading.Event() self._heartbeat_thread = threading.Thread( target=self._check_heartbeat, args=(os.getpid(),), daemon=True, ) self._heartbeat_thread.start() self._terminated = False else: self._heartbeat = None self._heartbeat_thread = None def __del__(self): self.stop_timeout() def __getattr__(self, name): """Forward missing attributes to wrapped module.""" try: return super().__getattr__(name) # defer to nn.Module's logic except AttributeError: return getattr(self.module, name) def stop_timeout(self): if self._heartbeat_thread is not None: self._terminated = True self._heartbeat_thread.join() def state_dict(self, *args, **kwargs): return self.module.state_dict(*args, **kwargs) def load_state_dict(self, *args, **kwargs): return self.module.load_state_dict(*args, **kwargs) def forward(self, *args, **kwargs): if self._heartbeat is not None: self._heartbeat.set() return self.module(*args, **kwargs) def _check_heartbeat(self, parent_pid): self._heartbeat.wait() # wait for the first forward pass while True: self._heartbeat.clear() success = self._heartbeat.wait(timeout=self.timeout) if self._terminated: break elif not success: logger.error(( "Killing job for not making progress in {} seconds. " "Set --heartbeat-timeout=-1 to disable this timeout." ).format(int(self.timeout))) os.kill(parent_pid, self.signal) return