Spaces:
Running
on
Zero
Running
on
Zero
Commit
•
e203a0e
1
Parent(s):
b87fd1b
Update models/model.py
Browse files- models/model.py +21 -33
models/model.py
CHANGED
@@ -885,29 +885,14 @@ class NextDiT(nn.Module):
|
|
885 |
# """
|
886 |
# # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
887 |
# print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
|
888 |
-
if
|
889 |
-
|
890 |
-
|
891 |
-
|
892 |
-
|
|
|
|
|
893 |
)
|
894 |
-
ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
|
895 |
-
if (
|
896 |
-
rope_scaling_factor != self.rope_scaling_factor
|
897 |
-
or ntk_factor != self.ntk_factor
|
898 |
-
):
|
899 |
-
print(
|
900 |
-
f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}",
|
901 |
-
flush=True,
|
902 |
-
)
|
903 |
-
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
904 |
-
self.dim // self.n_heads,
|
905 |
-
384,
|
906 |
-
rope_scaling_factor=rope_scaling_factor,
|
907 |
-
ntk_factor=ntk_factor,
|
908 |
-
)
|
909 |
-
self.rope_scaling_factor = rope_scaling_factor
|
910 |
-
self.ntk_factor = ntk_factor
|
911 |
|
912 |
if proportional_attn:
|
913 |
assert base_seqlen is not None
|
@@ -938,7 +923,8 @@ class NextDiT(nn.Module):
|
|
938 |
end: int,
|
939 |
theta: float = 10000.0,
|
940 |
rope_scaling_factor: float = 1.0,
|
941 |
-
|
|
|
942 |
):
|
943 |
"""
|
944 |
Precompute the frequency tensor for complex exponentials (cis) with
|
@@ -959,23 +945,25 @@ class NextDiT(nn.Module):
|
|
959 |
torch.Tensor: Precomputed frequency tensor with complex
|
960 |
exponentials.
|
961 |
"""
|
|
|
962 |
|
963 |
-
|
|
|
|
|
964 |
|
965 |
-
|
966 |
-
|
967 |
-
)
|
968 |
-
|
969 |
-
|
970 |
-
|
971 |
-
|
972 |
-
t = t / rope_scaling_factor
|
973 |
-
freqs = torch.outer(t, freqs).float() # type: ignore
|
974 |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
975 |
|
976 |
freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
|
977 |
freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
|
978 |
freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
|
|
|
979 |
return freqs_cis
|
980 |
|
981 |
def parameter_count(self) -> int:
|
|
|
885 |
# """
|
886 |
# # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
|
887 |
# print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
|
888 |
+
if scale_factor is not None:
|
889 |
+
assert scale_factor is not None
|
890 |
+
self.freqs_cis = NextDiT.precompute_freqs_cis(
|
891 |
+
self.dim // self.n_heads,
|
892 |
+
384,
|
893 |
+
scale_factor=scale_factor,
|
894 |
+
timestep=t[0],
|
895 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
896 |
|
897 |
if proportional_attn:
|
898 |
assert base_seqlen is not None
|
|
|
923 |
end: int,
|
924 |
theta: float = 10000.0,
|
925 |
rope_scaling_factor: float = 1.0,
|
926 |
+
scale_factor: float = 1.0,
|
927 |
+
timestep: float = 1.0,
|
928 |
):
|
929 |
"""
|
930 |
Precompute the frequency tensor for complex exponentials (cis) with
|
|
|
945 |
torch.Tensor: Precomputed frequency tensor with complex
|
946 |
exponentials.
|
947 |
"""
|
948 |
+
freqs_inter = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / scale_factor
|
949 |
|
950 |
+
target_dim = timestep * dim + 1
|
951 |
+
scale_factor = scale_factor ** (dim / target_dim)
|
952 |
+
theta = theta * scale_factor
|
953 |
|
954 |
+
freqs_time_scaled = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim))
|
955 |
+
|
956 |
+
freqs = torch.max(freqs_inter, freqs_time_scaled)
|
957 |
+
|
958 |
+
timestep = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
|
959 |
+
|
960 |
+
freqs = torch.outer(timestep, freqs).float() # type: ignore
|
|
|
|
|
961 |
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
|
962 |
|
963 |
freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
|
964 |
freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
|
965 |
freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
|
966 |
+
|
967 |
return freqs_cis
|
968 |
|
969 |
def parameter_count(self) -> int:
|