Robert001's picture
first commit
b334e29
raw
history blame
No virus
4.08 kB
import asyncio
import contextlib
import logging
import os
import time
from typing import List
import torch
logger = logging.getLogger(__name__)
DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
@contextlib.asynccontextmanager
async def completed(trace_name='',
name='',
sleep_interval=0.05,
streams: List[torch.cuda.Stream] = None):
"""Async context manager that waits for work to complete on given CUDA
streams."""
if not torch.cuda.is_available():
yield
return
stream_before_context_switch = torch.cuda.current_stream()
if not streams:
streams = [stream_before_context_switch]
else:
streams = [s if s else stream_before_context_switch for s in streams]
end_events = [
torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
]
if DEBUG_COMPLETED_TIME:
start = torch.cuda.Event(enable_timing=True)
stream_before_context_switch.record_event(start)
cpu_start = time.monotonic()
logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
grad_enabled_before = torch.is_grad_enabled()
try:
yield
finally:
current_stream = torch.cuda.current_stream()
assert current_stream == stream_before_context_switch
if DEBUG_COMPLETED_TIME:
cpu_end = time.monotonic()
for i, stream in enumerate(streams):
event = end_events[i]
stream.record_event(event)
grad_enabled_after = torch.is_grad_enabled()
# observed change of torch.is_grad_enabled() during concurrent run of
# async_test_bboxes code
assert (grad_enabled_before == grad_enabled_after
), 'Unexpected is_grad_enabled() value change'
are_done = [e.query() for e in end_events]
logger.debug('%s %s completed: %s streams: %s', trace_name, name,
are_done, streams)
with torch.cuda.stream(stream_before_context_switch):
while not all(are_done):
await asyncio.sleep(sleep_interval)
are_done = [e.query() for e in end_events]
logger.debug(
'%s %s completed: %s streams: %s',
trace_name,
name,
are_done,
streams,
)
current_stream = torch.cuda.current_stream()
assert current_stream == stream_before_context_switch
if DEBUG_COMPLETED_TIME:
cpu_time = (cpu_end - cpu_start) * 1000
stream_times_ms = ''
for i, stream in enumerate(streams):
elapsed_time = start.elapsed_time(end_events[i])
stream_times_ms += f' {stream} {elapsed_time:.2f} ms'
logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
stream_times_ms)
@contextlib.asynccontextmanager
async def concurrent(streamqueue: asyncio.Queue,
trace_name='concurrent',
name='stream'):
"""Run code concurrently in different streams.
:param streamqueue: asyncio.Queue instance.
Queue tasks define the pool of streams used for concurrent execution.
"""
if not torch.cuda.is_available():
yield
return
initial_stream = torch.cuda.current_stream()
with torch.cuda.stream(initial_stream):
stream = await streamqueue.get()
assert isinstance(stream, torch.cuda.Stream)
try:
with torch.cuda.stream(stream):
logger.debug('%s %s is starting, stream: %s', trace_name, name,
stream)
yield
current = torch.cuda.current_stream()
assert current == stream
logger.debug('%s %s has finished, stream: %s', trace_name,
name, stream)
finally:
streamqueue.task_done()
streamqueue.put_nowait(stream)