Spaces:
Running
on
Zero
Running
on
Zero
PommesPeter
commited on
Update models/model.py
Browse files- models/model.py +5 -25
models/model.py
CHANGED
@@ -592,7 +592,7 @@ class ParallelFinalLayer(nn.Module):
|
|
592 |
return x
|
593 |
|
594 |
|
595 |
-
class
|
596 |
"""
|
597 |
Diffusion model with a Transformer backbone.
|
598 |
"""
|
@@ -645,7 +645,7 @@ class DiT_Llama(nn.Module):
|
|
645 |
assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
646 |
self.dim = dim
|
647 |
self.n_heads = n_heads
|
648 |
-
self.freqs_cis =
|
649 |
dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
650 |
)
|
651 |
self.rope_scaling_factor = rope_scaling_factor
|
@@ -781,7 +781,7 @@ class DiT_Llama(nn.Module):
|
|
781 |
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
782 |
if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
|
783 |
print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
|
784 |
-
self.freqs_cis =
|
785 |
self.dim // self.n_heads, 384,
|
786 |
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
787 |
)
|
@@ -882,27 +882,7 @@ class DiT_Llama(nn.Module):
|
|
882 |
#############################################################################
|
883 |
# DiT Configs #
|
884 |
#############################################################################
|
885 |
-
|
886 |
-
|
887 |
-
def DiT_Llama_600M_patch2(**kwargs):
|
888 |
-
return DiT_Llama(
|
889 |
-
patch_size=2, dim=1536, n_layers=16, n_heads=32, **kwargs
|
890 |
-
)
|
891 |
-
|
892 |
-
|
893 |
-
def DiT_Llama_2B_patch2(**kwargs):
|
894 |
-
return DiT_Llama(
|
895 |
patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
|
896 |
)
|
897 |
-
|
898 |
-
|
899 |
-
def DiT_Llama_3B_patch2(**kwargs):
|
900 |
-
return DiT_Llama(
|
901 |
-
patch_size=2, dim=3072, n_layers=32, n_heads=32, **kwargs
|
902 |
-
)
|
903 |
-
|
904 |
-
|
905 |
-
def DiT_Llama_7B_patch2(**kwargs):
|
906 |
-
return DiT_Llama(
|
907 |
-
patch_size=2, dim=4096, n_layers=32, n_heads=32, **kwargs
|
908 |
-
)
|
|
|
592 |
return x
|
593 |
|
594 |
|
595 |
+
class NextDiT(nn.Module):
|
596 |
"""
|
597 |
Diffusion model with a Transformer backbone.
|
598 |
"""
|
|
|
645 |
assert (dim // n_heads) % 4 == 0, "2d rope needs head dim to be divisible by 4"
|
646 |
self.dim = dim
|
647 |
self.n_heads = n_heads
|
648 |
+
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
649 |
dim // n_heads, 384, rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
650 |
)
|
651 |
self.rope_scaling_factor = rope_scaling_factor
|
|
|
781 |
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
782 |
if rope_scaling_factor != self.rope_scaling_factor or ntk_factor != self.ntk_factor:
|
783 |
print(f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}", flush=True)
|
784 |
+
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
785 |
self.dim // self.n_heads, 384,
|
786 |
rope_scaling_factor=rope_scaling_factor, ntk_factor=ntk_factor
|
787 |
)
|
|
|
882 |
#############################################################################
|
883 |
# DiT Configs #
|
884 |
#############################################################################
|
885 |
+
def NextDiT_2B_patch2(**kwargs):
|
886 |
+
return NextDiT(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
887 |
patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs
|
888 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|