import unittest import asyncio from iterators import AsyncTimeoutIterator async def iter_simple(): yield 1 yield 2 async def iter_with_sleep(): yield 1 await asyncio.sleep(0.6) yield 2 await asyncio.sleep(0.4) yield 3 async def iter_with_exception(): yield 1 yield 2 raise Exception yield 3 class TestTimeoutIterator(unittest.TestCase): def test_normal_iteration(self): async def _(self): i = iter_simple() it = AsyncTimeoutIterator(i) self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), 2) with self.assertRaises(StopAsyncIteration): await it.__anext__() with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_normal_iteration_for_loop(self): async def _(self): i = iter_simple() it = AsyncTimeoutIterator(i) iterResults = [] async for x in it: iterResults.append(x) self.assertEqual(iterResults, [1, 2]) asyncio.get_event_loop().run_until_complete(_(self)) def test_timeout_block(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i) self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), 2) self.assertEqual(await it.__anext__(), 3) with self.assertRaises(StopAsyncIteration): await it.__anext__() with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_timeout_block_for_loop(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i) iterResults = [] async for x in it: iterResults.append(x) self.assertEqual(iterResults, [1, 2, 3]) asyncio.get_event_loop().run_until_complete(_(self)) def test_fixed_timeout(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.5) self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), it.get_sentinel()) self.assertEqual(await it.__anext__(), 2) self.assertEqual(await it.__anext__(), 3) with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_fixed_timeout(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.5) iterResults = [] async for x in it: iterResults.append(x) self.assertEqual(iterResults, [1, it.get_sentinel(), 2, 3]) asyncio.get_event_loop().run_until_complete(_(self)) def test_timeout_update(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.5) self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), it.get_sentinel()) it.set_timeout(0.3) self.assertEqual(await it.__anext__(), 2) self.assertEqual(await it.__anext__(), it.get_sentinel()) self.assertEqual(await it.__anext__(), 3) with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_custom_sentinel(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.5, sentinel="END") self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), "END") self.assertEqual(await it.__anext__(), 2) self.assertEqual(await it.__anext__(), 3) with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_feature_timeout_reset(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.5, reset_on_next=True) self.assertEqual(await it.__anext__(), 1) # timeout gets reset after first iteration self.assertEqual(await it.__anext__(), 2) self.assertEqual(await it.__anext__(), 3) with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_function_set_reset_on_next(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.35, reset_on_next=False) self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), it.get_sentinel()) it.set_reset_on_next(True) self.assertEqual(await it.__anext__(), 2) self.assertEqual(await it.__anext__(), 3) with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_iterator_raises_exception(self): async def _(self): i = iter_with_exception() it = AsyncTimeoutIterator(i, timeout=0.5, sentinel="END") self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), 2) with self.assertRaises(Exception): await it.__anext__() with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self)) def test_interrupt_thread(self): async def _(self): i = iter_with_sleep() it = AsyncTimeoutIterator(i, timeout=0.5, sentinel="END") self.assertEqual(await it.__anext__(), 1) self.assertEqual(await it.__anext__(), it.get_sentinel()) it.interrupt() self.assertEqual(await it.__anext__(), 2) with self.assertRaises(StopAsyncIteration): await it.__anext__() asyncio.get_event_loop().run_until_complete(_(self))