PommesPeter commited on
Commit
556f26b
·
verified ·
1 Parent(s): e5e92a0

Update models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +5 -25
models/model.py CHANGED
@@ -592,7 +592,7 @@ class ParallelFinalLayer(nn.Module):
592
  return x
593
 
594
 
595
- class DiT_Llama(nn.Module):
596
  """
597
  Diffusion model with a Transformer backbone.
598
  """
@@ -645,7 +645,7 @@ class DiT_Llama(nn.Module):
645
  assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
646
  self.dim = dim
647
  self.n_heads = n_heads
648
- self.freqs_cis = DiT_Llama.precompute_freqs_cis(
649
  dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
650
  )
651
  self.rope_scaling_factor = rope_scaling_factor
@@ -781,7 +781,7 @@ class DiT_Llama(nn.Module):
781
  ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
782
  if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
783
  print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
784
- self.freqs_cis = DiT_Llama.precompute_freqs_cis(
785
  self.dim // self.n_heads, 384,
786
  rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
787
  )
@@ -882,27 +882,7 @@ class DiT_Llama(nn.Module):
882
  #############################################################################
883
  # DiT Configs #
884
  #############################################################################
885
-
886
-
887
- def DiT_Llama_600M_patch2(**kwargs):
888
- return DiT_Llama(
889
- patch_size=2, dim=1536, n_layers=16, n_heads=32, **kwargs
890
- )
891
-
892
-
893
- def DiT_Llama_2B_patch2(**kwargs):
894
- return DiT_Llama(
895
  patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
896
  )
897
-
898
-
899
- def DiT_Llama_3B_patch2(**kwargs):
900
- return DiT_Llama(
901
- patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
902
- )
903
-
904
-
905
- def DiT_Llama_7B_patch2(**kwargs):
906
- return DiT_Llama(
907
- patch_size=2, dim=4096, n_layers=32, n_heads=32, **kwargs
908
- )
 
592
  return x
593
 
594
 
595
+ class NextDiT(nn.Module):
596
  """
597
  Diffusion model with a Transformer backbone.
598
  """
 
645
  assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
646
  self.dim = dim
647
  self.n_heads = n_heads
648
+ self.freqs_cis = NextDiT.precompute_freqs_cis(
649
  dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
650
  )
651
  self.rope_scaling_factor = rope_scaling_factor
 
781
  ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
782
  if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
783
  print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
784
+ self.freqs_cis = NextDiT.precompute_freqs_cis(
785
  self.dim // self.n_heads, 384,
786
  rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
787
  )
 
882
  #############################################################################
883
  # DiT Configs #
884
  #############################################################################
885
+ def NextDiT_2B_patch2(**kwargs):
886
+ return NextDiT(
 
 
 
 
 
 
 
 
887
  patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
888
  )