PommesPeter commited on
Commit
db5524c
1 Parent(s): be34a3d

Update models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +318 -198
models/model.py CHANGED
@@ -1,25 +1,9 @@
1
- # Copyright (c) Meta Platforms, Inc. and affiliates.
2
- # All rights reserved.
3
-
4
- # This source code is licensed under the license found in the
5
- # LICENSE file in the root directory of this source tree.
6
- # --------------------------------------------------------
7
- # References:
8
- # GLIDE: https://github.com/openai/glide-text2im
9
- # MAE: https://github.com/facebookresearch/mae/blob/main/models_mae.py
10
- # --------------------------------------------------------
11
-
12
  import functools
13
  import logging
14
  import math
15
  from typing import Optional, Tuple, List
16
 
17
- # from apex.normalization import FusedRMSNorm as RMSNorm
18
  from .components import RMSNorm
19
- import fairscale.nn.model_parallel.initialize as fs_init
20
- from fairscale.nn.model_parallel.layers import (
21
- ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
22
- )
23
  from flash_attn import flash_attn_varlen_func
24
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
25
  import torch
@@ -38,22 +22,25 @@ def modulate(x, shift, scale):
38
  # Embedding Layers for Timesteps and Class Labels #
39
  #############################################################################
40
 
 
41
  class ParallelTimestepEmbedder(nn.Module):
42
  """
43
  Embeds scalar timesteps into vector representations.
44
  """
 
45
  def __init__(self, hidden_size, frequency_embedding_size=256):
46
  super().__init__()
47
  self.mlp = nn.Sequential(
48
- ColumnParallelLinear(
49
- frequency_embedding_size, hidden_size, bias=True,
50
- gather_output=False,
51
- init_method=functools.partial(nn.init.normal_, std=0.02),
52
  ),
53
  nn.SiLU(),
54
- RowParallelLinear(
55
- hidden_size, hidden_size, bias=True, input_is_parallel=True,
56
- init_method=functools.partial(nn.init.normal_, std=0.02),
 
57
  ),
58
  )
59
  self.frequency_embedding_size = frequency_embedding_size
@@ -71,16 +58,16 @@ class ParallelTimestepEmbedder(nn.Module):
71
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
72
  half = dim // 2
73
  freqs = torch.exp(
74
- -math.log(max_period) * torch.arange(
75
- start=0, end=half, dtype=torch.float32
76
- ) / half
77
  ).to(device=t.device)
78
  args = t[:, None].float() * freqs[None]
79
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
80
  if dim % 2:
81
- embedding = torch.cat([
82
- embedding, torch.zeros_like(embedding[:, :1])
83
- ], dim=-1)
84
  return embedding
85
 
86
  def forward(self, t):
@@ -93,12 +80,13 @@ class ParallelLabelEmbedder(nn.Module):
93
  r"""Embeds class labels into vector representations. Also handles label
94
  dropout for classifier-free guidance.
95
  """
 
96
  def __init__(self, num_classes, hidden_size, dropout_prob):
97
  super().__init__()
98
  use_cfg_embedding = int(dropout_prob > 0)
99
- self.embedding_table = ParallelEmbedding(
100
- num_classes + use_cfg_embedding, hidden_size,
101
- init_method=functools.partial(nn.init.normal_, std=0.02),
102
  )
103
  self.num_classes = num_classes
104
  self.dropout_prob = dropout_prob
@@ -108,15 +96,10 @@ class ParallelLabelEmbedder(nn.Module):
108
  Drops labels to enable classifier-free guidance.
109
  """
110
  if force_drop_ids is None:
111
- drop_ids = torch.rand(
112
- labels.shape[0], device=labels.device
113
- ) < self.dropout_prob
114
- drop_ids = drop_ids.cuda()
115
- dist.broadcast(
116
- drop_ids,
117
- fs_init.get_model_parallel_src_rank(),
118
- fs_init.get_model_parallel_group(),
119
  )
 
120
  drop_ids = drop_ids.to(labels.device)
121
  else:
122
  drop_ids = force_drop_ids == 1
@@ -132,13 +115,21 @@ class ParallelLabelEmbedder(nn.Module):
132
 
133
 
134
  #############################################################################
135
- # Core DiT Model #
136
  #############################################################################
137
 
138
 
139
  class Attention(nn.Module):
140
  """Multi-head attention module."""
141
- def __init__(self, dim: int, n_heads: int, n_kv_heads: Optional[int], qk_norm: bool, y_dim: int):
 
 
 
 
 
 
 
 
142
  """
143
  Initialize the Attention module.
144
 
@@ -150,38 +141,44 @@ class Attention(nn.Module):
150
  """
151
  super().__init__()
152
  self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
153
- model_parallel_size = fs_init.get_model_parallel_world_size()
154
  self.n_local_heads = n_heads // model_parallel_size
155
  self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
156
  self.n_rep = self.n_local_heads // self.n_local_kv_heads
157
  self.head_dim = dim // n_heads
158
 
159
- self.wq = ColumnParallelLinear(
160
- dim, n_heads * self.head_dim, bias=False, gather_output=False,
161
- init_method=nn.init.xavier_uniform_,
 
162
  )
163
- self.wk = ColumnParallelLinear(
164
- dim, self.n_kv_heads * self.head_dim, bias=False,
165
- gather_output=False, init_method=nn.init.xavier_uniform_,
 
166
  )
167
- self.wv = ColumnParallelLinear(
168
- dim, self.n_kv_heads * self.head_dim, bias=False,
169
- gather_output=False, init_method=nn.init.xavier_uniform_,
 
170
  )
171
  if y_dim > 0:
172
- self.wk_y = ColumnParallelLinear(
173
- y_dim, self.n_kv_heads * self.head_dim, bias=False,
174
- gather_output=False, init_method=nn.init.xavier_uniform_,
 
175
  )
176
- self.wv_y = ColumnParallelLinear(
177
- y_dim, self.n_kv_heads * self.head_dim, bias=False,
178
- gather_output=False, init_method=nn.init.xavier_uniform_,
 
179
  )
180
  self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
181
 
182
- self.wo = RowParallelLinear(
183
- n_heads * self.head_dim, dim, bias=False,
184
- input_is_parallel=True, init_method=nn.init.xavier_uniform_,
 
185
  )
186
 
187
  if qk_norm:
@@ -194,7 +191,7 @@ class Attention(nn.Module):
194
  else:
195
  self.q_norm = self.k_norm = nn.Identity()
196
  self.ky_norm = nn.Identity()
197
-
198
  # for proportional attention computation
199
  self.base_seqlen = None
200
  self.proportional_attn = False
@@ -224,8 +221,7 @@ class Attention(nn.Module):
224
  ndim = x.ndim
225
  assert 0 <= 1 < ndim
226
  assert freqs_cis.shape == (x.shape[1], x.shape[-1])
227
- shape = [d if i == 1 or i == ndim - 1 else 1
228
- for i, d in enumerate(x.shape)]
229
  return freqs_cis.view(*shape)
230
 
231
  @staticmethod
@@ -259,13 +255,17 @@ class Attention(nn.Module):
259
  return x_out.type_as(x_in)
260
 
261
  # copied from huggingface modeling_llama.py
262
- def _upad_input(self, query_layer, key_layer, value_layer, attention_mask, query_length):
 
 
263
 
264
  def _get_unpad_data(attention_mask):
265
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
266
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
267
  max_seqlen_in_batch = seqlens_in_batch.max().item()
268
- cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
 
 
269
  return (
270
  indices,
271
  cu_seqlens,
@@ -276,14 +276,19 @@ class Attention(nn.Module):
276
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
277
 
278
  key_layer = index_first_axis(
279
- key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
280
  )
281
  value_layer = index_first_axis(
282
- value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim), indices_k
 
283
  )
284
  if query_length == kv_seq_len:
285
  query_layer = index_first_axis(
286
- query_layer.reshape(batch_size * kv_seq_len, self.n_local_heads, head_dim), indices_k
 
 
 
287
  )
288
  cu_seqlens_q = cu_seqlens_k
289
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
@@ -298,7 +303,9 @@ class Attention(nn.Module):
298
  else:
299
  # The -q_len: slice assumes left padding.
300
  attention_mask = attention_mask[:, -query_length:]
301
- query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(query_layer, attention_mask)
 
 
302
 
303
  return (
304
  query_layer,
@@ -347,15 +354,22 @@ class Attention(nn.Module):
347
 
348
  if dtype in [torch.float16, torch.bfloat16]:
349
  # begin var_len flash attn
350
- query_states, key_states, value_states, indices_q, cu_seq_lens, max_seq_lens = self._upad_input(
351
- xq, xk, xv, x_mask, seqlen
352
- )
 
 
 
 
 
353
 
354
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
355
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
356
 
357
  if self.proportional_attn:
358
- softmax_scale = math.sqrt(math.log(seqlen, self.base_seqlen) / self.head_dim)
 
 
359
  else:
360
  softmax_scale = math.sqrt(1 / self.head_dim)
361
 
@@ -367,24 +381,32 @@ class Attention(nn.Module):
367
  cu_seqlens_k=cu_seqlens_k,
368
  max_seqlen_q=max_seqlen_in_batch_q,
369
  max_seqlen_k=max_seqlen_in_batch_k,
370
- dropout_p=0.,
371
  causal=False,
372
- softmax_scale=softmax_scale
373
  )
374
  output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
375
  # end var_len_flash_attn
376
 
377
  else:
378
- output = F.scaled_dot_product_attention(
379
- xq.permute(0, 2, 1, 3),
380
- xk.permute(0, 2, 1, 3),
381
- xv.permute(0, 2, 1, 3),
382
- attn_mask=x_mask.bool().view(bsz, 1, 1, seqlen).expand(-1, self.n_local_heads, seqlen, -1),
383
- ).permute(0, 2, 1, 3).to(dtype)
 
 
 
 
 
 
384
 
385
  if hasattr(self, "wk_y"):
386
  # todo better flash_attn support
387
- yk = self.ky_norm(self.wk_y(y)).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
 
 
388
  yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
389
  n_rep = self.n_local_heads // self.n_local_kv_heads
390
  if n_rep >= 1:
@@ -394,7 +416,7 @@ class Attention(nn.Module):
394
  xq.permute(0, 2, 1, 3),
395
  yk.permute(0, 2, 1, 3),
396
  yv.permute(0, 2, 1, 3),
397
- y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1)
398
  ).permute(0, 2, 1, 3)
399
  output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
400
  output = output + output_y
@@ -424,10 +446,10 @@ class FeedForward(nn.Module):
424
  dimension. Defaults to None.
425
 
426
  Attributes:
427
- w1 (ColumnParallelLinear): Linear transformation for the first
428
  layer.
429
- w2 (RowParallelLinear): Linear transformation for the second layer.
430
- w3 (ColumnParallelLinear): Linear transformation for the third
431
  layer.
432
 
433
  """
@@ -436,21 +458,22 @@ class FeedForward(nn.Module):
436
  # custom dim factor multiplier
437
  if ffn_dim_multiplier is not None:
438
  hidden_dim = int(ffn_dim_multiplier * hidden_dim)
439
- hidden_dim = multiple_of * (
440
- (hidden_dim + multiple_of - 1) // multiple_of
441
- )
442
 
443
- self.w1 = ColumnParallelLinear(
444
- dim, hidden_dim, bias=False, gather_output=False,
445
- init_method=nn.init.xavier_uniform_,
 
446
  )
447
- self.w2 = RowParallelLinear(
448
- hidden_dim, dim, bias=False, input_is_parallel=True,
449
- init_method=nn.init.xavier_uniform_,
 
450
  )
451
- self.w3 = ColumnParallelLinear(
452
- dim, hidden_dim, bias=False, gather_output=False,
453
- init_method=nn.init.xavier_uniform_,
 
454
  )
455
 
456
  # @torch.compile
@@ -462,9 +485,18 @@ class FeedForward(nn.Module):
462
 
463
 
464
  class TransformerBlock(nn.Module):
465
- def __init__(self, layer_id: int, dim: int, n_heads: int, n_kv_heads: int,
466
- multiple_of: int, ffn_dim_multiplier: float, norm_eps: float,
467
- qk_norm: bool, y_dim: int) -> None:
 
 
 
 
 
 
 
 
 
468
  """
469
  Initialize a TransformerBlock.
470
 
@@ -495,7 +527,9 @@ class TransformerBlock(nn.Module):
495
  self.head_dim = dim // n_heads
496
  self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim)
497
  self.feed_forward = FeedForward(
498
- dim=dim, hidden_dim=4 * dim, multiple_of=multiple_of,
 
 
499
  ffn_dim_multiplier=ffn_dim_multiplier,
500
  )
501
  self.layer_id = layer_id
@@ -506,9 +540,10 @@ class TransformerBlock(nn.Module):
506
 
507
  self.adaLN_modulation = nn.Sequential(
508
  nn.SiLU(),
509
- ColumnParallelLinear(
510
- min(dim, 1024), 6 * dim, bias=True, gather_output=True,
511
- init_method=nn.init.zeros_,
 
512
  ),
513
  )
514
 
@@ -536,28 +571,41 @@ class TransformerBlock(nn.Module):
536
 
537
  """
538
  if adaln_input is not None:
539
- shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
540
  self.adaLN_modulation(adaln_input).chunk(6, dim=1)
 
541
 
542
- x = x + self.attention_norm1(gate_msa.unsqueeze(1) * self.attention(
543
- modulate(self.attention_norm(x), shift_msa, scale_msa),
544
- x_mask,
545
- freqs_cis,
546
- self.attention_y_norm(y),
547
- y_mask,
548
- ))
 
 
 
549
  d = x.shape[-1]
550
- x = x + self.ffn_norm1(gate_mlp.unsqueeze(1) * self.feed_forward(
551
- modulate(self.ffn_norm(x), shift_mlp, scale_mlp).view(-1, d),
552
- ).view(*x.shape))
 
 
 
553
 
554
  else:
555
- x = x + self.attention_norm1(self.attention(
556
- self.attention_norm(x), x_mask, freqs_cis, self.attention_y_norm(y), y_mask
557
- ))
 
 
 
 
 
 
558
  # for compatibility with torch.compile because the sequence length changes
559
  B, L, D = x.shape
560
- x = x.view(B*L, D)
561
  x = x + self.ffn_norm1(self.feed_forward(self.ffn_norm(x)))
562
  x = x.view(B, L, D)
563
 
@@ -566,22 +614,27 @@ class TransformerBlock(nn.Module):
566
 
567
  class ParallelFinalLayer(nn.Module):
568
  """
569
- The final layer of DiT.
570
  """
 
571
  def __init__(self, hidden_size, patch_size, out_channels):
572
  super().__init__()
573
  self.norm_final = nn.LayerNorm(
574
- hidden_size, elementwise_affine=False, eps=1e-6,
 
 
575
  )
576
- self.linear = ColumnParallelLinear(
577
- hidden_size, patch_size * patch_size * out_channels, bias=True,
578
- init_method=nn.init.zeros_, gather_output=True,
 
579
  )
580
  self.adaLN_modulation = nn.Sequential(
581
  nn.SiLU(),
582
- ColumnParallelLinear(
583
- min(hidden_size, 1024), 2 * hidden_size, bias=True,
584
- init_method=nn.init.zeros_, gather_output=True,
 
585
  ),
586
  )
587
 
@@ -596,6 +649,7 @@ class NextDiT(nn.Module):
596
  """
597
  Diffusion model with a Transformer backbone.
598
  """
 
599
  def __init__(
600
  self,
601
  patch_size: int = 2,
@@ -610,8 +664,8 @@ class NextDiT(nn.Module):
610
  learn_sigma: bool = True,
611
  qk_norm: bool = False,
612
  cap_feat_dim: int = 5120,
613
- rope_scaling_factor: float = 1.,
614
- ntk_factor: float=1.
615
  ) -> None:
616
  super().__init__()
617
  self.learn_sigma = learn_sigma
@@ -619,34 +673,49 @@ class NextDiT(nn.Module):
619
  self.out_channels = in_channels * 2 if learn_sigma else in_channels
620
  self.patch_size = patch_size
621
 
622
- self.x_embedder = ColumnParallelLinear(
623
  in_features=patch_size * patch_size * in_channels,
624
  out_features=dim,
625
  bias=True,
626
- gather_output=True,
627
- init_method=nn.init.xavier_uniform_,
628
  )
629
- nn.init.constant_(self.x_embedder.bias, 0.)
630
 
631
  self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024))
632
  self.cap_embedder = nn.Sequential(
633
  nn.LayerNorm(cap_feat_dim),
634
- ColumnParallelLinear(cap_feat_dim, min(dim, 1024), bias=True, gather_output=True,
635
- init_method=nn.init.zeros_),
 
 
 
636
  )
637
 
638
- self.layers = nn.ModuleList([
639
- TransformerBlock(layer_id, dim, n_heads, n_kv_heads, multiple_of,
640
- ffn_dim_multiplier, norm_eps, qk_norm, cap_feat_dim)
641
- for layer_id in range(n_layers)
642
- ])
 
 
 
 
 
 
 
 
 
 
 
643
  self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
644
 
645
  assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
646
  self.dim = dim
647
  self.n_heads = n_heads
648
  self.freqs_cis = NextDiT.precompute_freqs_cis(
649
- dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
 
 
 
650
  )
651
  self.rope_scaling_factor = rope_scaling_factor
652
  self.ntk_factor = ntk_factor
@@ -655,7 +724,9 @@ class NextDiT(nn.Module):
655
  # nn.init.normal_(self.eol_token, std=0.02)
656
  nn.init.normal_(self.pad_token, std=0.02)
657
 
658
- def unpatchify(self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False) -> List[torch.Tensor]:
 
 
659
  """
660
  x: (N, T, patch_size**2 * C)
661
  imgs: (N, H, W, C)
@@ -673,26 +744,40 @@ class NextDiT(nn.Module):
673
  for i in range(x.size(0)):
674
  H, W = img_size[i]
675
  L = (H // pH) * (W // pW)
676
- imgs.append(x[i][:L].view(
677
- H // pH, W // pW, pH, pW, self.out_channels
678
- ).permute(4, 0, 2, 1, 3).flatten(3, 4).flatten(1, 2))
 
 
 
 
679
  return imgs
680
 
681
  def patchify_and_embed(
682
- self,
683
- x: List[torch.Tensor] | torch.Tensor
684
  ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]:
685
  self.freqs_cis = self.freqs_cis.to(x[0].device)
686
  if isinstance(x, torch.Tensor):
687
  pH = pW = self.patch_size
688
  B, C, H, W = x.size()
689
- x = x.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 1, 3, 5).flatten(3)
 
 
 
 
690
  x = self.x_embedder(x)
691
  x = x.flatten(1, 2)
692
 
693
- mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
 
 
694
  # leave the first line for text
695
- return x, mask, [(H, W)] * B, self.freqs_cis[:H//pH, :W//pW].flatten(0,1).unsqueeze(0)
 
 
 
 
 
696
  else:
697
  pH = pW = self.patch_size
698
  x_embed = []
@@ -702,30 +787,44 @@ class NextDiT(nn.Module):
702
 
703
  for img in x:
704
  C, H, W = img.size()
705
- item_freqs_cis = self.freqs_cis[:H//pH, :W//pW]
706
- freqs_cis.append(item_freqs_cis.flatten(0,1))
707
  img_size.append((H, W))
708
- img = img.view(C, H // pH, pH, W // pW, pW).permute(1, 3, 0, 2, 4).flatten(2)
 
 
 
 
709
  img = self.x_embedder(img)
710
  img = img.flatten(0, 1)
711
  l_effective_seq_len.append(len(img))
712
  x_embed.append(img)
713
 
714
  max_seq_len = max(l_effective_seq_len)
715
- mask = torch.zeros(len(x), max_seq_len, dtype=torch.int32, device=x[0].device)
 
 
716
  padded_x_embed = []
717
  padded_freqs_cis = []
718
- for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate(zip(
719
- x_embed, freqs_cis, l_effective_seq_len
720
- )):
721
- item_embed = torch.cat([
722
- item_embed,
723
- self.pad_token.view(1, -1).expand(max_seq_len - item_seq_len, -1),
724
- ], dim=0)
725
- item_freqs_cis = torch.cat([
726
- item_freqs_cis,
727
- item_freqs_cis[-1:].expand(max_seq_len - item_seq_len, -1)
728
- ], dim=0)
 
 
 
 
 
 
 
 
729
  padded_x_embed.append(item_embed)
730
  padded_freqs_cis.append(item_freqs_cis)
731
  mask[i][:item_seq_len] = 1
@@ -736,7 +835,7 @@ class NextDiT(nn.Module):
736
 
737
  def forward(self, x, t, cap_feats, cap_mask):
738
  """
739
- Forward pass of DiT.
740
  t: (N,) tensor of diffusion timesteps
741
  y: (N,) tensor of class labels
742
  """
@@ -746,19 +845,18 @@ class NextDiT(nn.Module):
746
 
747
  # cap_freqs_cis = self.freqs_cis[:1, :cap_feats.shape[1]].to(x.device)
748
 
749
- t = self.t_embedder(t) # (N, D)
750
  cap_mask_float = cap_mask.float().unsqueeze(-1)
751
- cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(dim=1)
 
 
752
  cap_feats_pool = cap_feats_pool.to(cap_feats)
753
  cap_emb = self.cap_embedder(cap_feats_pool)
754
  adaln_input = t + cap_emb
755
 
756
  cap_mask = cap_mask.bool()
757
  for layer in self.layers:
758
- x = layer(
759
- x, mask, freqs_cis, cap_feats, cap_mask,
760
- adaln_input=adaln_input
761
- )
762
 
763
  x = self.final_layer(x, adaln_input)
764
  x = self.unpatchify(x, img_size, return_tensor=x_is_tensor)
@@ -769,25 +867,48 @@ class NextDiT(nn.Module):
769
  x = [_.chunk(2, dim=0)[0] for _ in x]
770
  return x
771
 
772
- def forward_with_cfg(self, x, t, cap_feats, cap_mask, cfg_scale, rope_scaling_factor=None, ntk_factor=None, base_seqlen: Optional[int] = None, proportional_attn: bool = False):
 
 
 
 
 
 
 
 
 
 
 
773
  # """
774
- # Forward pass of DiT, but also batches the unconditional forward pass
775
  # for classifier-free guidance.
776
  # """
777
  # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
778
  # print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
779
  if rope_scaling_factor is not None or ntk_factor is not None:
780
- rope_scaling_factor = rope_scaling_factor if rope_scaling_factor is not None else self.rope_scaling_factor
 
 
 
 
781
  ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
782
- if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
783
- print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
 
 
 
 
 
 
784
  self.freqs_cis = NextDiT.precompute_freqs_cis(
785
- self.dim // self.n_heads, 384,
786
- rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
 
 
787
  )
788
  self.rope_scaling_factor = rope_scaling_factor
789
  self.ntk_factor = ntk_factor
790
-
791
  if proportional_attn:
792
  assert base_seqlen is not None
793
  for layer in self.layers:
@@ -817,7 +938,7 @@ class NextDiT(nn.Module):
817
  end: int,
818
  theta: float = 10000.0,
819
  rope_scaling_factor: float = 1.0,
820
- ntk_factor: float = 1.0
821
  ):
822
  """
823
  Precompute the frequency tensor for complex exponentials (cis) with
@@ -841,23 +962,27 @@ class NextDiT(nn.Module):
841
 
842
  theta = theta * ntk_factor
843
 
844
- logger.info(f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}")
845
- freqs = 1.0 / (theta ** (
846
- torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim
847
- ))
 
 
848
  t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
849
  t = t / rope_scaling_factor
850
  freqs = torch.outer(t, freqs).float() # type: ignore
851
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
852
 
853
- freqs_cis_h = freqs_cis.view(end, 1, dim//4, 1).repeat(1, end, 1, 1)
854
- freqs_cis_w = freqs_cis.view(1, end, dim//4, 1).repeat(end, 1, 1, 1)
855
  freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
856
  return freqs_cis
857
 
858
  def parameter_count(self) -> int:
859
  tensor_parallel_module_list = (
860
- ColumnParallelLinear, RowParallelLinear, ParallelEmbedding,
 
 
861
  )
862
  total_params = 0
863
 
@@ -865,10 +990,7 @@ class NextDiT(nn.Module):
865
  nonlocal total_params
866
  is_tp_module = isinstance(module, tensor_parallel_module_list)
867
  for param in module.parameters(recurse=False):
868
- total_params += param.numel() * (
869
- fs_init.get_model_parallel_world_size()
870
- if is_tp_module else 1
871
- )
872
  for submodule in module.children():
873
  _recursive_count_params(submodule)
874
 
@@ -880,9 +1002,7 @@ class NextDiT(nn.Module):
880
 
881
 
882
  #############################################################################
883
- # DiT Configs #
884
  #############################################################################
885
  def NextDiT_2B_patch2(**kwargs):
886
- return NextDiT(
887
- patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
888
- )
 
 
 
 
 
 
 
 
 
 
 
 
1
  import functools
2
  import logging
3
  import math
4
  from typing import Optional, Tuple, List
5
 
 
6
  from .components import RMSNorm
 
 
 
 
7
  from flash_attn import flash_attn_varlen_func
8
  from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
9
  import torch
 
22
  # Embedding Layers for Timesteps and Class Labels #
23
  #############################################################################
24
 
25
+
26
  class ParallelTimestepEmbedder(nn.Module):
27
  """
28
  Embeds scalar timesteps into vector representations.
29
  """
30
+
31
  def __init__(self, hidden_size, frequency_embedding_size=256):
32
  super().__init__()
33
  self.mlp = nn.Sequential(
34
+ nn.Linear(
35
+ frequency_embedding_size,
36
+ hidden_size,
37
+ bias=True,
38
  ),
39
  nn.SiLU(),
40
+ nn.Linear(
41
+ hidden_size,
42
+ hidden_size,
43
+ bias=True,
44
  ),
45
  )
46
  self.frequency_embedding_size = frequency_embedding_size
 
58
  # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py
59
  half = dim // 2
60
  freqs = torch.exp(
61
+ -math.log(max_period)
62
+ * torch.arange(start=0, end=half, dtype=torch.float32)
63
+ / half
64
  ).to(device=t.device)
65
  args = t[:, None].float() * freqs[None]
66
  embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
67
  if dim % 2:
68
+ embedding = torch.cat(
69
+ [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
70
+ )
71
  return embedding
72
 
73
  def forward(self, t):
 
80
  r"""Embeds class labels into vector representations. Also handles label
81
  dropout for classifier-free guidance.
82
  """
83
+
84
  def __init__(self, num_classes, hidden_size, dropout_prob):
85
  super().__init__()
86
  use_cfg_embedding = int(dropout_prob > 0)
87
+ self.embedding_table = nn.Embedding(
88
+ num_classes + use_cfg_embedding,
89
+ hidden_size,
90
  )
91
  self.num_classes = num_classes
92
  self.dropout_prob = dropout_prob
 
96
  Drops labels to enable classifier-free guidance.
97
  """
98
  if force_drop_ids is None:
99
+ drop_ids = (
100
+ torch.rand(labels.shape[0], device=labels.device) < self.dropout_prob
 
 
 
 
 
 
101
  )
102
+ drop_ids = drop_ids.cuda()
103
  drop_ids = drop_ids.to(labels.device)
104
  else:
105
  drop_ids = force_drop_ids == 1
 
115
 
116
 
117
  #############################################################################
118
+ # Core NextDiT Model #
119
  #############################################################################
120
 
121
 
122
  class Attention(nn.Module):
123
  """Multi-head attention module."""
124
+
125
+ def __init__(
126
+ self,
127
+ dim: int,
128
+ n_heads: int,
129
+ n_kv_heads: Optional[int],
130
+ qk_norm: bool,
131
+ y_dim: int,
132
+ ):
133
  """
134
  Initialize the Attention module.
135
 
 
141
  """
142
  super().__init__()
143
  self.n_kv_heads = n_heads if n_kv_heads is None else n_kv_heads
144
+ model_parallel_size = 1
145
  self.n_local_heads = n_heads // model_parallel_size
146
  self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
147
  self.n_rep = self.n_local_heads // self.n_local_kv_heads
148
  self.head_dim = dim // n_heads
149
 
150
+ self.wq = nn.Linear(
151
+ dim,
152
+ n_heads * self.head_dim,
153
+ bias=False,
154
  )
155
+ self.wk = nn.Linear(
156
+ dim,
157
+ self.n_kv_heads * self.head_dim,
158
+ bias=False,
159
  )
160
+ self.wv = nn.Linear(
161
+ dim,
162
+ self.n_kv_heads * self.head_dim,
163
+ bias=False,
164
  )
165
  if y_dim > 0:
166
+ self.wk_y = nn.Linear(
167
+ y_dim,
168
+ self.n_kv_heads * self.head_dim,
169
+ bias=False,
170
  )
171
+ self.wv_y = nn.Linear(
172
+ y_dim,
173
+ self.n_kv_heads * self.head_dim,
174
+ bias=False,
175
  )
176
  self.gate = nn.Parameter(torch.zeros([self.n_local_heads]))
177
 
178
+ self.wo = nn.Linear(
179
+ n_heads * self.head_dim,
180
+ dim,
181
+ bias=False,
182
  )
183
 
184
  if qk_norm:
 
191
  else:
192
  self.q_norm = self.k_norm = nn.Identity()
193
  self.ky_norm = nn.Identity()
194
+
195
  # for proportional attention computation
196
  self.base_seqlen = None
197
  self.proportional_attn = False
 
221
  ndim = x.ndim
222
  assert 0 <= 1 < ndim
223
  assert freqs_cis.shape == (x.shape[1], x.shape[-1])
224
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
 
225
  return freqs_cis.view(*shape)
226
 
227
  @staticmethod
 
255
  return x_out.type_as(x_in)
256
 
257
  # copied from huggingface modeling_llama.py
258
+ def _upad_input(
259
+ self, query_layer, key_layer, value_layer, attention_mask, query_length
260
+ ):
261
 
262
  def _get_unpad_data(attention_mask):
263
  seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
264
  indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
265
  max_seqlen_in_batch = seqlens_in_batch.max().item()
266
+ cu_seqlens = F.pad(
267
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
268
+ )
269
  return (
270
  indices,
271
  cu_seqlens,
 
276
  batch_size, kv_seq_len, num_key_value_heads, head_dim = key_layer.shape
277
 
278
  key_layer = index_first_axis(
279
+ key_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
280
+ indices_k,
281
  )
282
  value_layer = index_first_axis(
283
+ value_layer.reshape(batch_size * kv_seq_len, num_key_value_heads, head_dim),
284
+ indices_k,
285
  )
286
  if query_length == kv_seq_len:
287
  query_layer = index_first_axis(
288
+ query_layer.reshape(
289
+ batch_size * kv_seq_len, self.n_local_heads, head_dim
290
+ ),
291
+ indices_k,
292
  )
293
  cu_seqlens_q = cu_seqlens_k
294
  max_seqlen_in_batch_q = max_seqlen_in_batch_k
 
303
  else:
304
  # The -q_len: slice assumes left padding.
305
  attention_mask = attention_mask[:, -query_length:]
306
+ query_layer, indices_q, cu_seqlens_q, max_seqlen_in_batch_q = unpad_input(
307
+ query_layer, attention_mask
308
+ )
309
 
310
  return (
311
  query_layer,
 
354
 
355
  if dtype in [torch.float16, torch.bfloat16]:
356
  # begin var_len flash attn
357
+ (
358
+ query_states,
359
+ key_states,
360
+ value_states,
361
+ indices_q,
362
+ cu_seq_lens,
363
+ max_seq_lens,
364
+ ) = self._upad_input(xq, xk, xv, x_mask, seqlen)
365
 
366
  cu_seqlens_q, cu_seqlens_k = cu_seq_lens
367
  max_seqlen_in_batch_q, max_seqlen_in_batch_k = max_seq_lens
368
 
369
  if self.proportional_attn:
370
+ softmax_scale = math.sqrt(
371
+ math.log(seqlen, self.base_seqlen) / self.head_dim
372
+ )
373
  else:
374
  softmax_scale = math.sqrt(1 / self.head_dim)
375
 
 
381
  cu_seqlens_k=cu_seqlens_k,
382
  max_seqlen_q=max_seqlen_in_batch_q,
383
  max_seqlen_k=max_seqlen_in_batch_k,
384
+ dropout_p=0.0,
385
  causal=False,
386
+ softmax_scale=softmax_scale,
387
  )
388
  output = pad_input(attn_output_unpad, indices_q, bsz, seqlen)
389
  # end var_len_flash_attn
390
 
391
  else:
392
+ output = (
393
+ F.scaled_dot_product_attention(
394
+ xq.permute(0, 2, 1, 3),
395
+ xk.permute(0, 2, 1, 3),
396
+ xv.permute(0, 2, 1, 3),
397
+ attn_mask=x_mask.bool()
398
+ .view(bsz, 1, 1, seqlen)
399
+ .expand(-1, self.n_local_heads, seqlen, -1),
400
+ )
401
+ .permute(0, 2, 1, 3)
402
+ .to(dtype)
403
+ )
404
 
405
  if hasattr(self, "wk_y"):
406
  # todo better flash_attn support
407
+ yk = self.ky_norm(self.wk_y(y)).view(
408
+ bsz, -1, self.n_local_kv_heads, self.head_dim
409
+ )
410
  yv = self.wv_y(y).view(bsz, -1, self.n_local_kv_heads, self.head_dim)
411
  n_rep = self.n_local_heads // self.n_local_kv_heads
412
  if n_rep >= 1:
 
416
  xq.permute(0, 2, 1, 3),
417
  yk.permute(0, 2, 1, 3),
418
  yv.permute(0, 2, 1, 3),
419
+ y_mask.view(bsz, 1, 1, -1).expand(bsz, self.n_local_heads, seqlen, -1),
420
  ).permute(0, 2, 1, 3)
421
  output_y = output_y * self.gate.tanh().view(1, 1, -1, 1)
422
  output = output + output_y
 
446
  dimension. Defaults to None.
447
 
448
  Attributes:
449
+ w1 (nn.Linear): Linear transformation for the first
450
  layer.
451
+ w2 (nn.Linear): Linear transformation for the second layer.
452
+ w3 (nn.Linear): Linear transformation for the third
453
  layer.
454
 
455
  """
 
458
  # custom dim factor multiplier
459
  if ffn_dim_multiplier is not None:
460
  hidden_dim = int(ffn_dim_multiplier * hidden_dim)
461
+ hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
 
 
462
 
463
+ self.w1 = nn.Linear(
464
+ dim,
465
+ hidden_dim,
466
+ bias=False,
467
  )
468
+ self.w2 = nn.Linear(
469
+ hidden_dim,
470
+ dim,
471
+ bias=False,
472
  )
473
+ self.w3 = nn.Linear(
474
+ dim,
475
+ hidden_dim,
476
+ bias=False,
477
  )
478
 
479
  # @torch.compile
 
485
 
486
 
487
  class TransformerBlock(nn.Module):
488
+ def __init__(
489
+ self,
490
+ layer_id: int,
491
+ dim: int,
492
+ n_heads: int,
493
+ n_kv_heads: int,
494
+ multiple_of: int,
495
+ ffn_dim_multiplier: float,
496
+ norm_eps: float,
497
+ qk_norm: bool,
498
+ y_dim: int,
499
+ ) -> None:
500
  """
501
  Initialize a TransformerBlock.
502
 
 
527
  self.head_dim = dim // n_heads
528
  self.attention = Attention(dim, n_heads, n_kv_heads, qk_norm, y_dim)
529
  self.feed_forward = FeedForward(
530
+ dim=dim,
531
+ hidden_dim=4 * dim,
532
+ multiple_of=multiple_of,
533
  ffn_dim_multiplier=ffn_dim_multiplier,
534
  )
535
  self.layer_id = layer_id
 
540
 
541
  self.adaLN_modulation = nn.Sequential(
542
  nn.SiLU(),
543
+ nn.Linear(
544
+ min(dim, 1024),
545
+ 6 * dim,
546
+ bias=True,
547
  ),
548
  )
549
 
 
571
 
572
  """
573
  if adaln_input is not None:
574
+ shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
575
  self.adaLN_modulation(adaln_input).chunk(6, dim=1)
576
+ )
577
 
578
+ x = x + self.attention_norm1(
579
+ gate_msa.unsqueeze(1)
580
+ * self.attention(
581
+ modulate(self.attention_norm(x), shift_msa, scale_msa),
582
+ x_mask,
583
+ freqs_cis,
584
+ self.attention_y_norm(y),
585
+ y_mask,
586
+ )
587
+ )
588
  d = x.shape[-1]
589
+ x = x + self.ffn_norm1(
590
+ gate_mlp.unsqueeze(1)
591
+ * self.feed_forward(
592
+ modulate(self.ffn_norm(x), shift_mlp, scale_mlp).view(-1, d),
593
+ ).view(*x.shape)
594
+ )
595
 
596
  else:
597
+ x = x + self.attention_norm1(
598
+ self.attention(
599
+ self.attention_norm(x),
600
+ x_mask,
601
+ freqs_cis,
602
+ self.attention_y_norm(y),
603
+ y_mask,
604
+ )
605
+ )
606
  # for compatibility with torch.compile because the sequence length changes
607
  B, L, D = x.shape
608
+ x = x.view(B * L, D)
609
  x = x + self.ffn_norm1(self.feed_forward(self.ffn_norm(x)))
610
  x = x.view(B, L, D)
611
 
 
614
 
615
  class ParallelFinalLayer(nn.Module):
616
  """
617
+ The final layer of NextDiT.
618
  """
619
+
620
  def __init__(self, hidden_size, patch_size, out_channels):
621
  super().__init__()
622
  self.norm_final = nn.LayerNorm(
623
+ hidden_size,
624
+ elementwise_affine=False,
625
+ eps=1e-6,
626
  )
627
+ self.linear = nn.Linear(
628
+ hidden_size,
629
+ patch_size * patch_size * out_channels,
630
+ bias=True,
631
  )
632
  self.adaLN_modulation = nn.Sequential(
633
  nn.SiLU(),
634
+ nn.Linear(
635
+ min(hidden_size, 1024),
636
+ 2 * hidden_size,
637
+ bias=True,
638
  ),
639
  )
640
 
 
649
  """
650
  Diffusion model with a Transformer backbone.
651
  """
652
+
653
  def __init__(
654
  self,
655
  patch_size: int = 2,
 
664
  learn_sigma: bool = True,
665
  qk_norm: bool = False,
666
  cap_feat_dim: int = 5120,
667
+ rope_scaling_factor: float = 1.0,
668
+ ntk_factor: float = 1.0,
669
  ) -> None:
670
  super().__init__()
671
  self.learn_sigma = learn_sigma
 
673
  self.out_channels = in_channels * 2 if learn_sigma else in_channels
674
  self.patch_size = patch_size
675
 
676
+ self.x_embedder = nn.Linear(
677
  in_features=patch_size * patch_size * in_channels,
678
  out_features=dim,
679
  bias=True,
 
 
680
  )
681
+ nn.init.constant_(self.x_embedder.bias, 0.0)
682
 
683
  self.t_embedder = ParallelTimestepEmbedder(min(dim, 1024))
684
  self.cap_embedder = nn.Sequential(
685
  nn.LayerNorm(cap_feat_dim),
686
+ nn.Linear(
687
+ cap_feat_dim,
688
+ min(dim, 1024),
689
+ bias=True,
690
+ ),
691
  )
692
 
693
+ self.layers = nn.ModuleList(
694
+ [
695
+ TransformerBlock(
696
+ layer_id,
697
+ dim,
698
+ n_heads,
699
+ n_kv_heads,
700
+ multiple_of,
701
+ ffn_dim_multiplier,
702
+ norm_eps,
703
+ qk_norm,
704
+ cap_feat_dim,
705
+ )
706
+ for layer_id in range(n_layers)
707
+ ]
708
+ )
709
  self.final_layer = ParallelFinalLayer(dim, patch_size, self.out_channels)
710
 
711
  assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
712
  self.dim = dim
713
  self.n_heads = n_heads
714
  self.freqs_cis = NextDiT.precompute_freqs_cis(
715
+ dim // n_heads,
716
+ 384,
717
+ rope_scaling_factor=rope_scaling_factor,
718
+ ntk_factor=ntk_factor,
719
  )
720
  self.rope_scaling_factor = rope_scaling_factor
721
  self.ntk_factor = ntk_factor
 
724
  # nn.init.normal_(self.eol_token, std=0.02)
725
  nn.init.normal_(self.pad_token, std=0.02)
726
 
727
+ def unpatchify(
728
+ self, x: torch.Tensor, img_size: List[Tuple[int, int]], return_tensor=False
729
+ ) -> List[torch.Tensor]:
730
  """
731
  x: (N, T, patch_size**2 * C)
732
  imgs: (N, H, W, C)
 
744
  for i in range(x.size(0)):
745
  H, W = img_size[i]
746
  L = (H // pH) * (W // pW)
747
+ imgs.append(
748
+ x[i][:L]
749
+ .view(H // pH, W // pW, pH, pW, self.out_channels)
750
+ .permute(4, 0, 2, 1, 3)
751
+ .flatten(3, 4)
752
+ .flatten(1, 2)
753
+ )
754
  return imgs
755
 
756
  def patchify_and_embed(
757
+ self, x: List[torch.Tensor] | torch.Tensor
 
758
  ) -> Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]:
759
  self.freqs_cis = self.freqs_cis.to(x[0].device)
760
  if isinstance(x, torch.Tensor):
761
  pH = pW = self.patch_size
762
  B, C, H, W = x.size()
763
+ x = (
764
+ x.view(B, C, H // pH, pH, W // pW, pW)
765
+ .permute(0, 2, 4, 1, 3, 5)
766
+ .flatten(3)
767
+ )
768
  x = self.x_embedder(x)
769
  x = x.flatten(1, 2)
770
 
771
+ mask = torch.ones(
772
+ x.shape[0], x.shape[1], dtype=torch.int32, device=x.device
773
+ )
774
  # leave the first line for text
775
+ return (
776
+ x,
777
+ mask,
778
+ [(H, W)] * B,
779
+ self.freqs_cis[: H // pH, : W // pW].flatten(0, 1).unsqueeze(0),
780
+ )
781
  else:
782
  pH = pW = self.patch_size
783
  x_embed = []
 
787
 
788
  for img in x:
789
  C, H, W = img.size()
790
+ item_freqs_cis = self.freqs_cis[: H // pH, : W // pW]
791
+ freqs_cis.append(item_freqs_cis.flatten(0, 1))
792
  img_size.append((H, W))
793
+ img = (
794
+ img.view(C, H // pH, pH, W // pW, pW)
795
+ .permute(1, 3, 0, 2, 4)
796
+ .flatten(2)
797
+ )
798
  img = self.x_embedder(img)
799
  img = img.flatten(0, 1)
800
  l_effective_seq_len.append(len(img))
801
  x_embed.append(img)
802
 
803
  max_seq_len = max(l_effective_seq_len)
804
+ mask = torch.zeros(
805
+ len(x), max_seq_len, dtype=torch.int32, device=x[0].device
806
+ )
807
  padded_x_embed = []
808
  padded_freqs_cis = []
809
+ for i, (item_embed, item_freqs_cis, item_seq_len) in enumerate(
810
+ zip(x_embed, freqs_cis, l_effective_seq_len)
811
+ ):
812
+ item_embed = torch.cat(
813
+ [
814
+ item_embed,
815
+ self.pad_token.view(1, -1).expand(
816
+ max_seq_len - item_seq_len, -1
817
+ ),
818
+ ],
819
+ dim=0,
820
+ )
821
+ item_freqs_cis = torch.cat(
822
+ [
823
+ item_freqs_cis,
824
+ item_freqs_cis[-1:].expand(max_seq_len - item_seq_len, -1),
825
+ ],
826
+ dim=0,
827
+ )
828
  padded_x_embed.append(item_embed)
829
  padded_freqs_cis.append(item_freqs_cis)
830
  mask[i][:item_seq_len] = 1
 
835
 
836
  def forward(self, x, t, cap_feats, cap_mask):
837
  """
838
+ Forward pass of NextDiT.
839
  t: (N,) tensor of diffusion timesteps
840
  y: (N,) tensor of class labels
841
  """
 
845
 
846
  # cap_freqs_cis = self.freqs_cis[:1, :cap_feats.shape[1]].to(x.device)
847
 
848
+ t = self.t_embedder(t) # (N, D)
849
  cap_mask_float = cap_mask.float().unsqueeze(-1)
850
+ cap_feats_pool = (cap_feats * cap_mask_float).sum(dim=1) / cap_mask_float.sum(
851
+ dim=1
852
+ )
853
  cap_feats_pool = cap_feats_pool.to(cap_feats)
854
  cap_emb = self.cap_embedder(cap_feats_pool)
855
  adaln_input = t + cap_emb
856
 
857
  cap_mask = cap_mask.bool()
858
  for layer in self.layers:
859
+ x = layer(x, mask, freqs_cis, cap_feats, cap_mask, adaln_input=adaln_input)
 
 
 
860
 
861
  x = self.final_layer(x, adaln_input)
862
  x = self.unpatchify(x, img_size, return_tensor=x_is_tensor)
 
867
  x = [_.chunk(2, dim=0)[0] for _ in x]
868
  return x
869
 
870
+ def forward_with_cfg(
871
+ self,
872
+ x,
873
+ t,
874
+ cap_feats,
875
+ cap_mask,
876
+ cfg_scale,
877
+ rope_scaling_factor=None,
878
+ ntk_factor=None,
879
+ base_seqlen: Optional[int] = None,
880
+ proportional_attn: bool = False,
881
+ ):
882
  # """
883
+ # Forward pass of NextDiT, but also batches the unconNextditional forward pass
884
  # for classifier-free guidance.
885
  # """
886
  # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
887
  # print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
888
  if rope_scaling_factor is not None or ntk_factor is not None:
889
+ rope_scaling_factor = (
890
+ rope_scaling_factor
891
+ if rope_scaling_factor is not None
892
+ else self.rope_scaling_factor
893
+ )
894
  ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
895
+ if (
896
+ rope_scaling_factor != self.rope_scaling_factor
897
+ or ntk_factor != self.ntk_factor
898
+ ):
899
+ print(
900
+ f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}",
901
+ flush=True,
902
+ )
903
  self.freqs_cis = NextDiT.precompute_freqs_cis(
904
+ self.dim // self.n_heads,
905
+ 384,
906
+ rope_scaling_factor=rope_scaling_factor,
907
+ ntk_factor=ntk_factor,
908
  )
909
  self.rope_scaling_factor = rope_scaling_factor
910
  self.ntk_factor = ntk_factor
911
+
912
  if proportional_attn:
913
  assert base_seqlen is not None
914
  for layer in self.layers:
 
938
  end: int,
939
  theta: float = 10000.0,
940
  rope_scaling_factor: float = 1.0,
941
+ ntk_factor: float = 1.0,
942
  ):
943
  """
944
  Precompute the frequency tensor for complex exponentials (cis) with
 
962
 
963
  theta = theta * ntk_factor
964
 
965
+ logger.info(
966
+ f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}"
967
+ )
968
+ freqs = 1.0 / (
969
+ theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)
970
+ )
971
  t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
972
  t = t / rope_scaling_factor
973
  freqs = torch.outer(t, freqs).float() # type: ignore
974
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
975
 
976
+ freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
977
+ freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
978
  freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
979
  return freqs_cis
980
 
981
  def parameter_count(self) -> int:
982
  tensor_parallel_module_list = (
983
+ nn.Linear,
984
+ nn.Linear,
985
+ nn.Embedding,
986
  )
987
  total_params = 0
988
 
 
990
  nonlocal total_params
991
  is_tp_module = isinstance(module, tensor_parallel_module_list)
992
  for param in module.parameters(recurse=False):
993
+ total_params += param.numel()
 
 
 
994
  for submodule in module.children():
995
  _recursive_count_params(submodule)
996
 
 
1002
 
1003
 
1004
  #############################################################################
1005
+ # NextDiT Configs #
1006
  #############################################################################
1007
  def NextDiT_2B_patch2(**kwargs):
1008
+ return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)