|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import random |
|
import typing as tp |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import numpy as np |
|
import math |
|
from einops import rearrange |
|
|
|
|
|
def create_sin_embedding( |
|
length: int, dim: int, shift: int = 0, device="cpu", max_period=10000 |
|
): |
|
|
|
assert dim % 2 == 0 |
|
pos = shift + torch.arange(length, device=device).view(-1, 1, 1) |
|
half_dim = dim // 2 |
|
adim = torch.arange(dim // 2, device=device).view(1, 1, -1) |
|
phase = pos / (max_period ** (adim / (half_dim - 1))) |
|
return torch.cat( |
|
[ |
|
torch.cos(phase), |
|
torch.sin(phase), |
|
], |
|
dim=-1, |
|
) |
|
|
|
|
|
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000): |
|
""" |
|
:param d_model: dimension of the model |
|
:param height: height of the positions |
|
:param width: width of the positions |
|
:return: d_model*height*width position matrix |
|
""" |
|
if d_model % 4 != 0: |
|
raise ValueError( |
|
"Cannot use sin/cos positional encoding with " |
|
"odd dimension (got dim={:d})".format(d_model) |
|
) |
|
pe = torch.zeros(d_model, height, width) |
|
|
|
d_model = int(d_model / 2) |
|
div_term = torch.exp( |
|
torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model) |
|
) |
|
pos_w = torch.arange(0.0, width).unsqueeze(1) |
|
pos_h = torch.arange(0.0, height).unsqueeze(1) |
|
pe[0:d_model:2, :, :] = ( |
|
torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) |
|
) |
|
pe[1:d_model:2, :, :] = ( |
|
torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1) |
|
) |
|
pe[d_model::2, :, :] = ( |
|
torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) |
|
) |
|
pe[d_model + 1:: 2, :, :] = ( |
|
torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width) |
|
) |
|
|
|
return pe[None, :].to(device) |
|
|
|
|
|
def create_sin_embedding_cape( |
|
length: int, |
|
dim: int, |
|
batch_size: int, |
|
mean_normalize: bool, |
|
augment: bool, |
|
max_global_shift: float = 0.0, |
|
max_local_shift: float = 0.0, |
|
max_scale: float = 1.0, |
|
device: str = "cpu", |
|
max_period: float = 10000.0, |
|
): |
|
|
|
assert dim % 2 == 0 |
|
pos = 1.0 * torch.arange(length).view(-1, 1, 1) |
|
pos = pos.repeat(1, batch_size, 1) |
|
if mean_normalize: |
|
pos -= torch.nanmean(pos, dim=0, keepdim=True) |
|
|
|
if augment: |
|
delta = np.random.uniform( |
|
-max_global_shift, +max_global_shift, size=[1, batch_size, 1] |
|
) |
|
delta_local = np.random.uniform( |
|
-max_local_shift, +max_local_shift, size=[length, batch_size, 1] |
|
) |
|
log_lambdas = np.random.uniform( |
|
-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1] |
|
) |
|
pos = (pos + delta + delta_local) * np.exp(log_lambdas) |
|
|
|
pos = pos.to(device) |
|
|
|
half_dim = dim // 2 |
|
adim = torch.arange(dim // 2, device=device).view(1, 1, -1) |
|
phase = pos / (max_period ** (adim / (half_dim - 1))) |
|
return torch.cat( |
|
[ |
|
torch.cos(phase), |
|
torch.sin(phase), |
|
], |
|
dim=-1, |
|
).float() |
|
|
|
|
|
def get_causal_mask(length): |
|
pos = torch.arange(length) |
|
return pos > pos[:, None] |
|
|
|
|
|
def get_elementary_mask( |
|
T1, |
|
T2, |
|
mask_type, |
|
sparse_attn_window, |
|
global_window, |
|
mask_random_seed, |
|
sparsity, |
|
device, |
|
): |
|
""" |
|
When the input of the Decoder has length T1 and the output T2 |
|
The mask matrix has shape (T2, T1) |
|
""" |
|
assert mask_type in ["diag", "jmask", "random", "global"] |
|
|
|
if mask_type == "global": |
|
mask = torch.zeros(T2, T1, dtype=torch.bool) |
|
mask[:, :global_window] = True |
|
line_window = int(global_window * T2 / T1) |
|
mask[:line_window, :] = True |
|
|
|
if mask_type == "diag": |
|
|
|
mask = torch.zeros(T2, T1, dtype=torch.bool) |
|
rows = torch.arange(T2)[:, None] |
|
cols = ( |
|
(T1 / T2 * rows + torch.arange(-sparse_attn_window, sparse_attn_window + 1)) |
|
.long() |
|
.clamp(0, T1 - 1) |
|
) |
|
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) |
|
|
|
elif mask_type == "jmask": |
|
mask = torch.zeros(T2 + 2, T1 + 2, dtype=torch.bool) |
|
rows = torch.arange(T2 + 2)[:, None] |
|
t = torch.arange(0, int((2 * T1) ** 0.5 + 1)) |
|
t = (t * (t + 1) / 2).int() |
|
t = torch.cat([-t.flip(0)[:-1], t]) |
|
cols = (T1 / T2 * rows + t).long().clamp(0, T1 + 1) |
|
mask.scatter_(1, cols, torch.ones(1, dtype=torch.bool).expand_as(cols)) |
|
mask = mask[1:-1, 1:-1] |
|
|
|
elif mask_type == "random": |
|
gene = torch.Generator(device=device) |
|
gene.manual_seed(mask_random_seed) |
|
mask = ( |
|
torch.rand(T1 * T2, generator=gene, device=device).reshape(T2, T1) |
|
> sparsity |
|
) |
|
|
|
mask = mask.to(device) |
|
return mask |
|
|
|
|
|
def get_mask( |
|
T1, |
|
T2, |
|
mask_type, |
|
sparse_attn_window, |
|
global_window, |
|
mask_random_seed, |
|
sparsity, |
|
device, |
|
): |
|
""" |
|
Return a SparseCSRTensor mask that is a combination of elementary masks |
|
mask_type can be a combination of multiple masks: for instance "diag_jmask_random" |
|
""" |
|
from xformers.sparse import SparseCSRTensor |
|
|
|
mask_types = mask_type.split("_") |
|
|
|
all_masks = [ |
|
get_elementary_mask( |
|
T1, |
|
T2, |
|
mask, |
|
sparse_attn_window, |
|
global_window, |
|
mask_random_seed, |
|
sparsity, |
|
device, |
|
) |
|
for mask in mask_types |
|
] |
|
|
|
final_mask = torch.stack(all_masks).sum(axis=0) > 0 |
|
|
|
return SparseCSRTensor.from_dense(final_mask[None]) |
|
|
|
|
|
class ScaledEmbedding(nn.Module): |
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
scale: float = 1.0, |
|
boost: float = 3.0, |
|
): |
|
super().__init__() |
|
self.embedding = nn.Embedding(num_embeddings, embedding_dim) |
|
self.embedding.weight.data *= scale / boost |
|
self.boost = boost |
|
|
|
@property |
|
def weight(self): |
|
return self.embedding.weight * self.boost |
|
|
|
def forward(self, x): |
|
return self.embedding(x) * self.boost |
|
|
|
|
|
class LayerScale(nn.Module): |
|
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf). |
|
This rescales diagonaly residual outputs close to 0 initially, then learnt. |
|
""" |
|
|
|
def __init__(self, channels: int, init: float = 0, channel_last=False): |
|
""" |
|
channel_last = False corresponds to (B, C, T) tensors |
|
channel_last = True corresponds to (T, B, C) tensors |
|
""" |
|
super().__init__() |
|
self.channel_last = channel_last |
|
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True)) |
|
self.scale.data[:] = init |
|
|
|
def forward(self, x): |
|
if self.channel_last: |
|
return self.scale * x |
|
else: |
|
return self.scale[:, None] * x |
|
|
|
|
|
class MyGroupNorm(nn.GroupNorm): |
|
def __init__(self, *args, **kwargs): |
|
super().__init__(*args, **kwargs) |
|
|
|
def forward(self, x): |
|
""" |
|
x: (B, T, C) |
|
if num_groups=1: Normalisation on all T and C together for each B |
|
""" |
|
x = x.transpose(1, 2) |
|
return super().forward(x).transpose(1, 2) |
|
|
|
|
|
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer): |
|
def __init__( |
|
self, |
|
d_model, |
|
nhead, |
|
dim_feedforward=2048, |
|
dropout=0.1, |
|
activation=F.relu, |
|
group_norm=0, |
|
norm_first=False, |
|
norm_out=False, |
|
layer_norm_eps=1e-5, |
|
layer_scale=False, |
|
init_values=1e-4, |
|
device=None, |
|
dtype=None, |
|
sparse=False, |
|
mask_type="diag", |
|
mask_random_seed=42, |
|
sparse_attn_window=500, |
|
global_window=50, |
|
auto_sparsity=False, |
|
sparsity=0.95, |
|
batch_first=False, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__( |
|
d_model=d_model, |
|
nhead=nhead, |
|
dim_feedforward=dim_feedforward, |
|
dropout=dropout, |
|
activation=activation, |
|
layer_norm_eps=layer_norm_eps, |
|
batch_first=batch_first, |
|
norm_first=norm_first, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
self.sparse = sparse |
|
self.auto_sparsity = auto_sparsity |
|
if sparse: |
|
if not auto_sparsity: |
|
self.mask_type = mask_type |
|
self.sparse_attn_window = sparse_attn_window |
|
self.global_window = global_window |
|
self.sparsity = sparsity |
|
if group_norm: |
|
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) |
|
|
|
self.norm_out = None |
|
if self.norm_first & norm_out: |
|
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) |
|
self.gamma_1 = ( |
|
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() |
|
) |
|
self.gamma_2 = ( |
|
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() |
|
) |
|
|
|
if sparse: |
|
self.self_attn = MultiheadAttention( |
|
d_model, nhead, dropout=dropout, batch_first=batch_first, |
|
auto_sparsity=sparsity if auto_sparsity else 0, |
|
) |
|
self.__setattr__("src_mask", torch.zeros(1, 1)) |
|
self.mask_random_seed = mask_random_seed |
|
|
|
def forward(self, src, src_mask=None, src_key_padding_mask=None): |
|
""" |
|
if batch_first = False, src shape is (T, B, C) |
|
the case where batch_first=True is not covered |
|
""" |
|
device = src.device |
|
x = src |
|
T, B, C = x.shape |
|
if self.sparse and not self.auto_sparsity: |
|
assert src_mask is None |
|
src_mask = self.src_mask |
|
if src_mask.shape[-1] != T: |
|
src_mask = get_mask( |
|
T, |
|
T, |
|
self.mask_type, |
|
self.sparse_attn_window, |
|
self.global_window, |
|
self.mask_random_seed, |
|
self.sparsity, |
|
device, |
|
) |
|
self.__setattr__("src_mask", src_mask) |
|
|
|
if self.norm_first: |
|
x = x + self.gamma_1( |
|
self._sa_block(self.norm1(x), src_mask, src_key_padding_mask) |
|
) |
|
x = x + self.gamma_2(self._ff_block(self.norm2(x))) |
|
|
|
if self.norm_out: |
|
x = self.norm_out(x) |
|
else: |
|
x = self.norm1( |
|
x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)) |
|
) |
|
x = self.norm2(x + self.gamma_2(self._ff_block(x))) |
|
|
|
return x |
|
|
|
|
|
class CrossTransformerEncoderLayer(nn.Module): |
|
def __init__( |
|
self, |
|
d_model: int, |
|
nhead: int, |
|
dim_feedforward: int = 2048, |
|
dropout: float = 0.1, |
|
activation=F.relu, |
|
layer_norm_eps: float = 1e-5, |
|
layer_scale: bool = False, |
|
init_values: float = 1e-4, |
|
norm_first: bool = False, |
|
group_norm: bool = False, |
|
norm_out: bool = False, |
|
sparse=False, |
|
mask_type="diag", |
|
mask_random_seed=42, |
|
sparse_attn_window=500, |
|
global_window=50, |
|
sparsity=0.95, |
|
auto_sparsity=None, |
|
device=None, |
|
dtype=None, |
|
batch_first=False, |
|
): |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
|
|
self.sparse = sparse |
|
self.auto_sparsity = auto_sparsity |
|
if sparse: |
|
if not auto_sparsity: |
|
self.mask_type = mask_type |
|
self.sparse_attn_window = sparse_attn_window |
|
self.global_window = global_window |
|
self.sparsity = sparsity |
|
|
|
self.cross_attn: nn.Module |
|
self.cross_attn = nn.MultiheadAttention( |
|
d_model, nhead, dropout=dropout, batch_first=batch_first) |
|
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs) |
|
self.dropout = nn.Dropout(dropout) |
|
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs) |
|
|
|
self.norm_first = norm_first |
|
self.norm1: nn.Module |
|
self.norm2: nn.Module |
|
self.norm3: nn.Module |
|
if group_norm: |
|
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs) |
|
else: |
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs) |
|
|
|
self.norm_out = None |
|
if self.norm_first & norm_out: |
|
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model) |
|
|
|
self.gamma_1 = ( |
|
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() |
|
) |
|
self.gamma_2 = ( |
|
LayerScale(d_model, init_values, True) if layer_scale else nn.Identity() |
|
) |
|
|
|
self.dropout1 = nn.Dropout(dropout) |
|
self.dropout2 = nn.Dropout(dropout) |
|
|
|
|
|
if isinstance(activation, str): |
|
self.activation = self._get_activation_fn(activation) |
|
else: |
|
self.activation = activation |
|
|
|
if sparse: |
|
self.cross_attn = MultiheadAttention( |
|
d_model, nhead, dropout=dropout, batch_first=batch_first, |
|
auto_sparsity=sparsity if auto_sparsity else 0) |
|
if not auto_sparsity: |
|
self.__setattr__("mask", torch.zeros(1, 1)) |
|
self.mask_random_seed = mask_random_seed |
|
|
|
def forward(self, q, k, mask=None): |
|
""" |
|
Args: |
|
q: tensor of shape (T, B, C) |
|
k: tensor of shape (S, B, C) |
|
mask: tensor of shape (T, S) |
|
|
|
""" |
|
device = q.device |
|
T, B, C = q.shape |
|
S, B, C = k.shape |
|
if self.sparse and not self.auto_sparsity: |
|
assert mask is None |
|
mask = self.mask |
|
if mask.shape[-1] != S or mask.shape[-2] != T: |
|
mask = get_mask( |
|
S, |
|
T, |
|
self.mask_type, |
|
self.sparse_attn_window, |
|
self.global_window, |
|
self.mask_random_seed, |
|
self.sparsity, |
|
device, |
|
) |
|
self.__setattr__("mask", mask) |
|
|
|
if self.norm_first: |
|
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask)) |
|
x = x + self.gamma_2(self._ff_block(self.norm3(x))) |
|
if self.norm_out: |
|
x = self.norm_out(x) |
|
else: |
|
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask))) |
|
x = self.norm2(x + self.gamma_2(self._ff_block(x))) |
|
|
|
return x |
|
|
|
|
|
def _ca_block(self, q, k, attn_mask=None): |
|
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0] |
|
return self.dropout1(x) |
|
|
|
|
|
def _ff_block(self, x): |
|
x = self.linear2(self.dropout(self.activation(self.linear1(x)))) |
|
return self.dropout2(x) |
|
|
|
def _get_activation_fn(self, activation): |
|
if activation == "relu": |
|
return F.relu |
|
elif activation == "gelu": |
|
return F.gelu |
|
|
|
raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) |
|
|
|
|
|
|
|
|
|
|
|
class CrossTransformerEncoder(nn.Module): |
|
def __init__( |
|
self, |
|
dim: int, |
|
emb: str = "sin", |
|
hidden_scale: float = 4.0, |
|
num_heads: int = 8, |
|
num_layers: int = 6, |
|
cross_first: bool = False, |
|
dropout: float = 0.0, |
|
max_positions: int = 1000, |
|
norm_in: bool = True, |
|
norm_in_group: bool = False, |
|
group_norm: int = False, |
|
norm_first: bool = False, |
|
norm_out: bool = False, |
|
max_period: float = 10000.0, |
|
weight_decay: float = 0.0, |
|
lr: tp.Optional[float] = None, |
|
layer_scale: bool = False, |
|
gelu: bool = True, |
|
sin_random_shift: int = 0, |
|
weight_pos_embed: float = 1.0, |
|
cape_mean_normalize: bool = True, |
|
cape_augment: bool = True, |
|
cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], |
|
sparse_self_attn: bool = False, |
|
sparse_cross_attn: bool = False, |
|
mask_type: str = "diag", |
|
mask_random_seed: int = 42, |
|
sparse_attn_window: int = 500, |
|
global_window: int = 50, |
|
auto_sparsity: bool = False, |
|
sparsity: float = 0.95, |
|
): |
|
super().__init__() |
|
""" |
|
""" |
|
assert dim % num_heads == 0 |
|
|
|
hidden_dim = int(dim * hidden_scale) |
|
|
|
self.num_layers = num_layers |
|
|
|
|
|
self.classic_parity = 1 if cross_first else 0 |
|
self.emb = emb |
|
self.max_period = max_period |
|
self.weight_decay = weight_decay |
|
self.weight_pos_embed = weight_pos_embed |
|
self.sin_random_shift = sin_random_shift |
|
if emb == "cape": |
|
self.cape_mean_normalize = cape_mean_normalize |
|
self.cape_augment = cape_augment |
|
self.cape_glob_loc_scale = cape_glob_loc_scale |
|
if emb == "scaled": |
|
self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2) |
|
|
|
self.lr = lr |
|
|
|
activation: tp.Any = F.gelu if gelu else F.relu |
|
|
|
self.norm_in: nn.Module |
|
self.norm_in_t: nn.Module |
|
if norm_in: |
|
self.norm_in = nn.LayerNorm(dim) |
|
self.norm_in_t = nn.LayerNorm(dim) |
|
elif norm_in_group: |
|
self.norm_in = MyGroupNorm(int(norm_in_group), dim) |
|
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim) |
|
else: |
|
self.norm_in = nn.Identity() |
|
self.norm_in_t = nn.Identity() |
|
|
|
|
|
self.layers = nn.ModuleList() |
|
|
|
self.layers_t = nn.ModuleList() |
|
|
|
kwargs_common = { |
|
"d_model": dim, |
|
"nhead": num_heads, |
|
"dim_feedforward": hidden_dim, |
|
"dropout": dropout, |
|
"activation": activation, |
|
"group_norm": group_norm, |
|
"norm_first": norm_first, |
|
"norm_out": norm_out, |
|
"layer_scale": layer_scale, |
|
"mask_type": mask_type, |
|
"mask_random_seed": mask_random_seed, |
|
"sparse_attn_window": sparse_attn_window, |
|
"global_window": global_window, |
|
"sparsity": sparsity, |
|
"auto_sparsity": auto_sparsity, |
|
"batch_first": True, |
|
} |
|
|
|
kwargs_classic_encoder = dict(kwargs_common) |
|
kwargs_classic_encoder.update({ |
|
"sparse": sparse_self_attn, |
|
}) |
|
kwargs_cross_encoder = dict(kwargs_common) |
|
kwargs_cross_encoder.update({ |
|
"sparse": sparse_cross_attn, |
|
}) |
|
|
|
for idx in range(num_layers): |
|
if idx % 2 == self.classic_parity: |
|
|
|
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder)) |
|
self.layers_t.append( |
|
MyTransformerEncoderLayer(**kwargs_classic_encoder) |
|
) |
|
|
|
else: |
|
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder)) |
|
|
|
self.layers_t.append( |
|
CrossTransformerEncoderLayer(**kwargs_cross_encoder) |
|
) |
|
|
|
def forward(self, x, xt): |
|
B, C, Fr, T1 = x.shape |
|
pos_emb_2d = create_2d_sin_embedding( |
|
C, Fr, T1, x.device, self.max_period |
|
) |
|
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c") |
|
x = rearrange(x, "b c fr t1 -> b (t1 fr) c") |
|
x = self.norm_in(x) |
|
x = x + self.weight_pos_embed * pos_emb_2d |
|
|
|
B, C, T2 = xt.shape |
|
xt = rearrange(xt, "b c t2 -> b t2 c") |
|
pos_emb = self._get_pos_embedding(T2, B, C, x.device) |
|
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c") |
|
xt = self.norm_in_t(xt) |
|
xt = xt + self.weight_pos_embed * pos_emb |
|
|
|
for idx in range(self.num_layers): |
|
if idx % 2 == self.classic_parity: |
|
x = self.layers[idx](x) |
|
xt = self.layers_t[idx](xt) |
|
else: |
|
old_x = x |
|
x = self.layers[idx](x, xt) |
|
xt = self.layers_t[idx](xt, old_x) |
|
|
|
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1) |
|
xt = rearrange(xt, "b t2 c -> b c t2") |
|
return x, xt |
|
|
|
def _get_pos_embedding(self, T, B, C, device): |
|
if self.emb == "sin": |
|
shift = random.randrange(self.sin_random_shift + 1) |
|
pos_emb = create_sin_embedding( |
|
T, C, shift=shift, device=device, max_period=self.max_period |
|
) |
|
elif self.emb == "cape": |
|
if self.training: |
|
pos_emb = create_sin_embedding_cape( |
|
T, |
|
C, |
|
B, |
|
device=device, |
|
max_period=self.max_period, |
|
mean_normalize=self.cape_mean_normalize, |
|
augment=self.cape_augment, |
|
max_global_shift=self.cape_glob_loc_scale[0], |
|
max_local_shift=self.cape_glob_loc_scale[1], |
|
max_scale=self.cape_glob_loc_scale[2], |
|
) |
|
else: |
|
pos_emb = create_sin_embedding_cape( |
|
T, |
|
C, |
|
B, |
|
device=device, |
|
max_period=self.max_period, |
|
mean_normalize=self.cape_mean_normalize, |
|
augment=False, |
|
) |
|
|
|
elif self.emb == "scaled": |
|
pos = torch.arange(T, device=device) |
|
pos_emb = self.position_embeddings(pos)[:, None] |
|
|
|
return pos_emb |
|
|
|
def make_optim_group(self): |
|
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay} |
|
if self.lr is not None: |
|
group["lr"] = self.lr |
|
return group |
|
|
|
|
|
|
|
|
|
|
|
class MultiheadAttention(nn.Module): |
|
def __init__( |
|
self, |
|
embed_dim, |
|
num_heads, |
|
dropout=0.0, |
|
bias=True, |
|
add_bias_kv=False, |
|
add_zero_attn=False, |
|
kdim=None, |
|
vdim=None, |
|
batch_first=False, |
|
auto_sparsity=None, |
|
): |
|
super().__init__() |
|
assert auto_sparsity is not None, "sanity check" |
|
self.num_heads = num_heads |
|
self.q = torch.nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.k = torch.nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.v = torch.nn.Linear(embed_dim, embed_dim, bias=bias) |
|
self.attn_drop = torch.nn.Dropout(dropout) |
|
self.proj = torch.nn.Linear(embed_dim, embed_dim, bias) |
|
self.proj_drop = torch.nn.Dropout(dropout) |
|
self.batch_first = batch_first |
|
self.auto_sparsity = auto_sparsity |
|
|
|
def forward( |
|
self, |
|
query, |
|
key, |
|
value, |
|
key_padding_mask=None, |
|
need_weights=True, |
|
attn_mask=None, |
|
average_attn_weights=True, |
|
): |
|
|
|
if not self.batch_first: |
|
query = query.permute(1, 0, 2) |
|
key = key.permute(1, 0, 2) |
|
value = value.permute(1, 0, 2) |
|
B, N_q, C = query.shape |
|
B, N_k, C = key.shape |
|
|
|
q = ( |
|
self.q(query) |
|
.reshape(B, N_q, self.num_heads, C // self.num_heads) |
|
.permute(0, 2, 1, 3) |
|
) |
|
q = q.flatten(0, 1) |
|
k = ( |
|
self.k(key) |
|
.reshape(B, N_k, self.num_heads, C // self.num_heads) |
|
.permute(0, 2, 1, 3) |
|
) |
|
k = k.flatten(0, 1) |
|
v = ( |
|
self.v(value) |
|
.reshape(B, N_k, self.num_heads, C // self.num_heads) |
|
.permute(0, 2, 1, 3) |
|
) |
|
v = v.flatten(0, 1) |
|
|
|
if self.auto_sparsity: |
|
assert attn_mask is None |
|
x = dynamic_sparse_attention(q, k, v, sparsity=self.auto_sparsity) |
|
else: |
|
x = scaled_dot_product_attention(q, k, v, attn_mask, dropout=self.attn_drop) |
|
x = x.reshape(B, self.num_heads, N_q, C // self.num_heads) |
|
|
|
x = x.transpose(1, 2).reshape(B, N_q, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
if not self.batch_first: |
|
x = x.permute(1, 0, 2) |
|
return x, None |
|
|
|
|
|
def scaled_query_key_softmax(q, k, att_mask): |
|
from xformers.ops import masked_matmul |
|
q = q / (k.size(-1)) ** 0.5 |
|
att = masked_matmul(q, k.transpose(-2, -1), att_mask) |
|
att = torch.nn.functional.softmax(att, -1) |
|
return att |
|
|
|
|
|
def scaled_dot_product_attention(q, k, v, att_mask, dropout): |
|
att = scaled_query_key_softmax(q, k, att_mask=att_mask) |
|
att = dropout(att) |
|
y = att @ v |
|
return y |
|
|
|
|
|
def _compute_buckets(x, R): |
|
qq = torch.einsum('btf,bfhi->bhti', x, R) |
|
qq = torch.cat([qq, -qq], dim=-1) |
|
buckets = qq.argmax(dim=-1) |
|
|
|
return buckets.permute(0, 2, 1).byte().contiguous() |
|
|
|
|
|
def dynamic_sparse_attention(query, key, value, sparsity, infer_sparsity=True, attn_bias=None): |
|
|
|
from xformers.ops import find_locations, sparse_memory_efficient_attention |
|
n_hashes = 32 |
|
proj_size = 4 |
|
query, key, value = [x.contiguous() for x in [query, key, value]] |
|
with torch.no_grad(): |
|
R = torch.randn(1, query.shape[-1], n_hashes, proj_size // 2, device=query.device) |
|
bucket_query = _compute_buckets(query, R) |
|
bucket_key = _compute_buckets(key, R) |
|
row_offsets, column_indices = find_locations( |
|
bucket_query, bucket_key, sparsity, infer_sparsity) |
|
return sparse_memory_efficient_attention( |
|
query, key, value, row_offsets, column_indices, attn_bias) |
|
|