Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import numpy as np | |
from math import sqrt | |
from utils.masking import TriangularCausalMask | |
class FullAttention(nn.Module): | |
def __init__( | |
self, mask_flag=True, scale=None, attention_dropout=0.1, output_attention=False | |
): | |
super(FullAttention, self).__init__() | |
self.scale = scale | |
self.mask_flag = mask_flag | |
self.output_attention = output_attention | |
self.dropout = nn.Dropout(attention_dropout) | |
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): | |
B, L, H, E = queries.shape | |
_, S, _, D = values.shape | |
scale = self.scale or 1.0 / sqrt(E) | |
scores = torch.einsum("blhe,bshe->bhls", queries, keys) | |
if self.mask_flag: | |
if attn_mask is None: | |
attn_mask = TriangularCausalMask(B, L, device=queries.device) | |
scores.masked_fill_(attn_mask.mask, -np.inf) | |
A = self.dropout(torch.softmax(scale * scores, dim=-1)) | |
V = torch.einsum("bhls,bshd->blhd", A, values) | |
if self.output_attention: | |
return V.contiguous(), A | |
else: | |
return V.contiguous(), None | |
class AttentionLayer(nn.Module): | |
def __init__(self, attention, d_model, n_heads, d_keys=None, d_values=None): | |
super(AttentionLayer, self).__init__() | |
d_keys = d_keys or (d_model // n_heads) | |
d_values = d_values or (d_model // n_heads) | |
self.inner_attention = attention | |
self.query_projection = nn.Linear(d_model, d_keys * n_heads) | |
self.key_projection = nn.Linear(d_model, d_keys * n_heads) | |
self.value_projection = nn.Linear(d_model, d_values * n_heads) | |
self.out_projection = nn.Linear(d_values * n_heads, d_model) | |
self.n_heads = n_heads | |
def forward(self, queries, keys, values, attn_mask, tau=None, delta=None): | |
B, L, _ = queries.shape | |
_, S, _ = keys.shape | |
H = self.n_heads | |
queries = self.query_projection(queries).view(B, L, H, -1) | |
keys = self.key_projection(keys).view(B, S, H, -1) | |
values = self.value_projection(values).view(B, S, H, -1) | |
out, attn = self.inner_attention( | |
queries, keys, values, attn_mask, tau=tau, delta=delta | |
) | |
out = out.view(B, L, -1) | |
return self.out_projection(out), attn | |