|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Streaming module API that should be implemented by all Streaming components, |
|
""" |
|
|
|
from contextlib import contextmanager |
|
import typing as tp |
|
from torch import nn |
|
import torch |
|
|
|
|
|
State = tp.Dict[str, torch.Tensor] |
|
|
|
|
|
class StreamingModule(nn.Module): |
|
"""Common API for streaming components. |
|
|
|
Each streaming component has a streaming state, which is just a dict[str, Tensor]. |
|
By convention, the first dim of each tensor must be the batch size. |
|
Don't use dots in the key names, as this would clash with submodules |
|
(like in state_dict). |
|
|
|
If `self._is_streaming` is True, the component should use and remember |
|
the proper state inside `self._streaming_state`. |
|
|
|
To set a streaming component in streaming state, use |
|
|
|
with module.streaming(): |
|
... |
|
|
|
This will automatically reset the streaming state when exiting the context manager. |
|
This also automatically propagates to all streaming children module. |
|
|
|
Some module might also implement the `StreamingModule.flush` method, although |
|
this one is trickier, as all parents module must be StreamingModule and implement |
|
it as well for it to work properly. See `StreamingSequential` after. |
|
""" |
|
def __init__(self) -> None: |
|
super().__init__() |
|
self._streaming_state: State = {} |
|
self._is_streaming = False |
|
|
|
def _apply_named_streaming(self, fn: tp.Any): |
|
for name, module in self.named_modules(): |
|
if isinstance(module, StreamingModule): |
|
fn(name, module) |
|
|
|
def _set_streaming(self, streaming: bool): |
|
def _set_streaming(name, module): |
|
module._is_streaming = streaming |
|
self._apply_named_streaming(_set_streaming) |
|
|
|
@contextmanager |
|
def streaming(self): |
|
"""Context manager to enter streaming mode. Reset streaming state on exit.""" |
|
self._set_streaming(True) |
|
try: |
|
yield |
|
finally: |
|
self._set_streaming(False) |
|
self.reset_streaming() |
|
|
|
def reset_streaming(self): |
|
"""Reset the streaming state.""" |
|
def _reset(name: str, module: StreamingModule): |
|
module._streaming_state.clear() |
|
|
|
self._apply_named_streaming(_reset) |
|
|
|
def get_streaming_state(self) -> State: |
|
"""Return the streaming state, including that of sub-modules.""" |
|
state: State = {} |
|
|
|
def _add(name: str, module: StreamingModule): |
|
if name: |
|
name += "." |
|
for key, value in module._streaming_state.items(): |
|
state[name + key] = value |
|
|
|
self._apply_named_streaming(_add) |
|
return state |
|
|
|
def set_streaming_state(self, state: State): |
|
"""Set the streaming state, including that of sub-modules.""" |
|
state = dict(state) |
|
|
|
def _set(name: str, module: StreamingModule): |
|
if name: |
|
name += "." |
|
module._streaming_state.clear() |
|
for key, value in list(state.items()): |
|
|
|
if key.startswith(name): |
|
local_key = key[len(name):] |
|
if '.' not in local_key: |
|
module._streaming_state[local_key] = value |
|
del state[key] |
|
|
|
self._apply_named_streaming(_set) |
|
assert len(state) == 0, list(state.keys()) |
|
|
|
def flush(self, x: tp.Optional[torch.Tensor] = None): |
|
"""Flush any remaining outputs that were waiting for completion. |
|
Typically, for convolutions, this will add the final padding |
|
and process the last buffer. |
|
|
|
This should take an optional argument `x`, which will be provided |
|
if a module before this one in the streaming pipeline has already |
|
spitted out a flushed out buffer. |
|
""" |
|
if x is None: |
|
return None |
|
else: |
|
return self(x) |
|
|
|
|
|
class StreamingSequential(StreamingModule, nn.Sequential): |
|
"""A streaming compatible alternative of `nn.Sequential`. |
|
""" |
|
def flush(self, x: tp.Optional[torch.Tensor] = None): |
|
for module in self: |
|
if isinstance(module, StreamingModule): |
|
x = module.flush(x) |
|
elif x is not None: |
|
x = module(x) |
|
return x |
|
|