Spaces:
Running
on
Zero
Running
on
Zero
Update models/model.py
Browse files- models/model.py +4 -1
models/model.py
CHANGED
@@ -862,7 +862,7 @@ class NextDiT(nn.Module):
|
|
862 |
x, _ = x.chunk(2, dim=1)
|
863 |
else:
|
864 |
x = [_.chunk(2, dim=0)[0] for _ in x]
|
865 |
-
return x
|
866 |
|
867 |
def forward_with_cfg(
|
868 |
self,
|
@@ -991,3 +991,6 @@ class NextDiT(nn.Module):
|
|
991 |
#############################################################################
|
992 |
def NextDiT_2B_patch2(**kwargs):
|
993 |
return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)
|
|
|
|
|
|
|
|
862 |
x, _ = x.chunk(2, dim=1)
|
863 |
else:
|
864 |
x = [_.chunk(2, dim=0)[0] for _ in x]
|
865 |
+
return x`
|
866 |
|
867 |
def forward_with_cfg(
|
868 |
self,
|
|
|
991 |
#############################################################################
|
992 |
def NextDiT_2B_patch2(**kwargs):
|
993 |
return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, **kwargs)
|
994 |
+
|
995 |
+
def NextDiT_2B_GQA_patch2(**kwargs):
|
996 |
+
return NextDiT(patch_size=2, dim=2304, n_layers=24, n_heads=32, n_kv_heads=8, **kwargs)
|