Spaces:
Runtime error
Runtime error
""" | |
Based on https://github.com/buoyancy99/diffusion-forcing/blob/main/algorithms/diffusion_forcing/models/attention.py | |
""" | |
from typing import Optional | |
from collections import namedtuple | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
from einops import rearrange | |
from rotary_embedding_torch import RotaryEmbedding, apply_rotary_emb | |
from embeddings import TimestepEmbedding, Timesteps, Positions2d | |
class TemporalAxialAttention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
heads: int = 4, | |
dim_head: int = 32, | |
is_causal: bool = True, | |
rotary_emb: Optional[RotaryEmbedding] = None, | |
): | |
super().__init__() | |
self.inner_dim = dim_head * heads | |
self.heads = heads | |
self.head_dim = dim_head | |
self.inner_dim = dim_head * heads | |
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) | |
self.to_out = nn.Linear(self.inner_dim, dim) | |
self.rotary_emb = rotary_emb | |
self.time_pos_embedding = ( | |
nn.Sequential( | |
Timesteps(dim), | |
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim), | |
) | |
if rotary_emb is None | |
else None | |
) | |
self.is_causal = is_causal | |
def forward(self, x: torch.Tensor): | |
B, T, H, W, D = x.shape | |
if self.time_pos_embedding is not None: | |
time_emb = self.time_pos_embedding( | |
torch.arange(T, device=x.device) | |
) | |
x = x + rearrange(time_emb, "t d -> 1 t 1 1 d") | |
q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
q = rearrange(q, "B T H W (h d) -> (B H W) h T d", h=self.heads) | |
k = rearrange(k, "B T H W (h d) -> (B H W) h T d", h=self.heads) | |
v = rearrange(v, "B T H W (h d) -> (B H W) h T d", h=self.heads) | |
if self.rotary_emb is not None: | |
q = self.rotary_emb.rotate_queries_or_keys(q, self.rotary_emb.freqs) | |
k = self.rotary_emb.rotate_queries_or_keys(k, self.rotary_emb.freqs) | |
q, k, v = map(lambda t: t.contiguous(), (q, k, v)) | |
x = F.scaled_dot_product_attention( | |
query=q, key=k, value=v, is_causal=self.is_causal | |
) | |
x = rearrange(x, "(B H W) h T d -> B T H W (h d)", B=B, H=H, W=W) | |
x = x.to(q.dtype) | |
# linear proj | |
x = self.to_out(x) | |
return x | |
class SpatialAxialAttention(nn.Module): | |
def __init__( | |
self, | |
dim: int, | |
heads: int = 4, | |
dim_head: int = 32, | |
rotary_emb: Optional[RotaryEmbedding] = None, | |
): | |
super().__init__() | |
self.inner_dim = dim_head * heads | |
self.heads = heads | |
self.head_dim = dim_head | |
self.inner_dim = dim_head * heads | |
self.to_qkv = nn.Linear(dim, self.inner_dim * 3, bias=False) | |
self.to_out = nn.Linear(self.inner_dim, dim) | |
self.rotary_emb = rotary_emb | |
self.space_pos_embedding = ( | |
nn.Sequential( | |
Positions2d(dim), | |
TimestepEmbedding(in_channels=dim, time_embed_dim=dim * 4, out_dim=dim), | |
) | |
if rotary_emb is None | |
else None | |
) | |
def forward(self, x: torch.Tensor): | |
B, T, H, W, D = x.shape | |
if self.space_pos_embedding is not None: | |
h_steps = torch.arange(H, device=x.device) | |
w_steps = torch.arange(W, device=x.device) | |
grid = torch.meshgrid(h_steps, w_steps, indexing="ij") | |
space_emb = self.space_pos_embedding(grid) | |
x = x + rearrange(space_emb, "h w d -> 1 1 h w d") | |
q, k, v = self.to_qkv(x).chunk(3, dim=-1) | |
q = rearrange(q, "B T H W (h d) -> (B T) h H W d", h=self.heads) | |
k = rearrange(k, "B T H W (h d) -> (B T) h H W d", h=self.heads) | |
v = rearrange(v, "B T H W (h d) -> (B T) h H W d", h=self.heads) | |
if self.rotary_emb is not None: | |
freqs = self.rotary_emb.get_axial_freqs(H, W) | |
q = apply_rotary_emb(freqs, q) | |
k = apply_rotary_emb(freqs, k) | |
# prepare for attn | |
q = rearrange(q, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) | |
k = rearrange(k, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) | |
v = rearrange(v, "(B T) h H W d -> (B T) h (H W) d", B=B, T=T, h=self.heads) | |
q, k, v = map(lambda t: t.contiguous(), (q, k, v)) | |
x = F.scaled_dot_product_attention( | |
query=q, key=k, value=v, is_causal=False | |
) | |
x = rearrange(x, "(B T) h (H W) d -> B T H W (h d)", B=B, H=H, W=W) | |
x = x.to(q.dtype) | |
# linear proj | |
x = self.to_out(x) | |
return x | |