Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
""" | |
Transformer model, with streaming support, xformer attention support | |
and easy causal attention with a potentially finite receptive field. | |
See `StreamingTransformer` for more information. | |
Unlike regular PyTorch Transformer, we make the hard choice that batches are first. | |
""" | |
import typing as tp | |
from einops import rearrange | |
import torch | |
import torch.nn as nn | |
from torch.nn import functional as F | |
from torch.utils.checkpoint import checkpoint as torch_checkpoint | |
from xformers import ops | |
from .rope import RotaryEmbedding | |
from .streaming import StreamingModule | |
def _is_profiled() -> bool: | |
# Return true if we are currently running with a xformers profiler activated. | |
try: | |
from xformers.profiler import profiler | |
except ImportError: | |
return False | |
return profiler._Profiler._CURRENT_PROFILER is not None | |
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module: | |
"""Create normalization module for transformer encoder layer. | |
Args: | |
norm_type (str): Normalization method. | |
dim (int): Dimension of the normalized layer. | |
**kwargs (dict): Additional parameters for normalization layer. | |
Returns: | |
nn.Module: Normalization module. | |
""" | |
if norm_type == 'layer_norm': | |
return nn.LayerNorm(dim, eps=1e-5, **kwargs) | |
else: | |
raise ValueError(f"Unknown norm type: {norm_type}") | |
def create_sin_embedding(positions: torch.Tensor, dim: int, max_period: float = 10000, | |
dtype: torch.dtype = torch.float32) -> torch.Tensor: | |
"""Create sinusoidal positional embedding, with shape `[B, T, C]`. | |
Args: | |
positions (torch.Tensor): LongTensor of positions. | |
dim (int): Dimension of the embedding. | |
max_period (float): Maximum period of the cosine/sine functions. | |
dtype (torch.dtype or str): dtype to use to generate the embedding. | |
Returns: | |
torch.Tensor: Sinusoidal positional embedding. | |
""" | |
# We aim for BTC format | |
assert dim % 2 == 0 | |
half_dim = dim // 2 | |
positions = positions.to(dtype) | |
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1) | |
max_period_tensor = torch.full([], max_period, device=positions.device, dtype=dtype) # avoid sync point | |
phase = positions / (max_period_tensor ** (adim / (half_dim - 1))) | |
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1) | |
def expand_repeated_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: | |
"""torch.repeat_interleave(x, dim=2, repeats=n_rep) from xlformers""" | |
bs, slen, n_kv_heads, head_dim = x.shape | |
if n_rep == 1: | |
return x | |
return ( | |
x[:, :, :, None, :] | |
.expand(bs, slen, n_kv_heads, n_rep, head_dim) | |
.reshape(bs, slen, n_kv_heads * n_rep, head_dim) | |
) | |
class LayerScale(nn.Module): | |
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). | |
This rescales diagonaly the residual outputs close to 0, with a learnt scale. | |
Args: | |
channels (int): Number of channels. | |
init (float): Initial scale. | |
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`. | |
device (torch.device or None): Device on which to initialize the module. | |
dtype (torch.dtype or None): dtype to use to initialize the module. | |
""" | |
def __init__(self, channels: int, init: float = 1e-4, channel_last: bool = True, | |
device=None, dtype=None): | |
super().__init__() | |
self.channel_last = channel_last | |
self.scale = nn.Parameter( | |
torch.full((channels,), init, | |
requires_grad=True, device=device, dtype=dtype)) | |
def forward(self, x: torch.Tensor): | |
if self.channel_last: | |
return self.scale * x | |
else: | |
return self.scale[:, None] * x | |
class StreamingMultiheadAttention(StreamingModule): | |
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation. | |
Args: | |
embed_dim (int): Dimension to project to. | |
num_heads (int): Number of heads. | |
dropout (float): Dropout level. | |
bias (bool): Use bias in projections. | |
causal (bool): Causal mask applied automatically. | |
past_context (int or None): Receptive field for the causal mask, infinite if None. | |
custom (bool): Use custom MHA implementation, for testing / benchmarking. | |
memory_efficient (bool): Use xformers based memory efficient attention. | |
attention_as_float32 (bool): Perform the attention as float32 | |
(especially important with memory_efficient as autocast won't do this automatically). | |
rope (`RotaryEmbedding` or None): Rope embedding to use. | |
cross_attention: Should be true when used as a cross attention. | |
All keys and values must be available at once, streaming is only for the queries. | |
Cannot be used with `causal` or `rope` (as it wouldn't make sens to | |
intepret the time steps in the keys relative to those in the queries). | |
safe_streaming (bool): Bug fix, will go away with xformers update. | |
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product. | |
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). | |
This will lead to faster decoding time on A100 or other GPUs with tensorcore. | |
device (torch.device or None): Sevice on which to initialize. | |
dtype (torch.dtype or None): dtype to use. | |
""" | |
def __init__(self, embed_dim: int, num_heads: int, dropout: float = 0.0, bias: bool = True, | |
causal: bool = False, past_context: tp.Optional[int] = None, custom: bool = False, | |
memory_efficient: bool = False, attention_as_float32: bool = False, | |
rope: tp.Optional[RotaryEmbedding] = None, cross_attention: bool = False, | |
safe_streaming: bool = True, qk_layer_norm: bool = False, kv_repeat: int = 1, | |
device=None, dtype=None): | |
super().__init__() | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
if past_context is not None: | |
assert causal | |
self.embed_dim = embed_dim | |
self.causal = causal | |
self.past_context = past_context | |
self.memory_efficient = memory_efficient | |
self.attention_as_float32 = attention_as_float32 | |
self.rope = rope | |
self.cross_attention = cross_attention | |
self.safe_streaming = safe_streaming | |
self.num_heads = num_heads | |
self.dropout = dropout | |
self.kv_repeat = kv_repeat | |
if cross_attention: | |
assert not causal, "Causal cannot work with cross attention." | |
assert rope is None, "Rope cannot work with cross attention." | |
if memory_efficient: | |
_verify_xformers_memory_efficient_compat() | |
self.custom = _is_custom(custom, memory_efficient) | |
if self.custom: | |
out_dim = embed_dim | |
assert num_heads % kv_repeat == 0 | |
assert not cross_attention or kv_repeat == 1 | |
num_kv = num_heads // kv_repeat | |
kv_dim = (embed_dim // num_heads) * num_kv | |
out_dim += 2 * kv_dim | |
in_proj = nn.Linear(embed_dim, out_dim, bias=bias, **factory_kwargs) | |
# We try to follow the default PyTorch MHA convention, to easily compare results. | |
self.in_proj_weight = in_proj.weight | |
self.in_proj_bias = in_proj.bias | |
if bias: | |
self.in_proj_bias.data.zero_() # Following Pytorch convention | |
self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias, **factory_kwargs) | |
if bias: | |
self.out_proj.bias.data.zero_() | |
else: | |
assert not qk_layer_norm | |
assert kv_repeat == 1 | |
self.mha = nn.MultiheadAttention( | |
embed_dim, num_heads, dropout=dropout, bias=bias, batch_first=True, | |
**factory_kwargs) | |
self.qk_layer_norm = qk_layer_norm | |
if qk_layer_norm: | |
assert self.custom | |
assert kv_repeat == 1 | |
ln_dim = embed_dim | |
self.q_layer_norm = nn.LayerNorm(ln_dim) | |
self.k_layer_norm = nn.LayerNorm(ln_dim) | |
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs): | |
if not self.custom: | |
# Support compat with regular MHA | |
keys = [n for n, _ in self.mha.named_parameters()] | |
for key in keys: | |
if prefix + key in state_dict: | |
state_dict[prefix + "mha." + key] = state_dict.pop(prefix + key) | |
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) | |
def _get_mask(self, current_steps: int, device: torch.device, dtype: torch.dtype): | |
# Return a causal mask, accounting for potentially stored past keys/values | |
# We actually return a bias for the attention score, as this has the same | |
# convention both in the builtin MHA in Pytorch, and Xformers functions. | |
if self.memory_efficient: | |
from xformers.ops import LowerTriangularMask | |
if current_steps == 1: | |
# If we only have one step, then we do not need a mask. | |
return None | |
elif 'past_keys' in self._streaming_state: | |
raise RuntimeError('Not supported at the moment') | |
else: | |
# Then we can safely use a lower triangular mask | |
return LowerTriangularMask() | |
if self._streaming_state: | |
past_keys = self._streaming_state['past_keys'] | |
past_steps = past_keys.shape[1] | |
else: | |
past_steps = 0 | |
queries_pos = torch.arange( | |
past_steps, current_steps + past_steps, device=device).view(-1, 1) | |
keys_pos = torch.arange(past_steps + current_steps, device=device).view(1, -1) | |
delta = queries_pos - keys_pos | |
valid = delta >= 0 | |
if self.past_context is not None: | |
valid &= (delta <= self.past_context) | |
return torch.where( | |
valid, | |
torch.zeros([], device=device, dtype=dtype), | |
torch.full([], float('-inf'), device=device, dtype=dtype)) | |
def _complete_kv(self, k, v): | |
if self.cross_attention: | |
# With cross attention we assume all keys and values | |
# are already available, and streaming is with respect | |
# to the queries only. | |
return k, v | |
# Complete the key/value pair using the streaming state. | |
if self._streaming_state: | |
pk = self._streaming_state['past_keys'] | |
nk = torch.cat([pk, k], dim=1) | |
if v is k: | |
nv = nk | |
else: | |
pv = self._streaming_state['past_values'] | |
nv = torch.cat([pv, v], dim=1) | |
else: | |
nk = k | |
nv = v | |
assert nk.shape[1] == nv.shape[1] | |
offset = 0 | |
if self.past_context is not None: | |
offset = max(0, nk.shape[1] - self.past_context) | |
if self._is_streaming: | |
self._streaming_state['past_keys'] = nk[:, offset:] | |
if v is not k: | |
self._streaming_state['past_values'] = nv[:, offset:] | |
if 'offset' in self._streaming_state: | |
self._streaming_state['offset'] += offset | |
else: | |
self._streaming_state['offset'] = torch.tensor(0) | |
return nk, nv | |
def _apply_rope(self, query: torch.Tensor, key: torch.Tensor): | |
# Apply rope embeddings to query and key tensors. | |
assert self.rope is not None | |
if 'past_keys' in self._streaming_state: | |
past_keys_offset = self._streaming_state['past_keys'].shape[1] | |
else: | |
past_keys_offset = 0 | |
if 'offset' in self._streaming_state: | |
past_context_offset = int(self._streaming_state['offset'].item()) | |
else: | |
past_context_offset = 0 | |
streaming_offset = past_context_offset + past_keys_offset | |
return self.rope.rotate_qk(query, key, start=streaming_offset) | |
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, | |
key_padding_mask=None, need_weights=False, attn_mask=None, | |
average_attn_weights=True, is_causal=False): | |
assert attn_mask is None | |
assert not is_causal, ("new param added in torch 2.0.1 not supported, " | |
"use the causal args in the constructor.") | |
dtype = query.dtype | |
if self._is_streaming: | |
assert self.causal or self.cross_attention, \ | |
"Streaming only available for causal or cross attention" | |
if self.causal: | |
# At the moment we specialize only for the self-attention case. | |
assert query.shape[1] == key.shape[1], "Causal only for same length query / key / value" | |
assert value.shape[1] == key.shape[1], "Causal only for same length query / key / value" | |
attn_mask = self._get_mask(query.shape[1], query.device, query.dtype) | |
if self.custom: | |
# custom implementation | |
assert need_weights is False | |
assert key_padding_mask is None | |
if self.cross_attention: | |
# Different queries, keys, values, we have to spit manually the weights | |
# before applying the linear. | |
dim = self.in_proj_weight.shape[0] // 3 | |
if self.in_proj_bias is None: | |
bias_q, bias_k, bias_v = None, None, None | |
else: | |
bias_q = self.in_proj_bias[:dim] | |
bias_k = self.in_proj_bias[dim: 2 * dim] | |
bias_v = self.in_proj_bias[2 * dim:] | |
q = nn.functional.linear(query, self.in_proj_weight[:dim], bias_q) | |
# todo: when streaming, we could actually save k, v and check the shape actually match. | |
k = nn.functional.linear(key, self.in_proj_weight[dim: 2 * dim], bias_k) | |
v = nn.functional.linear(value, self.in_proj_weight[2 * dim:], bias_v) | |
if self.qk_layer_norm is True: | |
q = self.q_layer_norm(q) | |
k = self.k_layer_norm(k) | |
# q, k, v = [rearrange(x, "b t (h d) -> (b h) t d", h=self.num_heads) for x in [q, k, v]] | |
q, k, v = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k, v]] | |
else: | |
if not _is_profiled(): | |
# profiling breaks that propertysomehow. | |
assert query is key, "specialized implementation" | |
assert value is key, "specialized implementation" | |
projected = nn.functional.linear(query, self.in_proj_weight, self.in_proj_bias) | |
if self.kv_repeat == 1: | |
packed = rearrange(projected, "b t (p h d) -> b t p h d", p=3, h=self.num_heads) | |
q, k, v = ops.unbind(packed, dim=2) | |
else: | |
embed_dim = self.embed_dim | |
per_head_dim = (embed_dim // self.num_heads) | |
kv_heads = self.num_heads // self.kv_repeat | |
q = projected[:, :, :embed_dim] | |
start = embed_dim | |
end = start + per_head_dim * kv_heads | |
k = projected[:, :, start: end] | |
v = projected[:, :, end:] | |
q = rearrange(q, "b t (h d) -> b t h d", h=self.num_heads) | |
k = rearrange(k, "b t (h d) -> b t h d", h=kv_heads) | |
v = rearrange(v, "b t (h d) -> b t h d", h=kv_heads) | |
if self.qk_layer_norm is True: | |
assert self.kv_repeat == 1 | |
q, k = [rearrange(x, "b t h d -> b t (h d)") for x in [q, k]] | |
q = self.q_layer_norm(q) | |
k = self.k_layer_norm(k) | |
q, k = [rearrange(x, "b t (h d) -> b t h d", h=self.num_heads) for x in [q, k]] | |
if self.rope: | |
q, k = self._apply_rope(q, k) | |
k, v = self._complete_kv(k, v) | |
if self.kv_repeat > 1: | |
k = expand_repeated_kv(k, self.kv_repeat) | |
v = expand_repeated_kv(v, self.kv_repeat) | |
if self.attention_as_float32: | |
q, k, v = [x.float() for x in [q, k, v]] | |
if self.memory_efficient: | |
p = self.dropout if self.training else 0 | |
x = ops.memory_efficient_attention(q, k, v, attn_mask, p=p) | |
else: | |
# We include the dot product as float32, for consistency | |
# with the other implementations that include that step | |
# as part of the attention. Note that when using `autocast`, | |
# the einsums would be done as bfloat16, but the softmax | |
# would be done as bfloat16, so `attention_as_float32` will | |
# extend a bit the range of operations done in float32, | |
# although this should make no difference. | |
q = q / q.shape[-1] ** 0.5 | |
if self._is_streaming and self.safe_streaming and q.device.type == 'cuda': | |
with torch.autocast(device_type=q.device.type, dtype=torch.float32): | |
pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k) | |
else: | |
pre_w = torch.einsum("bqhc,bkhc->bhqk", q, k) | |
if attn_mask is not None: | |
pre_w = pre_w + attn_mask | |
w = torch.softmax(pre_w, dim=-1) | |
w = F.dropout(w, self.dropout, training=self.training).to(v) | |
x = torch.einsum("bhqk,bkhc->bqhc", w, v) | |
x = x.to(dtype) | |
x = rearrange(x, "b t h d -> b t (h d)", h=self.num_heads) | |
x = self.out_proj(x) | |
else: | |
key, value = self._complete_kv(key, value) | |
if self.attention_as_float32: | |
query, key, value = [x.float() for x in [query, key, value]] | |
x, _ = self.mha( | |
query, key, value, key_padding_mask, | |
need_weights, attn_mask, average_attn_weights) | |
x = x.to(dtype) | |
return x, None | |
class StreamingTransformerLayer(nn.TransformerEncoderLayer): | |
"""TransformerLayer with Streaming / Causal support. | |
This also integrates cross_attention, when passing `cross_attention=True`, | |
rather than having two separate classes like in PyTorch. | |
Args: | |
d_model (int): Dimension of the data. | |
num_heads (int): Number of heads. | |
dim_feedforward (int): Intermediate dimension of FF module. | |
dropout (float): Dropout both for MHA and FF. | |
bias_ff (bool): Use bias for FF. | |
bias_attn (bool): Use bias for MHA. | |
causal (bool): Causal mask applied automatically. | |
past_context (int or None): Receptive field for the causal mask, infinite if None. | |
custom (bool): Use custom MHA implementation, for testing / benchmarking. | |
memory_efficient (bool): Use xformers based memory efficient attention. | |
attention_as_float32 (bool): Perform the attention as float32 | |
(especially important with memory_efficient as autocast won't do this automatically). | |
qk_layer_norm (bool): Layer normalization applied to queries and keys before dot product in attention. | |
qk_layer_norm_cross (bool): Same for the cross attention. | |
cross_attention (bool): If True, expect to get secondary input for cross-attention. | |
Cross attention will use the default MHA, as it typically won't require | |
special treatment. | |
layer_scale (float or None): If not None, LayerScale will be used with | |
the given value as initial scale. | |
rope (`RotaryEmbedding` or None): Rope embedding to use. | |
attention_dropout (float or None): If not None, separate the value of the dimension dropout | |
in FFN and of the attention dropout. | |
kv_repeat (int): If > 1, will repeat keys and queries multiple times (need to divide num_heads). | |
This will lead to faster decoding time on A100 or other GPUs with tensorcore. | |
device (torch.device or None): Device on which to initialize. | |
dtype (torch.dtype or None): dtype to use. | |
**kwargs: See `nn.TransformerEncoderLayer`. | |
""" | |
def __init__(self, d_model: int, num_heads: int, dim_feedforward: int = 2048, dropout: float = 0.1, | |
bias_ff: bool = True, bias_attn: bool = True, causal: bool = False, | |
past_context: tp.Optional[int] = None, custom: bool = False, | |
memory_efficient: bool = False, attention_as_float32: bool = False, | |
qk_layer_norm: bool = False, qk_layer_norm_cross: bool = False, | |
cross_attention: bool = False, layer_scale: tp.Optional[float] = None, | |
rope: tp.Optional[RotaryEmbedding] = None, attention_dropout: tp.Optional[float] = None, | |
kv_repeat: int = 1, norm: str = 'layer_norm', device=None, dtype=None, **kwargs): | |
super().__init__(d_model, num_heads, dim_feedforward, dropout, | |
device=device, dtype=dtype, batch_first=True, **kwargs) | |
factory_kwargs = {'device': device, 'dtype': dtype} | |
# Redefine self_attn to our streaming multi-head attention | |
attn_kwargs: tp.Dict[str, tp.Any] = { | |
'embed_dim': d_model, | |
'num_heads': num_heads, | |
'dropout': dropout if attention_dropout is None else attention_dropout, | |
'bias': bias_attn, | |
'custom': custom, | |
'memory_efficient': memory_efficient, | |
'attention_as_float32': attention_as_float32, | |
} | |
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention( | |
causal=causal, past_context=past_context, rope=rope, qk_layer_norm=qk_layer_norm, | |
kv_repeat=kv_repeat, **attn_kwargs, **factory_kwargs) # type: ignore | |
# Redefine feedforward layers to expose bias parameter | |
self.linear1 = nn.Linear(d_model, dim_feedforward, bias=bias_ff, **factory_kwargs) | |
self.linear2 = nn.Linear(dim_feedforward, d_model, bias=bias_ff, **factory_kwargs) | |
self.layer_scale_1: nn.Module | |
self.layer_scale_2: nn.Module | |
if layer_scale is None: | |
self.layer_scale_1 = nn.Identity() | |
self.layer_scale_2 = nn.Identity() | |
else: | |
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) | |
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) | |
self.cross_attention: tp.Optional[nn.Module] = None | |
if cross_attention: | |
self.cross_attention = StreamingMultiheadAttention( | |
cross_attention=True, qk_layer_norm=qk_layer_norm_cross, | |
**attn_kwargs, **factory_kwargs) | |
# Norm and dropout | |
self.dropout_cross = nn.Dropout(dropout) | |
# eps value matching that used in PyTorch reference implementation. | |
self.norm_cross = nn.LayerNorm(d_model, eps=1e-5, **factory_kwargs) | |
self.layer_scale_cross: nn.Module | |
if layer_scale is None: | |
self.layer_scale_cross = nn.Identity() | |
else: | |
self.layer_scale_cross = LayerScale(d_model, layer_scale, **factory_kwargs) | |
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore | |
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs) # type: ignore | |
def _cross_attention_block(self, src: torch.Tensor, | |
cross_attention_src: torch.Tensor) -> torch.Tensor: | |
assert self.cross_attention is not None | |
# queries are from src, keys and values from cross_attention_src. | |
x = self.cross_attention( | |
src, cross_attention_src, cross_attention_src, need_weights=False)[0] | |
return self.dropout_cross(x) # type: ignore | |
def forward(self, src: torch.Tensor, src_mask: tp.Optional[torch.Tensor] = None, # type: ignore | |
src_key_padding_mask: tp.Optional[torch.Tensor] = None, | |
cross_attention_src: tp.Optional[torch.Tensor] = None): | |
if self.cross_attention is None: | |
assert cross_attention_src is None | |
else: | |
assert cross_attention_src is not None | |
x = src | |
if self.norm_first: | |
x = x + self.layer_scale_1( | |
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask)) | |
if cross_attention_src is not None: | |
x = x + self.layer_scale_cross( | |
self._cross_attention_block( | |
self.norm_cross(x), cross_attention_src)) | |
x = x + self.layer_scale_2(self._ff_block(self.norm2(x))) | |
else: | |
x = self.norm1(x + self.layer_scale_1( | |
self._sa_block(x, src_mask, src_key_padding_mask))) | |
if cross_attention_src is not None: | |
x = self.norm_cross( | |
x + self.layer_scale_cross( | |
self._cross_attention_block(src, cross_attention_src))) | |
x = self.norm2(x + self.layer_scale_2(self._ff_block(x))) | |
return x | |
class StreamingTransformer(StreamingModule): | |
"""Transformer with Streaming / Causal support. | |
Args: | |
d_model (int): Dimension of the data. | |
num_heads (int): Number of heads. | |
dim_feedforward (int): Intermediate dimension of FF module. | |
dropout (float): Dropout both for MHA and FF. | |
bias_ff (bool): Use bias for FF. | |
bias_attn (bool): Use bias for MHA. | |
causal (bool): Causal mask applied automatically. | |
past_context (int or None): Receptive field for the causal mask, infinite if None. | |
custom (bool): Use custom MHA implementation, for testing / benchmarking. | |
memory_efficient (bool): Use xformers based memory efficient attention. | |
attention_as_float32 (bool): Perform the attention as float32 | |
(especially important with memory_efficient as autocast won't do this automatically). | |
cross_attention (bool): If True, expect to get secondary input for cross-attention. | |
layer_scale (float or None): If not None, LayerScale will be used | |
with the given value as initial scale. | |
positional_embedding (str): Positional embedding strategy (sin, rope, or sin_rope). | |
max_period (float): Maximum period of the time embedding. | |
positional_scale (float): Scale of positional embedding, set to 0 to deactivate. | |
xpos (bool): Apply xpos exponential decay to positional embedding (rope only). | |
lr (float or None): learning rate override through the `make_optim_group` API. | |
weight_decay (float or None): Weight_decay override through the `make_optim_group` API. | |
layer_class: (subclass of `StreamingTransformerLayer): class to use | |
to initialize the layers, allowing further customization outside of Audiocraft. | |
checkpointing (str): Checkpointing strategy to reduce memory usage. | |
No checkpointing if set to 'none'. Per layer checkpointing using PyTorch | |
if set to 'torch' (entire layer checkpointed, i.e. linears are evaluated twice, | |
minimal memory usage, but maximal runtime). Finally, `xformers_default` provide | |
a policy for opting-out some operations of the checkpointing like | |
linear layers and attention, providing a middle ground between speed and memory. | |
device (torch.device or None): Device on which to initialize. | |
dtype (torch.dtype or None): dtype to use. | |
**kwargs: See `nn.TransformerEncoderLayer`. | |
""" | |
def __init__(self, d_model: int, num_heads: int, num_layers: int, dim_feedforward: int = 2048, | |
dropout: float = 0.1, bias_ff: bool = True, bias_attn: bool = True, | |
causal: bool = False, past_context: tp.Optional[int] = None, | |
custom: bool = False, memory_efficient: bool = False, attention_as_float32: bool = False, | |
cross_attention: bool = False, layer_scale: tp.Optional[float] = None, | |
positional_embedding: str = 'sin', max_period: float = 10_000, positional_scale: float = 1., | |
xpos: bool = False, lr: tp.Optional[float] = None, weight_decay: tp.Optional[float] = None, | |
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer, | |
checkpointing: str = 'none', device=None, dtype=None, **kwargs): | |
super().__init__() | |
assert d_model % num_heads == 0 | |
self.positional_embedding = positional_embedding | |
self.max_period = max_period | |
self.positional_scale = positional_scale | |
self.weight_decay = weight_decay | |
self.lr = lr | |
assert positional_embedding in ['sin', 'rope', 'sin_rope'] | |
self.rope: tp.Optional[RotaryEmbedding] = None | |
if self.positional_embedding in ['rope', 'sin_rope']: | |
assert _is_custom(custom, memory_efficient) | |
self.rope = RotaryEmbedding(d_model // num_heads, max_period=max_period, | |
xpos=xpos, scale=positional_scale, device=device) | |
self.checkpointing = checkpointing | |
assert checkpointing in ['none', 'torch', 'xformers_default', 'xformers_mm'] | |
if self.checkpointing.startswith('xformers'): | |
_verify_xformers_internal_compat() | |
self.layers = nn.ModuleList() | |
for idx in range(num_layers): | |
self.layers.append( | |
layer_class( | |
d_model=d_model, num_heads=num_heads, dim_feedforward=dim_feedforward, | |
dropout=dropout, bias_ff=bias_ff, bias_attn=bias_attn, | |
causal=causal, past_context=past_context, custom=custom, | |
memory_efficient=memory_efficient, attention_as_float32=attention_as_float32, | |
cross_attention=cross_attention, layer_scale=layer_scale, rope=self.rope, | |
device=device, dtype=dtype, **kwargs)) | |
if self.checkpointing != 'none': | |
for layer in self.layers: | |
# see audiocraft/optim/fsdp.py, magic signal to indicate this requires fixing the | |
# backward hook inside of FSDP... | |
layer._magma_checkpointed = True # type: ignore | |
assert layer.layer_drop == 0., "Need further checking" # type: ignore | |
def _apply_layer(self, layer, *args, **kwargs): | |
method = self.checkpointing | |
if method == 'none': | |
return layer(*args, **kwargs) | |
elif method == 'torch': | |
return torch_checkpoint(layer, *args, use_reentrant=False, **kwargs) | |
elif method.startswith('xformers'): | |
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy | |
if method == 'xformers_default': | |
# those operations will be saved, and not recomputed. | |
# According to Francisco we can get smarter policies but this is a good start. | |
allow_list = [ | |
"xformers.efficient_attention_forward_cutlass.default", | |
"xformers_flash.flash_fwd.default", | |
"aten.addmm.default", | |
"aten.mm.default", | |
] | |
elif method == 'xformers_mm': | |
# those operations will be saved, and not recomputed. | |
# According to Francisco we can get smarter policies but this is a good start. | |
allow_list = [ | |
"aten.addmm.default", | |
"aten.mm.default", | |
] | |
else: | |
raise ValueError(f"xformers checkpointing xformers policy {method} is not known.") | |
policy_fn = _get_default_policy(allow_list) | |
return checkpoint(layer, *args, policy_fn=policy_fn, **kwargs) | |
else: | |
raise ValueError(f"Checkpointing method {method} is unknown.") | |
def forward(self, x: torch.Tensor, *args, **kwargs): | |
B, T, C = x.shape | |
if 'offsets' in self._streaming_state: | |
offsets = self._streaming_state['offsets'] | |
else: | |
offsets = torch.zeros(B, dtype=torch.long, device=x.device) | |
if self.positional_embedding in ['sin', 'sin_rope']: | |
positions = torch.arange(T, device=x.device).view(1, -1, 1) | |
positions = positions + offsets.view(-1, 1, 1) | |
pos_emb = create_sin_embedding(positions, C, max_period=self.max_period, dtype=x.dtype) | |
x = x + self.positional_scale * pos_emb | |
for layer in self.layers: | |
x = self._apply_layer(layer, x, *args, **kwargs) | |
if self._is_streaming: | |
self._streaming_state['offsets'] = offsets + T | |
return x | |
def make_optim_group(self): | |
group = {"params": list(self.parameters())} | |
if self.lr is not None: | |
group["lr"] = self.lr | |
if self.weight_decay is not None: | |
group["weight_decay"] = self.weight_decay | |
return group | |
# special attention attention related function | |
def _verify_xformers_memory_efficient_compat(): | |
try: | |
from xformers.ops import memory_efficient_attention, LowerTriangularMask # noqa | |
except ImportError: | |
raise ImportError( | |
"xformers is not installed. Please install it and try again.\n" | |
"To install on AWS and Azure, run \n" | |
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" | |
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" | |
"To install on FAIR Cluster, run \n" | |
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" | |
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") | |
def _verify_xformers_internal_compat(): | |
try: | |
from xformers.checkpoint_fairinternal import checkpoint, _get_default_policy # noqa | |
except ImportError: | |
raise ImportError( | |
"Francisco's fairinternal xformers is not installed. Please install it and try again.\n" | |
"To install on AWS and Azure, run \n" | |
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='8.0'\\\n" | |
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n" | |
"To install on FAIR Cluster, run \n" | |
"FORCE_CUDA=1 TORCH_CUDA_ARCH_LIST='6.0;7.0'\\\n" | |
"pip install -U git+https://git@github.com/fairinternal/xformers.git#egg=xformers\n") | |
def _is_custom(custom: bool, memory_efficient: bool): | |
return custom or memory_efficient | |