PommesPeter commited on
Commit
64f5019
1 Parent(s): 15222c4

Update models/model.py

Browse files
Files changed (1) hide show
  1. models/model.py +4 -1
models/model.py CHANGED
@@ -862,7 +862,7 @@ class NextDiT(nn.Module):
862
  x, _ = x.chunk(2, dim=1)
863
  else:
864
  x = [_.chunk(2, dim=0)[0] for _ in x]
865
- return x
866
 
867
  def forward_with_cfg(
868
  self,
@@ -991,3 +991,6 @@ class NextDiT(nn.Module):
991
  #############################################################################
992
  def NextDiT_2B_patch2(**kwargs):
993
  return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)
 
 
 
 
862
  x, _ = x.chunk(2, dim=1)
863
  else:
864
  x = [_.chunk(2, dim=0)[0] for _ in x]
865
+ return x`
866
 
867
  def forward_with_cfg(
868
  self,
 
991
  #############################################################################
992
  def NextDiT_2B_patch2(**kwargs):
993
  return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)
994
+
995
+ def NextDiT_2B_GQA_patch2(**kwargs):
996
+ return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, n_kv_heads=8, **kwargs)