|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint |
|
from diffusers.configuration_utils import ConfigMixin, register_to_config |
|
from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin |
|
from diffusers.models.modeling_utils import ModelMixin |
|
from diffusers.utils.accelerate_utils import apply_forward_hook |
|
from einops import rearrange |
|
from peft import get_peft_model_state_dict, set_peft_model_state_dict |
|
from torch import nn |
|
|
|
|
|
def timestep_embedding(t, dim, max_period=10000): |
|
half = dim // 2 |
|
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( |
|
device=t.device |
|
) |
|
args = t[:, None].float() * freqs[None] |
|
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
|
|
|
return embedding |
|
|
|
|
|
class RMSNorm(nn.Module): |
|
def __init__(self, dim, eps=1e-6, trainable=False): |
|
super().__init__() |
|
self.eps = eps |
|
if trainable: |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
else: |
|
self.weight = None |
|
|
|
def forward(self, x): |
|
x_dtype = x.dtype |
|
x = x.float() |
|
norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
if self.weight is not None: |
|
return (x * norm * self.weight).to(dtype=x_dtype) |
|
else: |
|
return (x * norm).to(dtype=x_dtype) |
|
|
|
|
|
class QKNorm(nn.Module): |
|
"""Normalizing the query and the key independently, as Flux proposes""" |
|
|
|
def __init__(self, dim, trainable=False): |
|
super().__init__() |
|
self.query_norm = RMSNorm(dim, trainable=trainable) |
|
self.key_norm = RMSNorm(dim, trainable=trainable) |
|
|
|
def forward(self, q, k): |
|
q = self.query_norm(q) |
|
k = self.key_norm(k) |
|
return q, k |
|
|
|
|
|
class Attention(nn.Module): |
|
def __init__( |
|
self, |
|
dim, |
|
num_heads=8, |
|
qkv_bias=False, |
|
is_self_attn=True, |
|
cross_attn_input_size=None, |
|
residual_v=False, |
|
dynamic_softmax_temperature=False, |
|
): |
|
super().__init__() |
|
assert dim % num_heads == 0 |
|
self.num_heads = num_heads |
|
self.head_dim = dim // num_heads |
|
self.scale = self.head_dim**-0.5 |
|
self.is_self_attn = is_self_attn |
|
self.residual_v = residual_v |
|
self.dynamic_softmax_temperature = dynamic_softmax_temperature |
|
|
|
if is_self_attn: |
|
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) |
|
else: |
|
self.q = nn.Linear(dim, dim, bias=qkv_bias) |
|
self.context_kv = nn.Linear(cross_attn_input_size, dim * 2, bias=qkv_bias) |
|
|
|
self.proj = nn.Linear(dim, dim, bias=False) |
|
|
|
if residual_v: |
|
self.lambda_param = nn.Parameter(torch.tensor(0.5).reshape(1)) |
|
|
|
self.qk_norm = QKNorm(self.head_dim) |
|
|
|
def forward(self, x, context=None, v_0=None, rope=None): |
|
if self.is_self_attn: |
|
qkv = self.qkv(x) |
|
qkv = rearrange(qkv, "b l (k h d) -> k b h l d", k=3, h=self.num_heads) |
|
q, k, v = qkv.unbind(0) |
|
|
|
if self.residual_v and v_0 is not None: |
|
v = self.lambda_param * v + (1 - self.lambda_param) * v_0 |
|
|
|
if rope is not None: |
|
|
|
q = apply_rotary_emb(q, rope[0], rope[1]) |
|
k = apply_rotary_emb(k, rope[0], rope[1]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
token_length = q.shape[2] |
|
if self.dynamic_softmax_temperature: |
|
ratio = math.sqrt(math.log(token_length) / math.log(1040.0)) |
|
k = k * ratio |
|
q, k = self.qk_norm(q, k) |
|
|
|
else: |
|
q = rearrange(self.q(x), "b l (h d) -> b h l d", h=self.num_heads) |
|
kv = rearrange( |
|
self.context_kv(context), |
|
"b l (k h d) -> k b h l d", |
|
k=2, |
|
h=self.num_heads, |
|
) |
|
k, v = kv.unbind(0) |
|
q, k = self.qk_norm(q, k) |
|
|
|
x = F.scaled_dot_product_attention(q, k, v) |
|
x = rearrange(x, "b h l d -> b l (h d)") |
|
x = self.proj(x) |
|
return x, v if self.is_self_attn else None |
|
|
|
|
|
class DiTBlock(nn.Module): |
|
def __init__( |
|
self, |
|
hidden_size, |
|
cross_attn_input_size, |
|
num_heads, |
|
mlp_ratio=4.0, |
|
qkv_bias=True, |
|
residual_v=False, |
|
dynamic_softmax_temperature=False, |
|
): |
|
super().__init__() |
|
self.hidden_size = hidden_size |
|
self.norm1 = RMSNorm(hidden_size, trainable=qkv_bias) |
|
self.self_attn = Attention( |
|
hidden_size, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
is_self_attn=True, |
|
residual_v=residual_v, |
|
dynamic_softmax_temperature=dynamic_softmax_temperature, |
|
) |
|
|
|
if cross_attn_input_size is not None: |
|
self.norm2 = RMSNorm(hidden_size, trainable=qkv_bias) |
|
self.cross_attn = Attention( |
|
hidden_size, |
|
num_heads=num_heads, |
|
qkv_bias=qkv_bias, |
|
is_self_attn=False, |
|
cross_attn_input_size=cross_attn_input_size, |
|
dynamic_softmax_temperature=dynamic_softmax_temperature, |
|
) |
|
else: |
|
self.norm2 = None |
|
self.cross_attn = None |
|
|
|
self.norm3 = RMSNorm(hidden_size, trainable=qkv_bias) |
|
mlp_hidden = int(hidden_size * mlp_ratio) |
|
self.mlp = nn.Sequential( |
|
nn.Linear(hidden_size, mlp_hidden), |
|
nn.GELU(), |
|
nn.Linear(mlp_hidden, hidden_size), |
|
) |
|
|
|
self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 9 * hidden_size, bias=True)) |
|
|
|
self.adaLN_modulation[-1].weight.data.zero_() |
|
self.adaLN_modulation[-1].bias.data.zero_() |
|
|
|
|
|
def forward(self, x, context, c, v_0=None, rope=None): |
|
( |
|
shift_sa, |
|
scale_sa, |
|
gate_sa, |
|
shift_ca, |
|
scale_ca, |
|
gate_ca, |
|
shift_mlp, |
|
scale_mlp, |
|
gate_mlp, |
|
) = self.adaLN_modulation(c).chunk(9, dim=1) |
|
|
|
scale_sa = scale_sa[:, None, :] |
|
scale_ca = scale_ca[:, None, :] |
|
scale_mlp = scale_mlp[:, None, :] |
|
|
|
shift_sa = shift_sa[:, None, :] |
|
shift_ca = shift_ca[:, None, :] |
|
shift_mlp = shift_mlp[:, None, :] |
|
|
|
gate_sa = gate_sa[:, None, :] |
|
gate_ca = gate_ca[:, None, :] |
|
gate_mlp = gate_mlp[:, None, :] |
|
|
|
norm_x = self.norm1(x.clone()) |
|
norm_x = norm_x * (1 + scale_sa) + shift_sa |
|
attn_out, v = self.self_attn(norm_x, v_0=v_0, rope=rope) |
|
x = x + attn_out * gate_sa |
|
|
|
if self.norm2 is not None: |
|
norm_x = self.norm2(x) |
|
norm_x = norm_x * (1 + scale_ca) + shift_ca |
|
x = x + self.cross_attn(norm_x, context)[0] * gate_ca |
|
|
|
norm_x = self.norm3(x) |
|
norm_x = norm_x * (1 + scale_mlp) + shift_mlp |
|
x = x + self.mlp(norm_x) * gate_mlp |
|
|
|
return x, v |
|
|
|
|
|
class PatchEmbed(nn.Module): |
|
def __init__(self, patch_size=16, in_channels=3, embed_dim=768): |
|
super().__init__() |
|
self.patch_proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size) |
|
self.patch_size = patch_size |
|
|
|
def forward(self, x): |
|
B, C, H, W = x.shape |
|
x = self.patch_proj(x) |
|
x = rearrange(x, "b c h w -> b (h w) c") |
|
return x |
|
|
|
|
|
class TwoDimRotary(torch.nn.Module): |
|
def __init__(self, dim, base=10000, h=256, w=256): |
|
super().__init__() |
|
self.inv_freq = torch.FloatTensor([1.0 / (base ** (i / dim)) for i in range(0, dim, 2)]) |
|
self.h = h |
|
self.w = w |
|
|
|
t_h = torch.arange(h, dtype=torch.float32) |
|
t_w = torch.arange(w, dtype=torch.float32) |
|
|
|
freqs_h = torch.outer(t_h, self.inv_freq).unsqueeze(1) |
|
freqs_w = torch.outer(t_w, self.inv_freq).unsqueeze(0) |
|
freqs_h = freqs_h.repeat(1, w, 1) |
|
freqs_w = freqs_w.repeat(h, 1, 1) |
|
freqs_hw = torch.cat([freqs_h, freqs_w], 2) |
|
|
|
self.register_buffer("freqs_hw_cos", freqs_hw.cos()) |
|
self.register_buffer("freqs_hw_sin", freqs_hw.sin()) |
|
|
|
def forward(self, x, height_width=None, extend_with_register_tokens=0): |
|
if height_width is not None: |
|
this_h, this_w = height_width |
|
else: |
|
this_hw = x.shape[1] |
|
this_h, this_w = int(this_hw**0.5), int(this_hw**0.5) |
|
|
|
cos = self.freqs_hw_cos[0 : this_h, 0 : this_w] |
|
sin = self.freqs_hw_sin[0 : this_h, 0 : this_w] |
|
|
|
cos = cos.clone().reshape(this_h * this_w, -1) |
|
sin = sin.clone().reshape(this_h * this_w, -1) |
|
|
|
|
|
if extend_with_register_tokens > 0: |
|
cos = torch.cat( |
|
[ |
|
torch.ones(extend_with_register_tokens, cos.shape[1]).to(cos.device), |
|
cos, |
|
], |
|
0, |
|
) |
|
sin = torch.cat( |
|
[ |
|
torch.zeros(extend_with_register_tokens, sin.shape[1]).to(sin.device), |
|
sin, |
|
], |
|
0, |
|
) |
|
|
|
return cos[None, None, :, :], sin[None, None, :, :] |
|
|
|
|
|
def apply_rotary_emb(x, cos, sin): |
|
orig_dtype = x.dtype |
|
x = x.to(dtype=torch.float32) |
|
assert x.ndim == 4 |
|
d = x.shape[3] // 2 |
|
x1 = x[..., :d] |
|
x2 = x[..., d:] |
|
y1 = x1 * cos + x2 * sin |
|
y2 = x1 * (-sin) + x2 * cos |
|
return torch.cat([y1, y2], 3).to(dtype=orig_dtype) |
|
|
|
|
|
class DiT(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin): |
|
@register_to_config |
|
def __init__( |
|
self, |
|
in_channels=4, |
|
patch_size=2, |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
cross_attn_input_size=128, |
|
residual_v=False, |
|
train_bias_and_rms=True, |
|
use_rope=True, |
|
gradient_checkpoint=False, |
|
dynamic_softmax_temperature=False, |
|
rope_base=10000, |
|
): |
|
super().__init__() |
|
|
|
self.patch_embed = PatchEmbed(patch_size, in_channels, hidden_size) |
|
|
|
if use_rope: |
|
self.rope = TwoDimRotary(hidden_size // (2 * num_heads), base=rope_base, h=512, w=512) |
|
else: |
|
self.positional_embedding = nn.Parameter(torch.zeros(1, 2048, hidden_size)) |
|
|
|
self.register_tokens = nn.Parameter(torch.randn(1, 16, hidden_size)) |
|
|
|
self.time_embed = nn.Sequential( |
|
nn.Linear(hidden_size, 4 * hidden_size), |
|
nn.SiLU(), |
|
nn.Linear(4 * hidden_size, hidden_size), |
|
) |
|
|
|
self.blocks = nn.ModuleList( |
|
[ |
|
DiTBlock( |
|
hidden_size=hidden_size, |
|
num_heads=num_heads, |
|
mlp_ratio=mlp_ratio, |
|
cross_attn_input_size=cross_attn_input_size, |
|
residual_v=residual_v, |
|
qkv_bias=train_bias_and_rms, |
|
dynamic_softmax_temperature=dynamic_softmax_temperature, |
|
) |
|
for _ in range(depth) |
|
] |
|
) |
|
|
|
self.final_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) |
|
|
|
self.final_norm = RMSNorm(hidden_size, trainable=train_bias_and_rms) |
|
self.final_proj = nn.Linear(hidden_size, patch_size * patch_size * in_channels) |
|
nn.init.zeros_(self.final_modulation[-1].weight) |
|
nn.init.zeros_(self.final_modulation[-1].bias) |
|
nn.init.zeros_(self.final_proj.weight) |
|
nn.init.zeros_(self.final_proj.bias) |
|
self.paramstatus = {} |
|
for n, p in self.named_parameters(): |
|
self.paramstatus[n] = { |
|
"shape": p.shape, |
|
"requires_grad": p.requires_grad, |
|
} |
|
|
|
def save_lora_weights(self, save_directory): |
|
"""Save LoRA weights to a file""" |
|
lora_state_dict = get_peft_model_state_dict(self) |
|
torch.save(lora_state_dict, f"{save_directory}/lora_weights.pt") |
|
|
|
def load_lora_weights(self, load_directory): |
|
"""Load LoRA weights from a file""" |
|
lora_state_dict = torch.load(f"{load_directory}/lora_weights.pt") |
|
set_peft_model_state_dict(self, lora_state_dict) |
|
|
|
@apply_forward_hook |
|
def forward(self, x, context, timesteps): |
|
b, c, h, w = x.shape |
|
x = self.patch_embed(x) |
|
|
|
x = torch.cat([self.register_tokens.repeat(b, 1, 1), x], 1) |
|
|
|
if self.config.use_rope: |
|
cos, sin = self.rope( |
|
x, |
|
extend_with_register_tokens=16, |
|
height_width=(h // self.config.patch_size, w // self.config.patch_size), |
|
) |
|
else: |
|
x = x + self.positional_embedding.repeat(b, 1, 1)[:, : x.shape[1], :] |
|
cos, sin = None, None |
|
|
|
t_emb = timestep_embedding(timesteps * 1000, self.config.hidden_size).to(x.device, dtype=x.dtype) |
|
t_emb = self.time_embed(t_emb) |
|
|
|
v_0 = None |
|
|
|
for _idx, block in enumerate(self.blocks): |
|
if self.config.gradient_checkpoint: |
|
x, v = torch.utils.checkpoint.checkpoint( |
|
block, |
|
x, |
|
context, |
|
t_emb, |
|
v_0, |
|
(cos, sin), |
|
use_reentrant=True, |
|
) |
|
else: |
|
x, v = block(x, context, t_emb, v_0, (cos, sin)) |
|
if v_0 is None: |
|
v_0 = v |
|
|
|
x = x[:, 16:, :] |
|
final_shift, final_scale = self.final_modulation(t_emb).chunk(2, dim=1) |
|
x = self.final_norm(x) |
|
x = x * (1 + final_scale[:, None, :]) + final_shift[:, None, :] |
|
x = self.final_proj(x) |
|
|
|
x = rearrange( |
|
x, |
|
"b (h w) (p1 p2 c) -> b c (h p1) (w p2)", |
|
h=h // self.config.patch_size, |
|
w=w // self.config.patch_size, |
|
p1=self.config.patch_size, |
|
p2=self.config.patch_size, |
|
) |
|
return x |
|
|
|
|
|
if __name__ == "__main__": |
|
model = DiT( |
|
in_channels=4, |
|
patch_size=2, |
|
hidden_size=1152, |
|
depth=28, |
|
num_heads=16, |
|
mlp_ratio=4.0, |
|
cross_attn_input_size=128, |
|
residual_v=False, |
|
train_bias_and_rms=True, |
|
use_rope=True, |
|
).cuda() |
|
print( |
|
model( |
|
torch.randn(1, 4, 64, 64).cuda(), |
|
torch.randn(1, 37, 128).cuda(), |
|
torch.tensor([1.0]).cuda(), |
|
) |
|
) |
|
|