diff --git a/models/stylegan2/op/fused_act.py b/models/stylegan2/op/fused_act.py index 973a84f..6854b97 100644 --- a/models/stylegan2/op/fused_act.py +++ b/models/stylegan2/op/fused_act.py @@ -2,17 +2,18 @@ import os import torch from torch import nn +from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load -module_path = os.path.dirname(__file__) -fused = load( - 'fused', - sources=[ - os.path.join(module_path, 'fused_bias_act.cpp'), - os.path.join(module_path, 'fused_bias_act_kernel.cu'), - ], -) +#module_path = os.path.dirname(__file__) +#fused = load( +# 'fused', +# sources=[ +# os.path.join(module_path, 'fused_bias_act.cpp'), +# os.path.join(module_path, 'fused_bias_act_kernel.cu'), +# ], +#) class FusedLeakyReLUFunctionBackward(Function): @@ -82,4 +83,18 @@ class FusedLeakyReLU(nn.Module): def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5): - return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) + if input.device.type == "cpu": + if bias is not None: + rest_dim = [1] * (input.ndim - bias.ndim - 1) + return ( + F.leaky_relu( + input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=0.2 + ) + * scale + ) + + else: + return F.leaky_relu(input, negative_slope=0.2) * scale + + else: + return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale) diff --git a/models/stylegan2/op/upfirdn2d.py b/models/stylegan2/op/upfirdn2d.py index 7bc5a1e..5465d1a 100644 --- a/models/stylegan2/op/upfirdn2d.py +++ b/models/stylegan2/op/upfirdn2d.py @@ -1,17 +1,18 @@ import os import torch +from torch.nn import functional as F from torch.autograd import Function from torch.utils.cpp_extension import load -module_path = os.path.dirname(__file__) -upfirdn2d_op = load( - 'upfirdn2d', - sources=[ - os.path.join(module_path, 'upfirdn2d.cpp'), - os.path.join(module_path, 'upfirdn2d_kernel.cu'), - ], -) +#module_path = os.path.dirname(__file__) +#upfirdn2d_op = load( +# 'upfirdn2d', +# sources=[ +# os.path.join(module_path, 'upfirdn2d.cpp'), +# os.path.join(module_path, 'upfirdn2d_kernel.cu'), +# ], +#) class UpFirDn2dBackward(Function): @@ -97,8 +98,8 @@ class UpFirDn2d(Function): ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1])) - out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1 - out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1 + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x ctx.out_size = (out_h, out_w) ctx.up = (up_x, up_y) @@ -140,9 +141,13 @@ class UpFirDn2d(Function): def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): - out = UpFirDn2d.apply( - input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) - ) + if input.device.type == "cpu": + out = upfirdn2d_native(input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]) + + else: + out = UpFirDn2d.apply( + input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1]) + ) return out @@ -150,6 +155,9 @@ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)): def upfirdn2d_native( input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1 ): + _, channel, in_h, in_w = input.shape + input = input.reshape(-1, in_h, in_w, 1) + _, in_h, in_w, minor = input.shape kernel_h, kernel_w = kernel.shape @@ -180,5 +188,9 @@ def upfirdn2d_native( in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1, ) out = out.permute(0, 2, 3, 1) + out = out[:, ::down_y, ::down_x, :] + + out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h + down_y) // down_y + out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w + down_x) // down_x - return out[:, ::down_y, ::down_x, :] + return out.view(-1, channel, out_h, out_w)