PommesPeter commited on
Commit
e203a0e
1 Parent(s): b87fd1b

Update models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +21 -33
models/model.py CHANGED
@@ -885,29 +885,14 @@ class NextDiT(nn.Module):
885
  # """
886
  # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
887
  # print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
888
- if rope_scaling_factor is not None or ntk_factor is not None:
889
- rope_scaling_factor = (
890
- rope_scaling_factor
891
- if rope_scaling_factor is not None
892
- else self.rope_scaling_factor
 
 
893
  )
894
- ntk_factor = ntk_factor if ntk_factor is not None else self.ntk_factor
895
- if (
896
- rope_scaling_factor != self.rope_scaling_factor
897
- or ntk_factor != self.ntk_factor
898
- ):
899
- print(
900
- f"override freqs_cis, rope_scaling {rope_scaling_factor}, ntk {ntk_factor}",
901
- flush=True,
902
- )
903
- self.freqs_cis = NextDiT.precompute_freqs_cis(
904
- self.dim // self.n_heads,
905
- 384,
906
- rope_scaling_factor=rope_scaling_factor,
907
- ntk_factor=ntk_factor,
908
- )
909
- self.rope_scaling_factor = rope_scaling_factor
910
- self.ntk_factor = ntk_factor
911
 
912
  if proportional_attn:
913
  assert base_seqlen is not None
@@ -938,7 +923,8 @@ class NextDiT(nn.Module):
938
  end: int,
939
  theta: float = 10000.0,
940
  rope_scaling_factor: float = 1.0,
941
- ntk_factor: float = 1.0,
 
942
  ):
943
  """
944
  Precompute the frequency tensor for complex exponentials (cis) with
@@ -959,23 +945,25 @@ class NextDiT(nn.Module):
959
  torch.Tensor: Precomputed frequency tensor with complex
960
  exponentials.
961
  """
 
962
 
963
- theta = theta * ntk_factor
 
 
964
 
965
- logger.info(
966
- f"theta {theta} rope scaling {rope_scaling_factor} ntk {ntk_factor}"
967
- )
968
- freqs = 1.0 / (
969
- theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)
970
- )
971
- t = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
972
- t = t / rope_scaling_factor
973
- freqs = torch.outer(t, freqs).float() # type: ignore
974
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
975
 
976
  freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
977
  freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
978
  freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
 
979
  return freqs_cis
980
 
981
  def parameter_count(self) -> int:
 
885
  # """
886
  # # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
887
  # print(ntk_factor, rope_scaling_factor, self.ntk_factor, self.rope_scaling_factor)
888
+ if scale_factor is not None:
889
+ assert scale_factor is not None
890
+ self.freqs_cis = NextDiT.precompute_freqs_cis(
891
+ self.dim // self.n_heads,
892
+ 384,
893
+ scale_factor=scale_factor,
894
+ timestep=t[0],
895
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
896
 
897
  if proportional_attn:
898
  assert base_seqlen is not None
 
923
  end: int,
924
  theta: float = 10000.0,
925
  rope_scaling_factor: float = 1.0,
926
+ scale_factor: float = 1.0,
927
+ timestep: float = 1.0,
928
  ):
929
  """
930
  Precompute the frequency tensor for complex exponentials (cis) with
 
945
  torch.Tensor: Precomputed frequency tensor with complex
946
  exponentials.
947
  """
948
+ freqs_inter = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim)) / scale_factor
949
 
950
+ target_dim = timestep * dim + 1
951
+ scale_factor = scale_factor ** (dim / target_dim)
952
+ theta = theta * scale_factor
953
 
954
+ freqs_time_scaled = 1.0 / (theta ** (torch.arange(0, dim, 4)[: (dim // 4)].float().cuda() / dim))
955
+
956
+ freqs = torch.max(freqs_inter, freqs_time_scaled)
957
+
958
+ timestep = torch.arange(end, device=freqs.device, dtype=torch.float) # type: ignore
959
+
960
+ freqs = torch.outer(timestep, freqs).float() # type: ignore
 
 
961
  freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64
962
 
963
  freqs_cis_h = freqs_cis.view(end, 1, dim // 4, 1).repeat(1, end, 1, 1)
964
  freqs_cis_w = freqs_cis.view(1, end, dim // 4, 1).repeat(end, 1, 1, 1)
965
  freqs_cis = torch.cat([freqs_cis_h, freqs_cis_w], dim=-1).flatten(2)
966
+
967
  return freqs_cis
968
 
969
  def parameter_count(self) -> int: