safetyWaifu / patch
hysts
Add files
161d647
raw history blame
No virus
2.34 kB
diff --git a/model.py b/model.py
index 0134c39..3a7826c 100755
--- a/model.py
+++ b/model.py
@@ -395,6 +395,7 @@ class Generator(nn.Module):
style_dim,
n_mlp,
channel_multiplier=2,
+ additional_multiplier=2,
blur_kernel=[1, 3, 3, 1],
lr_mlp=0.01,
):
@@ -426,6 +427,9 @@ class Generator(nn.Module):
512: 32 * channel_multiplier,
1024: 16 * channel_multiplier,
}
+ if additional_multiplier > 1:
+ for k in list(self.channels.keys()):
+ self.channels[k] *= additional_multiplier
self.input = ConstantInput(self.channels[4])
self.conv1 = StyledConv(
@@ -518,7 +522,7 @@ class Generator(nn.Module):
getattr(self.noises, f"noise_{i}") for i in range(self.num_layers)
]
- if truncation < 1:
+ if truncation_latent is not None:
style_t = []
for style in styles:
diff --git a/op/fused_act.py b/op/fused_act.py
index 5d46e10..bc522ed 100755
--- a/op/fused_act.py
+++ b/op/fused_act.py
@@ -1,5 +1,3 @@
-import os
-
import torch
from torch import nn
from torch.nn import functional as F
@@ -7,16 +5,6 @@ from torch.autograd import Function
from torch.utils.cpp_extension import load
-module_path = os.path.dirname(__file__)
-fused = load(
- "fused",
- sources=[
- os.path.join(module_path, "fused_bias_act.cpp"),
- os.path.join(module_path, "fused_bias_act_kernel.cu"),
- ],
-)
-
-
class FusedLeakyReLUFunctionBackward(Function):
@staticmethod
def forward(ctx, grad_output, out, bias, negative_slope, scale):
diff --git a/op/upfirdn2d.py b/op/upfirdn2d.py
index 67e0375..6c5840e 100755
--- a/op/upfirdn2d.py
+++ b/op/upfirdn2d.py
@@ -1,5 +1,4 @@
from collections import abc
-import os
import torch
from torch.nn import functional as F
@@ -7,16 +6,6 @@ from torch.autograd import Function
from torch.utils.cpp_extension import load
-module_path = os.path.dirname(__file__)
-upfirdn2d_op = load(
- "upfirdn2d",
- sources=[
- os.path.join(module_path, "upfirdn2d.cpp"),
- os.path.join(module_path, "upfirdn2d_kernel.cu"),
- ],
-)
-
-
class UpFirDn2dBackward(Function):
@staticmethod
def forward(