Spaces:
Running
on
T4
Running
on
T4
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Streaming module API that should be implemented by all Streaming components, | |
""" | |
from contextlib import contextmanager | |
import typing as tp | |
from torch import nn | |
import torch | |
State = tp.Dict[str, torch.Tensor] | |
class StreamingModule(nn.Module): | |
"""Common API for streaming components. | |
Each streaming component has a streaming state, which is just a dict[str, Tensor]. | |
By convention, the first dim of each tensor must be the batch size. | |
Don't use dots in the key names, as this would clash with submodules | |
(like in state_dict). | |
If `self._is_streaming` is True, the component should use and remember | |
the proper state inside `self._streaming_state`. | |
To set a streaming component in streaming state, use | |
with module.streaming(): | |
... | |
This will automatically reset the streaming state when exiting the context manager. | |
This also automatically propagates to all streaming children module. | |
Some module might also implement the `StreamingModule.flush` method, although | |
this one is trickier, as all parents module must be StreamingModule and implement | |
it as well for it to work properly. See `StreamingSequential` after. | |
""" | |
def __init__(self) -> None: | |
super().__init__() | |
self._streaming_state: State = {} | |
self._is_streaming = False | |
def _apply_named_streaming(self, fn: tp.Any): | |
for name, module in self.named_modules(): | |
if isinstance(module, StreamingModule): | |
fn(name, module) | |
def _set_streaming(self, streaming: bool): | |
def _set_streaming(name, module): | |
module._is_streaming = streaming | |
self._apply_named_streaming(_set_streaming) | |
def streaming(self): | |
"""Context manager to enter streaming mode. Reset streaming state on exit. | |
""" | |
self._set_streaming(True) | |
try: | |
yield | |
finally: | |
self._set_streaming(False) | |
self.reset_streaming() | |
def reset_streaming(self): | |
"""Reset the streaming state. | |
""" | |
def _reset(name: str, module: StreamingModule): | |
module._streaming_state.clear() | |
self._apply_named_streaming(_reset) | |
def get_streaming_state(self) -> State: | |
"""Return the streaming state, including that of sub-modules. | |
""" | |
state: State = {} | |
def _add(name: str, module: StreamingModule): | |
if name: | |
name += "." | |
for key, value in module._streaming_state.items(): | |
state[name + key] = value | |
self._apply_named_streaming(_add) | |
return state | |
def set_streaming_state(self, state: State): | |
"""Set the streaming state, including that of sub-modules. | |
""" | |
state = dict(state) | |
def _set(name: str, module: StreamingModule): | |
if name: | |
name += "." | |
module._streaming_state.clear() | |
for key, value in list(state.items()): | |
# complexity is not ideal here, but probably fine. | |
if key.startswith(name): | |
local_key = key[len(name):] | |
if '.' not in local_key: | |
module._streaming_state[local_key] = value | |
del state[key] | |
self._apply_named_streaming(_set) | |
assert len(state) == 0, list(state.keys()) | |
def flush(self, x: tp.Optional[torch.Tensor] = None): | |
"""Flush any remaining outputs that were waiting for completion. | |
Typically, for convolutions, this will add the final padding | |
and process the last buffer. | |
This should take an optional argument `x`, which will be provided | |
if a module before this one in the streaming pipeline has already | |
spitted out a flushed out buffer. | |
""" | |
if x is None: | |
return None | |
else: | |
return self(x) | |
class StreamingSequential(StreamingModule, nn.Sequential): | |
"""A streaming compatible alternative of `nn.Sequential`. | |
""" | |
def flush(self, x: tp.Optional[torch.Tensor] = None): | |
for module in self: | |
if isinstance(module, StreamingModule): | |
x = module.flush(x) | |
elif x is not None: | |
x = module(x) | |
return x | |