OFA-OCR / fairseq /tests /distributed /test_distributed_timeout_wrapper.py
JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
1.35 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
import signal
import time
import unittest
import torch
from torch import nn
from fairseq.distributed import DistributedTimeoutWrapper
class ModuleWithDelay(nn.Module):
def __init__(self, delay):
super().__init__()
self.delay = delay
def forward(self, x):
time.sleep(self.delay)
return x
class TestDistributedTimeoutWrapper(unittest.TestCase):
def setUp(self):
logging.disable(logging.CRITICAL)
def tearDown(self):
logging.disable(logging.NOTSET)
def test_no_timeout(self):
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 0, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
def test_timeout_safe(self):
module = DistributedTimeoutWrapper(ModuleWithDelay(1), 10, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
def test_timeout_killed(self):
with self.assertRaises(KeyboardInterrupt):
module = DistributedTimeoutWrapper(ModuleWithDelay(5), 1, signal.SIGINT)
module(torch.rand(5))
module.stop_timeout()
if __name__ == "__main__":
unittest.main()