PommesPeter commited on
Commit
fbed413
1 Parent(s): a935b35

Update models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +90 -141
models/model.py CHANGED
@@ -1,9 +1,17 @@
1
- import functools
2
- import logging
 
 
 
 
 
 
 
 
 
3
  import math
4
- from typing import Optional, Tuple, List
5
 
6
- from .components import RMSNorm
7
  from flash_attn import flash_attn_varlen_func
8
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
9
  import torch
@@ -11,7 +19,7 @@ import torch.distributed as dist
11
  import torch.nn as nn
12
  import torch.nn.functional as F
13
 
14
- logger = logging.getLogger(__name__)
15
 
16
 
17
  def modulate(x, scale):
@@ -57,17 +65,13 @@ class ParallelTimestepEmbedder(nn.Module):
57
  """
58
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
59
  half = dim // 2
60
- freqs = torch.exp(
61
- -math.log(max_period)
62
- * torch.arange(start=0, end=half, dtype=torch.float32)
63
- / half
64
- ).to(device=t.device)
65
  args = t[:, None].float() * freqs[None]
66
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
67
  if dim % 2:
68
- embedding = torch.cat(
69
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
70
- )
71
  return embedding
72
 
73
  def forward(self, t):
@@ -85,8 +89,7 @@ class ParallelLabelEmbedder(nn.Module):
85
  super().__init__()
86
  use_cfg_embedding = int(dropout_prob > 0)
87
  self.embedding_table = nn.Embedding(
88
- num_classes + use_cfg_embedding,
89
- hidden_size,
90
  )
91
  self.num_classes = num_classes
92
  self.dropout_prob = dropout_prob
@@ -96,9 +99,7 @@ class ParallelLabelEmbedder(nn.Module):
96
  Drops labels to enable classifier-free guidance.
97
  """
98
  if force_drop_ids is None:
99
- drop_ids = (
100
- torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
101
- )
102
  drop_ids = drop_ids.cuda()
103
  drop_ids = drop_ids.to(labels.device)
104
  else:
@@ -141,10 +142,9 @@ class Attention(nn.Module):
141
  """
142
  super().__init__()
143
  self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
144
- model_parallel_size = 1
145
- self.n_local_heads = n_heads // model_parallel_size
146
- self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
147
- self.n_rep = self.n_local_heads // self.n_local_kv_heads
148
  self.head_dim = dim // n_heads
149
 
150
  self.wq = nn.Linear(
@@ -173,7 +173,7 @@ class Attention(nn.Module):
173
  self.n_kv_heads * self.head_dim,
174
  bias=False,
175
  )
176
- self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
177
 
178
  self.wo = nn.Linear(
179
  n_heads * self.head_dim,
@@ -182,10 +182,10 @@ class Attention(nn.Module):
182
  )
183
 
184
  if qk_norm:
185
- self.q_norm = nn.LayerNorm(self.n_local_heads * self.head_dim)
186
- self.k_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
187
  if y_dim > 0:
188
- self.ky_norm = nn.LayerNorm(self.n_local_kv_heads * self.head_dim)
189
  else:
190
  self.ky_norm = nn.Identity()
191
  else:
@@ -255,17 +255,12 @@ class Attention(nn.Module):
255
  return x_out.type_as(x_in)
256
 
257
  # copied from huggingface modeling_llama.py
258
- def _upad_input(
259
- self, query_layer, key_layer, value_layer, attention_mask, query_length
260
- ):
261
-
262
  def _get_unpad_data(attention_mask):
263
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
264
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
265
  max_seqlen_in_batch = seqlens_in_batch.max().item()
266
- cu_seqlens = F.pad(
267
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
268
- )
269
  return (
270
  indices,
271
  cu_seqlens,
@@ -285,9 +280,7 @@ class Attention(nn.Module):
285
  )
286
  if query_length == kv_seq_len:
287
  query_layer = index_first_axis(
288
- query_layer.reshape(
289
- batch_size * kv_seq_len, self.n_local_heads, head_dim
290
- ),
291
  indices_k,
292
  )
293
  cu_seqlens_q = cu_seqlens_k
@@ -303,9 +296,7 @@ class Attention(nn.Module):
303
  else:
304
  # The -q_len: slice assumes left padding.
305
  attention_mask = attention_mask[:, -query_length:]
306
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
307
- query_layer, attention_mask
308
- )
309
 
310
  return (
311
  query_layer,
@@ -343,15 +334,20 @@ class Attention(nn.Module):
343
  xq = self.q_norm(xq)
344
  xk = self.k_norm(xk)
345
 
346
- xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
347
- xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
348
- xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
349
 
350
  xq = Attention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
351
  xk = Attention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
352
 
353
  xq, xk = xq.to(dtype), xk.to(dtype)
354
 
 
 
 
 
 
355
  if dtype in [torch.float16, torch.bfloat16]:
356
  # begin var_len flash attn
357
  (
@@ -366,13 +362,6 @@ class Attention(nn.Module):
366
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
367
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
368
 
369
- if self.proportional_attn:
370
- softmax_scale = math.sqrt(
371
- math.log(seqlen, self.base_seqlen) / self.head_dim
372
- )
373
- else:
374
- softmax_scale = math.sqrt(1 / self.head_dim)
375
-
376
  attn_output_unpad = flash_attn_varlen_func(
377
  query_states,
378
  key_states,
@@ -394,21 +383,17 @@ class Attention(nn.Module):
394
  xq.permute(0, 2, 1, 3),
395
  xk.permute(0, 2, 1, 3),
396
  xv.permute(0, 2, 1, 3),
397
- attn_mask=x_mask.bool()
398
- .view(bsz, 1, 1, seqlen)
399
- .expand(-1, self.n_local_heads, seqlen, -1),
400
  )
401
  .permute(0, 2, 1, 3)
402
  .to(dtype)
403
  )
404
 
405
  if hasattr(self, "wk_y"):
406
- # todo better flash_attn support
407
- yk = self.ky_norm(self.wk_y(y)).view(
408
- bsz, -1, self.n_local_kv_heads, self.head_dim
409
- )
410
- yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
411
- n_rep = self.n_local_heads // self.n_local_kv_heads
412
  if n_rep >= 1:
413
  yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
414
  yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
@@ -416,7 +401,7 @@ class Attention(nn.Module):
416
  xq.permute(0, 2, 1, 3),
417
  yk.permute(0, 2, 1, 3),
418
  yv.permute(0, 2, 1, 3),
419
- y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1),
420
  ).permute(0, 2, 1, 3)
421
  output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
422
  output = output + output_y
@@ -534,9 +519,9 @@ class TransformerBlock(nn.Module):
534
  )
535
  self.layer_id = layer_id
536
  self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
537
- self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
538
-
539
  self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
 
 
540
  self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
541
 
542
  self.adaLN_modulation = nn.Sequential(
@@ -583,33 +568,28 @@ class TransformerBlock(nn.Module):
583
  y_mask,
584
  )
585
  )
586
- d = x.shape[-1]
587
  x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
588
  self.feed_forward(
589
- modulate(self.ffn_norm1(x), scale_mlp).view(-1, d),
590
- ).view(*x.shape)
591
  )
592
 
593
  else:
594
- x = x + self.attention_norm1(
595
  self.attention(
596
- self.attention_norm(x),
597
  x_mask,
598
  freqs_cis,
599
  self.attention_y_norm(y),
600
  y_mask,
601
  )
602
  )
603
- # for compatibility with torch.compile because the sequence length changes
604
- B, L, D = x.shape
605
- x = x.view(B * L, D)
606
- x = x + self.ffn_norm1(self.feed_forward(self.ffn_norm(x)))
607
- x = x.view(B, L, D)
608
 
609
  return x
610
 
611
 
612
- class ParallelFinalLayer(nn.Module):
613
  """
614
  The final layer of NextDiT.
615
  """
@@ -624,19 +604,18 @@ class ParallelFinalLayer(nn.Module):
624
  self.linear = nn.Linear(
625
  hidden_size,
626
  patch_size * patch_size * out_channels,
627
- bias=True,
628
  )
629
  self.adaLN_modulation = nn.Sequential(
630
  nn.SiLU(),
631
  nn.Linear(
632
  min(hidden_size, 1024),
633
  hidden_size,
634
- bias=True,
635
  ),
636
  )
637
 
638
  def forward(self, x, c):
639
  scale = self.adaLN_modulation(c)
 
640
  x = modulate(self.norm_final(x), scale)
641
  x = self.linear(x)
642
  return x
@@ -661,7 +640,6 @@ class NextDiT(nn.Module):
661
  learn_sigma: bool = True,
662
  qk_norm: bool = False,
663
  cap_feat_dim: int = 5120,
664
- rope_scaling_factor: float = 1.0,
665
  scale_factor: float = 1.0,
666
  ) -> None:
667
  super().__init__()
@@ -703,27 +681,21 @@ class NextDiT(nn.Module):
703
  for layer_id in range(n_layers)
704
  ]
705
  )
706
- self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
707
 
708
  assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
709
- self.dim = dim
710
- self.n_heads = n_heads
711
  self.freqs_cis = NextDiT.precompute_freqs_cis(
712
  dim // n_heads,
713
  384,
714
- rope_scaling_factor=rope_scaling_factor,
715
  scale_factor=scale_factor,
716
  )
717
- self.rope_scaling_factor = rope_scaling_factor
 
718
  self.scale_factor = scale_factor
719
- # self.eol_token = nn.Parameter(torch.empty(dim))
720
  self.pad_token = nn.Parameter(torch.empty(dim))
721
- # nn.init.normal_(self.eol_token, std=0.02)
722
  nn.init.normal_(self.pad_token, std=0.02)
723
 
724
- def unpatchify(
725
- self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False
726
- ) -> List[torch.Tensor]:
727
  """
728
  x: (N, T, patch_size**2 * C)
729
  imgs: (N, H, W, C)
@@ -757,18 +729,12 @@ class NextDiT(nn.Module):
757
  if isinstance(x, torch.Tensor):
758
  pH = pW = self.patch_size
759
  B, C, H, W = x.size()
760
- x = (
761
- x.view(B, C, H // pH, pH, W // pW, pW)
762
- .permute(0, 2, 4, 1, 3, 5)
763
- .flatten(3)
764
- )
765
  x = self.x_embedder(x)
766
  x = x.flatten(1, 2)
767
 
768
- mask = torch.ones(
769
- x.shape[0], x.shape[1], dtype=torch.int32, device=x.device
770
- )
771
- # leave the first line for text
772
  return (
773
  x,
774
  mask,
@@ -787,20 +753,14 @@ class NextDiT(nn.Module):
787
  item_freqs_cis = self.freqs_cis[: H // pH, : W // pW]
788
  freqs_cis.append(item_freqs_cis.flatten(0, 1))
789
  img_size.append((H, W))
790
- img = (
791
- img.view(C, H // pH, pH, W // pW, pW)
792
- .permute(1, 3, 0, 2, 4)
793
- .flatten(2)
794
- )
795
  img = self.x_embedder(img)
796
  img = img.flatten(0, 1)
797
  l_effective_seq_len.append(len(img))
798
  x_embed.append(img)
799
 
800
  max_seq_len = max(l_effective_seq_len)
801
- mask = torch.zeros(
802
- len(x), max_seq_len, dtype=torch.int32, device=x[0].device
803
- )
804
  padded_x_embed = []
805
  padded_freqs_cis = []
806
  for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate(
@@ -809,9 +769,7 @@ class NextDiT(nn.Module):
809
  item_embed = torch.cat(
810
  [
811
  item_embed,
812
- self.pad_token.view(1, -1).expand(
813
- max_seq_len - item_seq_len, -1
814
- ),
815
  ],
816
  dim=0,
817
  )
@@ -840,13 +798,9 @@ class NextDiT(nn.Module):
840
  x, mask, img_size, freqs_cis = self.patchify_and_embed(x)
841
  freqs_cis = freqs_cis.to(x.device)
842
 
843
- # cap_freqs_cis = self.freqs_cis[:1, :cap_feats.shape[1]].to(x.device)
844
-
845
  t = self.t_embedder(t) # (N, D)
846
  cap_mask_float = cap_mask.float().unsqueeze(-1)
847
- cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(
848
- dim=1
849
- )
850
  cap_feats_pool = cap_feats_pool.to(cap_feats)
851
  cap_emb = self.cap_embedder(cap_feats_pool)
852
  adaln_input = t + cap_emb
@@ -871,25 +825,23 @@ class NextDiT(nn.Module):
871
  cap_feats,
872
  cap_mask,
873
  cfg_scale,
874
- rope_scaling_factor=None,
875
- scale_factor=None,
876
  base_seqlen: Optional[int] = None,
877
  proportional_attn: bool = False,
878
  ):
879
- # """
880
- # Forward pass of NextDiT, but also batches the unconNextditional forward pass
881
- # for classifier-free guidance.
882
- # """
883
  # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
884
- # print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
885
- if scale_factor is not None:
886
- assert scale_factor is not None
887
- self.freqs_cis = NextDiT.precompute_freqs_cis(
888
- self.dim // self.n_heads,
889
- 384,
890
- scale_factor=scale_factor,
891
- timestep=t[0],
892
- )
893
 
894
  if proportional_attn:
895
  assert base_seqlen is not None
@@ -903,7 +855,7 @@ class NextDiT(nn.Module):
903
 
904
  half = x[: len(x) // 2]
905
  combined = torch.cat([half, half], dim=0)
906
- model_out = self.forward(combined, t, cap_feats, cap_mask)
907
  # For exact reproducibility reasons, we apply classifier-free guidance on only
908
  # three channels by default. The standard approach to cfg applies it to all channels.
909
  # This can be done by uncommenting the following line and commenting-out the line following that.
@@ -912,6 +864,7 @@ class NextDiT(nn.Module):
912
  cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
913
  half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
914
  eps = torch.cat([half_eps, half_eps], dim=0)
 
915
  return torch.cat([eps, rest], dim=1)
916
 
917
  @staticmethod
@@ -919,8 +872,8 @@ class NextDiT(nn.Module):
919
  dim: int,
920
  end: int,
921
  theta: float = 10000.0,
922
- rope_scaling_factor: float = 1.0,
923
  scale_factor: float = 1.0,
 
924
  timestep: float = 1.0,
925
  ):
926
  """
@@ -942,15 +895,16 @@ class NextDiT(nn.Module):
942
  torch.Tensor: Precomputed frequency tensor with complex
943
  exponentials.
944
  """
945
- freqs_inter = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / scale_factor
946
-
947
- target_dim = timestep * dim + 1
948
- scale_factor = scale_factor ** (dim / target_dim)
949
- theta = theta * scale_factor
950
 
951
- freqs_time_scaled = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim))
 
 
 
 
 
952
 
953
- freqs = torch.max(freqs_inter, freqs_time_scaled)
 
954
 
955
  timestep = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
956
 
@@ -960,20 +914,14 @@ class NextDiT(nn.Module):
960
  freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
961
  freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
962
  freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
963
-
964
  return freqs_cis
965
 
966
  def parameter_count(self) -> int:
967
- tensor_parallel_module_list = (
968
- nn.Linear,
969
- nn.Linear,
970
- nn.Embedding,
971
- )
972
  total_params = 0
973
 
974
  def _recursive_count_params(module):
975
  nonlocal total_params
976
- is_tp_module = isinstance(module, tensor_parallel_module_list)
977
  for param in module.parameters(recurse=False):
978
  total_params += param.numel()
979
  for submodule in module.children():
@@ -992,5 +940,6 @@ class NextDiT(nn.Module):
992
  def NextDiT_2B_patch2(**kwargs):
993
  return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)
994
 
 
995
  def NextDiT_2B_GQA_patch2(**kwargs):
996
  return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, n_kv_heads=8, **kwargs)
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ # --------------------------------------------------------
7
+ # References:
8
+ # GLIDE: https://github.com/openai/glide-text2im
9
+ # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
+ # --------------------------------------------------------
11
+
12
  import math
13
+ from typing import List, Optional, Tuple
14
 
 
15
  from flash_attn import flash_attn_varlen_func
16
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
17
  import torch
 
19
  import torch.nn as nn
20
  import torch.nn.functional as F
21
 
22
+ from .components import RMSNorm
23
 
24
 
25
  def modulate(x, scale):
 
65
  """
66
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
67
  half = dim // 2
68
+ freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
69
+ device=t.device
70
+ )
 
 
71
  args = t[:, None].float() * freqs[None]
72
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
73
  if dim % 2:
74
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
 
 
75
  return embedding
76
 
77
  def forward(self, t):
 
89
  super().__init__()
90
  use_cfg_embedding = int(dropout_prob > 0)
91
  self.embedding_table = nn.Embedding(
92
+ num_classes + use_cfg_embedding
 
93
  )
94
  self.num_classes = num_classes
95
  self.dropout_prob = dropout_prob
 
99
  Drops labels to enable classifier-free guidance.
100
  """
101
  if force_drop_ids is None:
102
+ drop_ids = torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
 
 
103
  drop_ids = drop_ids.cuda()
104
  drop_ids = drop_ids.to(labels.device)
105
  else:
 
142
  """
143
  super().__init__()
144
  self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
145
+ self.n_heads = n_heads
146
+ self.n_kv_heads = self.n_kv_heads
147
+ self.n_rep = self.n_heads // self.n_kv_heads
 
148
  self.head_dim = dim // n_heads
149
 
150
  self.wq = nn.Linear(
 
173
  self.n_kv_heads * self.head_dim,
174
  bias=False,
175
  )
176
+ self.gate = nn.Parameter(torch.zeros([self.n_heads]))
177
 
178
  self.wo = nn.Linear(
179
  n_heads * self.head_dim,
 
182
  )
183
 
184
  if qk_norm:
185
+ self.q_norm = nn.LayerNorm(self.n_heads * self.head_dim)
186
+ self.k_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim)
187
  if y_dim > 0:
188
+ self.ky_norm = nn.LayerNorm(self.n_kv_heads * self.head_dim)
189
  else:
190
  self.ky_norm = nn.Identity()
191
  else:
 
255
  return x_out.type_as(x_in)
256
 
257
  # copied from huggingface modeling_llama.py
258
+ def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
 
 
259
  def _get_unpad_data(attention_mask):
260
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
261
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
262
  max_seqlen_in_batch = seqlens_in_batch.max().item()
263
+ cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
 
 
264
  return (
265
  indices,
266
  cu_seqlens,
 
280
  )
281
  if query_length == kv_seq_len:
282
  query_layer = index_first_axis(
283
+ query_layer.reshape(batch_size * kv_seq_len, self.n_heads, head_dim),
 
 
284
  indices_k,
285
  )
286
  cu_seqlens_q = cu_seqlens_k
 
296
  else:
297
  # The -q_len: slice assumes left padding.
298
  attention_mask = attention_mask[:, -query_length:]
299
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
 
300
 
301
  return (
302
  query_layer,
 
334
  xq = self.q_norm(xq)
335
  xk = self.k_norm(xk)
336
 
337
+ xq = xq.view(bsz, seqlen, self.n_heads, self.head_dim)
338
+ xk = xk.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
339
+ xv = xv.view(bsz, seqlen, self.n_kv_heads, self.head_dim)
340
 
341
  xq = Attention.apply_rotary_emb(xq, freqs_cis=freqs_cis)
342
  xk = Attention.apply_rotary_emb(xk, freqs_cis=freqs_cis)
343
 
344
  xq, xk = xq.to(dtype), xk.to(dtype)
345
 
346
+ if self.proportional_attn:
347
+ softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim)
348
+ else:
349
+ softmax_scale = math.sqrt(1 / self.head_dim)
350
+
351
  if dtype in [torch.float16, torch.bfloat16]:
352
  # begin var_len flash attn
353
  (
 
362
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
363
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
364
 
 
 
 
 
 
 
 
365
  attn_output_unpad = flash_attn_varlen_func(
366
  query_states,
367
  key_states,
 
383
  xq.permute(0, 2, 1, 3),
384
  xk.permute(0, 2, 1, 3),
385
  xv.permute(0, 2, 1, 3),
386
+ attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_heads, seqlen, -1),
387
+ scale=softmax_scale,
 
388
  )
389
  .permute(0, 2, 1, 3)
390
  .to(dtype)
391
  )
392
 
393
  if hasattr(self, "wk_y"):
394
+ yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_kv_heads, self.head_dim)
395
+ yv = self.wv_y(y).view(bsz, -1, self.n_kv_heads, self.head_dim)
396
+ n_rep = self.n_heads // self.n_kv_heads
 
 
 
397
  if n_rep >= 1:
398
  yk = yk.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
399
  yv = yv.unsqueeze(3).repeat(1, 1, 1, n_rep, 1).flatten(2, 3)
 
401
  xq.permute(0, 2, 1, 3),
402
  yk.permute(0, 2, 1, 3),
403
  yv.permute(0, 2, 1, 3),
404
+ y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_heads, seqlen, -1),
405
  ).permute(0, 2, 1, 3)
406
  output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
407
  output = output + output_y
 
519
  )
520
  self.layer_id = layer_id
521
  self.attention_norm1 = RMSNorm(dim, eps=norm_eps)
 
 
522
  self.ffn_norm1 = RMSNorm(dim, eps=norm_eps)
523
+
524
+ self.attention_norm2 = RMSNorm(dim, eps=norm_eps)
525
  self.ffn_norm2 = RMSNorm(dim, eps=norm_eps)
526
 
527
  self.adaLN_modulation = nn.Sequential(
 
568
  y_mask,
569
  )
570
  )
 
571
  x = x + gate_mlp.unsqueeze(1).tanh() * self.ffn_norm2(
572
  self.feed_forward(
573
+ modulate(self.ffn_norm1(x), scale_mlp),
574
+ )
575
  )
576
 
577
  else:
578
+ x = x + self.attention_norm2(
579
  self.attention(
580
+ self.attention_norm1(x),
581
  x_mask,
582
  freqs_cis,
583
  self.attention_y_norm(y),
584
  y_mask,
585
  )
586
  )
587
+ x = x + self.ffn_norm2(self.feed_forward(self.ffn_norm1(x)))
 
 
 
 
588
 
589
  return x
590
 
591
 
592
+ class FinalLayer(nn.Module):
593
  """
594
  The final layer of NextDiT.
595
  """
 
604
  self.linear = nn.Linear(
605
  hidden_size,
606
  patch_size * patch_size * out_channels,
 
607
  )
608
  self.adaLN_modulation = nn.Sequential(
609
  nn.SiLU(),
610
  nn.Linear(
611
  min(hidden_size, 1024),
612
  hidden_size,
 
613
  ),
614
  )
615
 
616
  def forward(self, x, c):
617
  scale = self.adaLN_modulation(c)
618
+
619
  x = modulate(self.norm_final(x), scale)
620
  x = self.linear(x)
621
  return x
 
640
  learn_sigma: bool = True,
641
  qk_norm: bool = False,
642
  cap_feat_dim: int = 5120,
 
643
  scale_factor: float = 1.0,
644
  ) -> None:
645
  super().__init__()
 
681
  for layer_id in range(n_layers)
682
  ]
683
  )
684
+ self.final_layer = FinalLayer(dim, patch_size, self.out_channels)
685
 
686
  assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
 
 
687
  self.freqs_cis = NextDiT.precompute_freqs_cis(
688
  dim // n_heads,
689
  384,
 
690
  scale_factor=scale_factor,
691
  )
692
+ self.dim = dim
693
+ self.n_heads = n_heads
694
  self.scale_factor = scale_factor
 
695
  self.pad_token = nn.Parameter(torch.empty(dim))
 
696
  nn.init.normal_(self.pad_token, std=0.02)
697
 
698
+ def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]:
 
 
699
  """
700
  x: (N, T, patch_size**2 * C)
701
  imgs: (N, H, W, C)
 
729
  if isinstance(x, torch.Tensor):
730
  pH = pW = self.patch_size
731
  B, C, H, W = x.size()
732
+ x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3)
 
 
 
 
733
  x = self.x_embedder(x)
734
  x = x.flatten(1, 2)
735
 
736
+ mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
737
+
 
 
738
  return (
739
  x,
740
  mask,
 
753
  item_freqs_cis = self.freqs_cis[: H // pH, : W // pW]
754
  freqs_cis.append(item_freqs_cis.flatten(0, 1))
755
  img_size.append((H, W))
756
+ img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2)
 
 
 
 
757
  img = self.x_embedder(img)
758
  img = img.flatten(0, 1)
759
  l_effective_seq_len.append(len(img))
760
  x_embed.append(img)
761
 
762
  max_seq_len = max(l_effective_seq_len)
763
+ mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device)
 
 
764
  padded_x_embed = []
765
  padded_freqs_cis = []
766
  for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate(
 
769
  item_embed = torch.cat(
770
  [
771
  item_embed,
772
+ self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1),
 
 
773
  ],
774
  dim=0,
775
  )
 
798
  x, mask, img_size, freqs_cis = self.patchify_and_embed(x)
799
  freqs_cis = freqs_cis.to(x.device)
800
 
 
 
801
  t = self.t_embedder(t) # (N, D)
802
  cap_mask_float = cap_mask.float().unsqueeze(-1)
803
+ cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1)
 
 
804
  cap_feats_pool = cap_feats_pool.to(cap_feats)
805
  cap_emb = self.cap_embedder(cap_feats_pool)
806
  adaln_input = t + cap_emb
 
825
  cap_feats,
826
  cap_mask,
827
  cfg_scale,
828
+ scale_factor=1.0,
829
+ scale_watershed=1.0,
830
  base_seqlen: Optional[int] = None,
831
  proportional_attn: bool = False,
832
  ):
833
+ """
834
+ Forward pass of NextDiT, but also batches the unconditional forward pass
835
+ for classifier-free guidance.
836
+ """
837
  # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
838
+ self.freqs_cis = NextDiT.precompute_freqs_cis(
839
+ self.dim // self.n_heads,
840
+ 384,
841
+ scale_factor=scale_factor,
842
+ scale_watershed=scale_watershed,
843
+ timestep=t[0].item(),
844
+ )
 
 
845
 
846
  if proportional_attn:
847
  assert base_seqlen is not None
 
855
 
856
  half = x[: len(x) // 2]
857
  combined = torch.cat([half, half], dim=0)
858
+ model_out = self(combined, t, cap_feats, cap_mask)
859
  # For exact reproducibility reasons, we apply classifier-free guidance on only
860
  # three channels by default. The standard approach to cfg applies it to all channels.
861
  # This can be done by uncommenting the following line and commenting-out the line following that.
 
864
  cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
865
  half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
866
  eps = torch.cat([half_eps, half_eps], dim=0)
867
+
868
  return torch.cat([eps, rest], dim=1)
869
 
870
  @staticmethod
 
872
  dim: int,
873
  end: int,
874
  theta: float = 10000.0,
 
875
  scale_factor: float = 1.0,
876
+ scale_watershed: float = 1.0,
877
  timestep: float = 1.0,
878
  ):
879
  """
 
895
  torch.Tensor: Precomputed frequency tensor with complex
896
  exponentials.
897
  """
 
 
 
 
 
898
 
899
+ if timestep < scale_watershed:
900
+ linear_factor = scale_factor
901
+ ntk_factor = 1.0
902
+ else:
903
+ linear_factor = 1.0
904
+ ntk_factor = scale_factor
905
 
906
+ theta = theta * ntk_factor
907
+ freqs = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / linear_factor
908
 
909
  timestep = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
910
 
 
914
  freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
915
  freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
916
  freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
917
+
918
  return freqs_cis
919
 
920
  def parameter_count(self) -> int:
 
 
 
 
 
921
  total_params = 0
922
 
923
  def _recursive_count_params(module):
924
  nonlocal total_params
 
925
  for param in module.parameters(recurse=False):
926
  total_params += param.numel()
927
  for submodule in module.children():
 
940
  def NextDiT_2B_patch2(**kwargs):
941
  return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)
942
 
943
+
944
  def NextDiT_2B_GQA_patch2(**kwargs):
945
  return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, n_kv_heads=8, **kwargs)