Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,672 Bytes
b9d6819 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 |
import torch
import torch.nn as nn
import torch.utils.checkpoint
import einops
from einops import rearrange, repeat
from inspect import isfunction
from .rotary import RotaryEmbedding
if hasattr(nn.functional, 'scaled_dot_product_attention'):
ATTENTION_MODE = 'flash'
else:
ATTENTION_MODE = 'math'
print(f'attention mode is {ATTENTION_MODE}')
def add_mask(sim, mask):
b, ndim = sim.shape[0], mask.ndim
if ndim == 3:
mask = rearrange(mask, "b n m -> b 1 n m")
if ndim == 2:
mask = repeat(mask, "n m -> b 1 n m", b=b)
max_neg_value = -torch.finfo(sim.dtype).max
sim = sim.masked_fill(~mask, max_neg_value)
return sim
def create_mask(q, k, q_mask=None, k_mask=None):
def default(val, d):
return val if val is not None else (d() if isfunction(d) else d)
b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device
q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool))
k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool))
attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j')
return attn_mask
class Attention(nn.Module):
def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None,
attn_drop=0., proj_drop=0., use_rope=False):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = qk_scale or head_dim ** -0.5
context_dim = dim if context_dim is None else context_dim
self.to_q = nn.Linear(dim, dim, bias=qkv_bias)
self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias)
self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias)
self.attn_drop_p = attn_drop
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
self.use_rope = use_rope
if self.use_rope:
self.rotary = RotaryEmbedding(dim=head_dim)
def forward(self, x, context=None, context_mask=None):
B, L, C = x.shape
q = self.to_q(x)
if context is None:
context = x
else:
assert self.use_rope is False
k = self.to_k(context)
v = self.to_v(context)
if context_mask is not None:
mask_binary = create_mask(x, context, None, context_mask)
else:
mask_binary = None
q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float()
k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float()
v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float()
if self.use_rope:
q, k = self.rotary(q=q, k=k)
if ATTENTION_MODE == 'flash':
x = torch.nn.functional.scaled_dot_product_attention(q, k, v,
dropout_p=self.attn_drop_p,
attn_mask=mask_binary)
x = einops.rearrange(x, 'B H L D -> B L (H D)')
elif ATTENTION_MODE == 'math':
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = add_mask(attn, mask_binary) if mask_binary is not None else attn
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, L, C)
else:
raise NotImplementedError
x = self.proj(x)
x = self.proj_drop(x)
return x |