Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import math | |
| import torch.nn.functional as F | |
| from torch import nn, einsum | |
| from inspect import isfunction | |
| def exists(val): | |
| return val is not None | |
| def uniq(arr): | |
| return{el: True for el in arr}.keys() | |
| def default(val, d): | |
| if exists(val): | |
| return val | |
| return d() if isfunction(d) else d | |
| def max_neg_value(t): | |
| return -torch.finfo(t.dtype).max | |
| def init_(tensor): | |
| dim = tensor.shape[-1] | |
| std = 1 / math.sqrt(dim) | |
| tensor.uniform_(-std, std) | |
| return tensor | |
| # feedforward | |
| class GEGLU(nn.Module): | |
| def __init__(self, dim_in, dim_out): | |
| super().__init__() | |
| self.proj = nn.Linear(dim_in, dim_out * 2) | |
| def forward(self, x): | |
| x, gate = self.proj(x).chunk(2, dim=-1) | |
| return x * F.gelu(gate) | |
| class FeedForward(nn.Module): | |
| def __init__(self, dim, dim_out=None, mult=4, glu=True, dropout=0.): | |
| super().__init__() | |
| inner_dim = int(dim * mult) | |
| dim_out = default(dim_out, dim) | |
| project_in = nn.Sequential( | |
| nn.Linear(dim, inner_dim), | |
| nn.GELU() | |
| ) if not glu else GEGLU(dim, inner_dim) | |
| self.net = nn.Sequential( | |
| project_in, | |
| nn.Dropout(dropout), | |
| nn.Linear(inner_dim, dim_out) | |
| ) | |
| def forward(self, x): | |
| return self.net(x) | |
| class SelfAttention(nn.Module): | |
| def __init__(self, query_dim, heads=8, dim_head=64, dropout=0.): | |
| super().__init__() | |
| inner_dim = dim_head * heads | |
| self.scale = dim_head ** -0.5 | |
| self.heads = heads | |
| self.to_q = nn.Linear(query_dim, inner_dim, bias=False) | |
| self.to_k = nn.Linear(query_dim, inner_dim, bias=False) | |
| self.to_v = nn.Linear(query_dim, inner_dim, bias=False) | |
| self.to_out = nn.Sequential(nn.Linear(inner_dim, query_dim), nn.Dropout(dropout) ) | |
| def forward(self, x): | |
| q = self.to_q(x) # B*N*(H*C) | |
| k = self.to_k(x) # B*N*(H*C) | |
| v = self.to_v(x) # B*N*(H*C) | |
| B, N, HC = q.shape | |
| H = self.heads | |
| C = HC // H | |
| q = q.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C | |
| k = k.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C | |
| v = v.view(B,N,H,C).permute(0,2,1,3).reshape(B*H,N,C) # (B*H)*N*C | |
| sim = torch.einsum('b i c, b j c -> b i j', q, k) * self.scale # (B*H)*N*N | |
| attn = sim.softmax(dim=-1) # (B*H)*N*N | |
| out = torch.einsum('b i j, b j c -> b i c', attn, v) # (B*H)*N*C | |
| out = out.view(B,H,N,C).permute(0,2,1,3).reshape(B,N,(H*C)) # B*N*(H*C) | |
| return self.to_out(out) | |
| class Resampler(nn.Module): | |
| def __init__(self, query_dim=1024, n_heads=8, d_head=64): | |
| super().__init__() | |
| self.attn = SelfAttention(query_dim=query_dim, heads=n_heads, dim_head=d_head) | |
| self.ff = FeedForward(query_dim, glu=True) | |
| self.norm1 = nn.LayerNorm(query_dim) | |
| self.norm2 = nn.LayerNorm(query_dim) | |
| def forward(self, x): | |
| x = x + self.attn(self.norm1(x)) | |
| x = x + self.ff(self.norm2(x)) | |
| return x |