# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """ Streaming module API that should be implemented by all Streaming components, """ from contextlib import contextmanager import typing as tp from torch import nn import torch State = tp.Dict[str, torch.Tensor] class StreamingModule(nn.Module): """Common API for streaming components. Each streaming component has a streaming state, which is just a dict[str, Tensor]. By convention, the first dim of each tensor must be the batch size. Don't use dots in the key names, as this would clash with submodules (like in state_dict). If `self._is_streaming` is True, the component should use and remember the proper state inside `self._streaming_state`. To set a streaming component in streaming state, use with module.streaming(): ... This will automatically reset the streaming state when exiting the context manager. This also automatically propagates to all streaming children module. Some module might also implement the `StreamingModule.flush` method, although this one is trickier, as all parents module must be StreamingModule and implement it as well for it to work properly. See `StreamingSequential` after. """ def __init__(self) -> None: super().__init__() self._streaming_state: State = {} self._is_streaming = False def _apply_named_streaming(self, fn: tp.Any): for name, module in self.named_modules(): if isinstance(module, StreamingModule): fn(name, module) def _set_streaming(self, streaming: bool): def _set_streaming(name, module): module._is_streaming = streaming self._apply_named_streaming(_set_streaming) @contextmanager def streaming(self): """Context manager to enter streaming mode. Reset streaming state on exit. """ self._set_streaming(True) try: yield finally: self._set_streaming(False) self.reset_streaming() def reset_streaming(self): """Reset the streaming state. """ def _reset(name: str, module: StreamingModule): module._streaming_state.clear() self._apply_named_streaming(_reset) def get_streaming_state(self) -> State: """Return the streaming state, including that of sub-modules. """ state: State = {} def _add(name: str, module: StreamingModule): if name: name += "." for key, value in module._streaming_state.items(): state[name + key] = value self._apply_named_streaming(_add) return state def set_streaming_state(self, state: State): """Set the streaming state, including that of sub-modules. """ state = dict(state) def _set(name: str, module: StreamingModule): if name: name += "." module._streaming_state.clear() for key, value in list(state.items()): # complexity is not ideal here, but probably fine. if key.startswith(name): local_key = key[len(name):] if '.' not in local_key: module._streaming_state[local_key] = value del state[key] self._apply_named_streaming(_set) assert len(state) == 0, list(state.keys()) def flush(self, x: tp.Optional[torch.Tensor] = None): """Flush any remaining outputs that were waiting for completion. Typically, for convolutions, this will add the final padding and process the last buffer. This should take an optional argument `x`, which will be provided if a module before this one in the streaming pipeline has already spitted out a flushed out buffer. """ if x is None: return None else: return self(x) class StreamingSequential(StreamingModule, nn.Sequential): """A streaming compatible alternative of `nn.Sequential`. """ def flush(self, x: tp.Optional[torch.Tensor] = None): for module in self: if isinstance(module, StreamingModule): x = module.flush(x) elif x is not None: x = module(x) return x