Spaces:
Sleeping
Sleeping
# Adapted from https://github.com/lucidrains/naturalspeech2-pytorch/blob/659bec7f7543e7747e809e950cc2f84242fbeec7/naturalspeech2_pytorch/naturalspeech2_pytorch.py#L532 | |
from collections import namedtuple | |
from functools import wraps | |
import torch | |
import torch.nn.functional as F | |
from einops import rearrange, repeat | |
from einops.layers.torch import Rearrange | |
from packaging import version | |
from torch import einsum, nn | |
def exists(val): | |
return val is not None | |
def once(fn): | |
called = False | |
def inner(x): | |
nonlocal called | |
if called: | |
return | |
called = True | |
return fn(x) | |
return inner | |
print_once = once(print) | |
# main class | |
class Attend(nn.Module): | |
def __init__(self, dropout=0.0, causal=False, use_flash=False): | |
super().__init__() | |
self.dropout = dropout | |
self.attn_dropout = nn.Dropout(dropout) | |
self.causal = causal | |
self.register_buffer("mask", None, persistent=False) | |
self.use_flash = use_flash | |
assert not ( | |
use_flash and version.parse(torch.__version__) < version.parse("2.0.0") | |
), "in order to use flash attention, you must be using pytorch 2.0 or above" | |
# determine efficient attention configs for cuda and cpu | |
self.config = namedtuple("EfficientAttentionConfig", ["enable_flash", "enable_math", "enable_mem_efficient"]) | |
self.cpu_config = self.config(True, True, True) | |
self.cuda_config = None | |
if not torch.cuda.is_available() or not use_flash: | |
return | |
device_properties = torch.cuda.get_device_properties(torch.device("cuda")) | |
if device_properties.major == 8 and device_properties.minor == 0: | |
print_once("A100 GPU detected, using flash attention if input tensor is on cuda") | |
self.cuda_config = self.config(True, False, False) | |
else: | |
print_once("Non-A100 GPU detected, using math or mem efficient attention if input tensor is on cuda") | |
self.cuda_config = self.config(False, True, True) | |
def get_mask(self, n, device): | |
if exists(self.mask) and self.mask.shape[-1] >= n: | |
return self.mask[:n, :n] | |
mask = torch.ones((n, n), device=device, dtype=torch.bool).triu(1) | |
self.register_buffer("mask", mask, persistent=False) | |
return mask | |
def flash_attn(self, q, k, v, mask=None): | |
_, heads, q_len, _, k_len, is_cuda = *q.shape, k.shape[-2], q.is_cuda | |
# Recommended for multi-query single-key-value attention by Tri Dao | |
# kv shape torch.Size([1, 512, 64]) -> torch.Size([1, 8, 512, 64]) | |
if k.ndim == 3: | |
k = rearrange(k, "b ... -> b 1 ...").expand_as(q) | |
if v.ndim == 3: | |
v = rearrange(v, "b ... -> b 1 ...").expand_as(q) | |
# Check if mask exists and expand to compatible shape | |
# The mask is B L, so it would have to be expanded to B H N L | |
if exists(mask): | |
mask = rearrange(mask, "b j -> b 1 1 j") | |
mask = mask.expand(-1, heads, q_len, -1) | |
# Check if there is a compatible device for flash attention | |
config = self.cuda_config if is_cuda else self.cpu_config | |
# pytorch 2.0 flash attn: q, k, v, mask, dropout, causal, softmax_scale | |
with torch.backends.cuda.sdp_kernel(**config._asdict()): | |
out = F.scaled_dot_product_attention( | |
q, k, v, attn_mask=mask, dropout_p=self.dropout if self.training else 0.0, is_causal=self.causal | |
) | |
return out | |
def forward(self, q, k, v, mask=None): | |
""" | |
einstein notation | |
b - batch | |
h - heads | |
n, i, j - sequence length (base sequence length, source, target) | |
d - feature dimension | |
""" | |
n, device = q.shape[-2], q.device | |
scale = q.shape[-1] ** -0.5 | |
if self.use_flash: | |
return self.flash_attn(q, k, v, mask=mask) | |
kv_einsum_eq = "b j d" if k.ndim == 3 else "b h j d" | |
# similarity | |
sim = einsum(f"b h i d, {kv_einsum_eq} -> b h i j", q, k) * scale | |
# key padding mask | |
if exists(mask): | |
mask = rearrange(mask, "b j -> b 1 1 j") | |
sim = sim.masked_fill(~mask, -torch.finfo(sim.dtype).max) | |
# causal mask | |
if self.causal: | |
causal_mask = self.get_mask(n, device) | |
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max) | |
# attention | |
attn = sim.softmax(dim=-1) | |
attn = self.attn_dropout(attn) | |
# aggregate values | |
out = einsum(f"b h i j, {kv_einsum_eq} -> b h i d", attn, v) | |
return out | |
def Sequential(*mods): | |
return nn.Sequential(*filter(exists, mods)) | |
def exists(x): | |
return x is not None | |
def default(val, d): | |
if exists(val): | |
return val | |
return d() if callable(d) else d | |
class RMSNorm(nn.Module): | |
def __init__(self, dim, scale=True, dim_cond=None): | |
super().__init__() | |
self.cond = exists(dim_cond) | |
self.to_gamma_beta = nn.Linear(dim_cond, dim * 2) if self.cond else None | |
self.scale = dim**0.5 | |
self.gamma = nn.Parameter(torch.ones(dim)) if scale else None | |
def forward(self, x, cond=None): | |
gamma = default(self.gamma, 1) | |
out = F.normalize(x, dim=-1) * self.scale * gamma | |
if not self.cond: | |
return out | |
assert exists(cond) | |
gamma, beta = self.to_gamma_beta(cond).chunk(2, dim=-1) | |
gamma, beta = map(lambda t: rearrange(t, "b d -> b 1 d"), (gamma, beta)) | |
return out * gamma + beta | |
class CausalConv1d(nn.Conv1d): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
(kernel_size,) = self.kernel_size | |
(dilation,) = self.dilation | |
(stride,) = self.stride | |
assert stride == 1 | |
self.causal_padding = dilation * (kernel_size - 1) | |
def forward(self, x): | |
causal_padded_x = F.pad(x, (self.causal_padding, 0), value=0.0) | |
return super().forward(causal_padded_x) | |
class GEGLU(nn.Module): | |
def forward(self, x): | |
x, gate = x.chunk(2, dim=-1) | |
return F.gelu(gate) * x | |
def FeedForward(dim, mult=4, causal_conv=False): | |
dim_inner = int(dim * mult * 2 / 3) | |
conv = None | |
if causal_conv: | |
conv = nn.Sequential( | |
Rearrange("b n d -> b d n"), | |
CausalConv1d(dim_inner, dim_inner, 3), | |
Rearrange("b d n -> b n d"), | |
) | |
return Sequential(nn.Linear(dim, dim_inner * 2), GEGLU(), conv, nn.Linear(dim_inner, dim)) | |
class PerceiverResampler(nn.Module): | |
def __init__( | |
self, | |
*, | |
dim, | |
depth=2, | |
dim_context=None, | |
num_latents=32, | |
dim_head=64, | |
heads=8, | |
ff_mult=4, | |
use_flash_attn=False, | |
): | |
super().__init__() | |
dim_context = default(dim_context, dim) | |
self.proj_context = nn.Linear(dim_context, dim) if dim_context != dim else nn.Identity() | |
self.latents = nn.Parameter(torch.randn(num_latents, dim)) | |
nn.init.normal_(self.latents, std=0.02) | |
self.layers = nn.ModuleList([]) | |
for _ in range(depth): | |
self.layers.append( | |
nn.ModuleList( | |
[ | |
Attention( | |
dim=dim, | |
dim_head=dim_head, | |
heads=heads, | |
use_flash=use_flash_attn, | |
cross_attn_include_queries=True, | |
), | |
FeedForward(dim=dim, mult=ff_mult), | |
] | |
) | |
) | |
self.norm = RMSNorm(dim) | |
def forward(self, x, mask=None): | |
batch = x.shape[0] | |
x = self.proj_context(x) | |
latents = repeat(self.latents, "n d -> b n d", b=batch) | |
for attn, ff in self.layers: | |
latents = attn(latents, x, mask=mask) + latents | |
latents = ff(latents) + latents | |
return self.norm(latents) | |
class Attention(nn.Module): | |
def __init__( | |
self, | |
dim, | |
*, | |
dim_context=None, | |
causal=False, | |
dim_head=64, | |
heads=8, | |
dropout=0.0, | |
use_flash=False, | |
cross_attn_include_queries=False, | |
): | |
super().__init__() | |
self.scale = dim_head**-0.5 | |
self.heads = heads | |
self.cross_attn_include_queries = cross_attn_include_queries | |
dim_inner = dim_head * heads | |
dim_context = default(dim_context, dim) | |
self.attend = Attend(causal=causal, dropout=dropout, use_flash=use_flash) | |
self.to_q = nn.Linear(dim, dim_inner, bias=False) | |
self.to_kv = nn.Linear(dim_context, dim_inner * 2, bias=False) | |
self.to_out = nn.Linear(dim_inner, dim, bias=False) | |
def forward(self, x, context=None, mask=None): | |
h, has_context = self.heads, exists(context) | |
context = default(context, x) | |
if has_context and self.cross_attn_include_queries: | |
context = torch.cat((x, context), dim=-2) | |
q, k, v = (self.to_q(x), *self.to_kv(context).chunk(2, dim=-1)) | |
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v)) | |
out = self.attend(q, k, v, mask=mask) | |
out = rearrange(out, "b h n d -> b n (h d)") | |
return self.to_out(out) | |