| | from RoPE import apply_angles_2d, generate_angles_2d
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from einops import rearrange
|
| |
|
| |
|
| | class Attention(nn.Module):
|
| | def __init__(self, H,W, emb_dim, n_heads=8):
|
| | super().__init__()
|
| | self.H = H
|
| | self.W = W
|
| | self.n_heads = n_heads
|
| | head_dim = emb_dim // n_heads
|
| | self.qkv = nn.Linear(emb_dim, 3*emb_dim, bias=False)
|
| | self.apply_angles_2d = apply_angles_2d
|
| | self.proj = nn.Linear(emb_dim, emb_dim)
|
| | self.register_buffer("freq", generate_angles_2d(H, W, head_dim), persistent=False)
|
| |
|
| | def forward(self, x):
|
| | B, N, D = x.shape
|
| | q, k, v = self.qkv(x).chunk(3, dim=-1)
|
| |
|
| |
|
| | q = rearrange(q, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
| | k = rearrange(k, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
| | v = rearrange(v, "B (H W) (h D) -> B h H W D", H=self.H, W=self.W, h=self.n_heads)
|
| |
|
| | q = apply_angles_2d(q, self.freq)
|
| | k = apply_angles_2d(k, self.freq)
|
| | v = apply_angles_2d(v, self.freq)
|
| |
|
| |
|
| | q = rearrange(q, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
| | k = rearrange(k, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
| | v = rearrange(v, "B h H W D -> B h (H W) D", H=self.H, W=self.W, h=self.n_heads)
|
| |
|
| | x = F.scaled_dot_product_attention(q, k, v)
|
| | x = rearrange(x, "B h N D -> B N (h D)")
|
| | x = self.proj(x)
|
| | return x
|
| |
|
| | class ViTBlock(nn.Module):
|
| | def __init__(self, H, W, emb_dim, n_heads=8, dropout=0.1):
|
| | self.H, self.W, self.emb_dim = H, W, emb_dim
|
| | super().__init__()
|
| | self.attn = nn.Sequential(nn.LayerNorm(emb_dim),
|
| | Attention(H,W,emb_dim,n_heads=n_heads))
|
| | self.MLP = nn.Sequential(nn.LayerNorm(emb_dim),
|
| | nn.Linear(emb_dim, emb_dim*4, bias=True),
|
| | nn.GELU(),
|
| | nn.Dropout(dropout),
|
| | nn.Linear(emb_dim*4, emb_dim, bias=True),
|
| | nn.Dropout(dropout))
|
| | def forward(self, x):
|
| | assert x.ndim == 3, f"Expected shape [B, N, D], but got shape {x.shape}. You probably passed [B, H, W, D] instead."
|
| | assert x.shape == torch.Size([x.shape[0], self.H * self.W, self.emb_dim]), f"Expected shape [B, N, D] -> {torch.Size([x.shape[0], self.H * self.W, self.emb_dim])}, got {x.shape}"
|
| | x = x + self.attn(x)
|
| | x = x + self.MLP(x)
|
| | return x
|
| |
|
| |
|
| | print(ViTBlock(64,64,384)(torch.randn(1, 64**2, 384)).shape) |