# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """A streamable transformer.""" import typing as tp import torch import torch.nn as nn import torch.nn.functional as F def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float=10000): """Create time embedding for the given positions, target dimension `dim`. """ # We aim for BTC format assert dim % 2 == 0 half_dim = dim // 2 adim = torch.arange(half_dim, device=positions.device).view(1, 1, -1) phase = positions / (max_period**(adim / (half_dim - 1))) return torch.cat( [ torch.cos(phase), torch.sin(phase), ], dim=-1) class StreamingTransformerEncoderLayer(nn.TransformerEncoderLayer): def forward(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore if self.norm_first: sa_input = self.norm1(x) x = x + self._sa_block(sa_input, x_past, past_context) x = x + self._ff_block(self.norm2(x)) else: sa_input = x x = self.norm1(x + self._sa_block(sa_input, x_past, past_context)) x = self.norm2(x + self._ff_block(x)) return x, sa_input # self-attention block def _sa_block(self, x: torch.Tensor, x_past: torch.Tensor, past_context: int): # type: ignore _, T, _ = x.shape _, H, _ = x_past.shape queries = x keys = torch.cat([x_past, x], dim=1) values = keys queries_pos = torch.arange(H, T + H, device=x.device).view(-1, 1) keys_pos = torch.arange(T + H, device=x.device).view(1, -1) delta = queries_pos - keys_pos valid_access = (delta >= 0) & (delta <= past_context) x = self.self_attn( queries, keys, values, attn_mask=~valid_access, need_weights=False)[0] return self.dropout1(x) class StreamingTransformerEncoder(nn.Module): """TransformerEncoder with streaming support. Args: dim (int): dimension of the data. hidden_scale (int): intermediate dimension of FF module is this times the dimension. num_heads (int): number of heads. num_layers (int): number of layers. max_period (float): maxium period of cosines in the positional embedding. past_context (int or None): receptive field for the causal mask, infinite if None. gelu (bool): if true uses GeLUs, otherwise use ReLUs. norm_in (bool): normalize the input. dropout (float): dropout probability. **kwargs: See `nn.TransformerEncoderLayer`. """ def __init__(self, dim, hidden_scale: float=4., num_heads: int=8, num_layers: int=5, max_period: float=10000, past_context: int=1000, gelu: bool=True, norm_in: bool=True, dropout: float=0., **kwargs): super().__init__() assert dim % num_heads == 0 hidden_dim = int(dim * hidden_scale) self.max_period = max_period self.past_context = past_context activation: tp.Any = F.gelu if gelu else F.relu self.norm_in: nn.Module if norm_in: self.norm_in = nn.LayerNorm(dim) else: self.norm_in = nn.Identity() self.layers = nn.ModuleList() for idx in range(num_layers): self.layers.append( StreamingTransformerEncoderLayer( dim, num_heads, hidden_dim, activation=activation, batch_first=True, dropout=dropout, **kwargs)) def forward(self, x: torch.Tensor, states: tp.Optional[tp.List[torch.Tensor]]=None, offset: tp.Union[int, torch.Tensor]=0): B, T, C = x.shape if states is None: states = [ torch.zeros_like(x[:, :1]) for _ in range(1 + len(self.layers)) ] positions = torch.arange(T, device=x.device).view(1, -1, 1) + offset pos_emb = create_sin_embedding(positions, C, max_period=self.max_period) new_state: tp.List[torch.Tensor] = [] x = self.norm_in(x) x = x + pos_emb for layer_state, layer in zip(states, self.layers): x, new_layer_state = layer(x, layer_state, self.past_context) new_layer_state = torch.cat([layer_state, new_layer_state], dim=1) new_state.append(new_layer_state[:, -self.past_context:, :]) return x, new_state, offset + T