Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Streaming module API that should be implemented by all Streaming components, | |
| """ | |
| from contextlib import contextmanager | |
| import typing as tp | |
| from torch import nn | |
| import torch | |
| State = tp.Dict[str, torch.Tensor] | |
| class StreamingModule(nn.Module): | |
| """Common API for streaming components. | |
| Each streaming component has a streaming state, which is just a dict[str, Tensor]. | |
| By convention, the first dim of each tensor must be the batch size. | |
| Don't use dots in the key names, as this would clash with submodules | |
| (like in state_dict). | |
| If `self._is_streaming` is True, the component should use and remember | |
| the proper state inside `self._streaming_state`. | |
| To set a streaming component in streaming state, use | |
| with module.streaming(): | |
| ... | |
| This will automatically reset the streaming state when exiting the context manager. | |
| This also automatically propagates to all streaming children module. | |
| Some module might also implement the `StreamingModule.flush` method, although | |
| this one is trickier, as all parents module must be StreamingModule and implement | |
| it as well for it to work properly. See `StreamingSequential` after. | |
| """ | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._streaming_state: State = {} | |
| self._is_streaming = False | |
| def _apply_named_streaming(self, fn: tp.Any): | |
| for name, module in self.named_modules(): | |
| if isinstance(module, StreamingModule): | |
| fn(name, module) | |
| def _set_streaming(self, streaming: bool): | |
| def _set_streaming(name, module): | |
| module._is_streaming = streaming | |
| self._apply_named_streaming(_set_streaming) | |
| def streaming(self): | |
| """Context manager to enter streaming mode. Reset streaming state on exit. | |
| """ | |
| self._set_streaming(True) | |
| try: | |
| yield | |
| finally: | |
| self._set_streaming(False) | |
| self.reset_streaming() | |
| def reset_streaming(self): | |
| """Reset the streaming state. | |
| """ | |
| def _reset(name: str, module: StreamingModule): | |
| module._streaming_state.clear() | |
| self._apply_named_streaming(_reset) | |
| def get_streaming_state(self) -> State: | |
| """Return the streaming state, including that of sub-modules. | |
| """ | |
| state: State = {} | |
| def _add(name: str, module: StreamingModule): | |
| if name: | |
| name += "." | |
| for key, value in module._streaming_state.items(): | |
| state[name + key] = value | |
| self._apply_named_streaming(_add) | |
| return state | |
| def set_streaming_state(self, state: State): | |
| """Set the streaming state, including that of sub-modules. | |
| """ | |
| state = dict(state) | |
| def _set(name: str, module: StreamingModule): | |
| if name: | |
| name += "." | |
| module._streaming_state.clear() | |
| for key, value in list(state.items()): | |
| # complexity is not ideal here, but probably fine. | |
| if key.startswith(name): | |
| local_key = key[len(name):] | |
| if '.' not in local_key: | |
| module._streaming_state[local_key] = value | |
| del state[key] | |
| self._apply_named_streaming(_set) | |
| assert len(state) == 0, list(state.keys()) | |
| def flush(self, x: tp.Optional[torch.Tensor] = None): | |
| """Flush any remaining outputs that were waiting for completion. | |
| Typically, for convolutions, this will add the final padding | |
| and process the last buffer. | |
| This should take an optional argument `x`, which will be provided | |
| if a module before this one in the streaming pipeline has already | |
| spitted out a flushed out buffer. | |
| """ | |
| if x is None: | |
| return None | |
| else: | |
| return self(x) | |
| class StreamingSequential(StreamingModule, nn.Sequential): | |
| """A streaming compatible alternative of `nn.Sequential`. | |
| """ | |
| def flush(self, x: tp.Optional[torch.Tensor] = None): | |
| for module in self: | |
| if isinstance(module, StreamingModule): | |
| x = module.flush(x) | |
| elif x is not None: | |
| x = module(x) | |
| return x | |