aliabd commited on
Commit
e3030de
β€’
1 Parent(s): 9bc5228

added op and samples

Browse files
op/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
2
+ from .upfirdn2d import upfirdn2d
op/fused_act.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ class FusedLeakyReLU(nn.Module):
12
+ def __init__(self, channel, negative_slope=0.2, scale=2 ** 0.5):
13
+ super().__init__()
14
+
15
+ self.bias = nn.Parameter(torch.zeros(channel))
16
+ self.negative_slope = negative_slope
17
+ self.scale = scale
18
+
19
+ def forward(self, input):
20
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
21
+
22
+
23
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2 ** 0.5):
24
+ rest_dim = [1] * (input.ndim - bias.ndim - 1)
25
+ if input.ndim == 3:
26
+ return (
27
+ F.leaky_relu(
28
+ input + bias.view(1, *rest_dim, bias.shape[0]), negative_slope=negative_slope
29
+ )
30
+ * scale
31
+ )
32
+ else:
33
+ return (
34
+ F.leaky_relu(
35
+ input + bias.view(1, bias.shape[0], *rest_dim), negative_slope=negative_slope
36
+ )
37
+ * scale
38
+ )
39
+
op/upfirdn2d.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ module_path = os.path.dirname(__file__)
8
+
9
+
10
+
11
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
12
+ out = upfirdn2d_native(
13
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
14
+ )
15
+
16
+ return out
17
+
18
+
19
+ def upfirdn2d_native(
20
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
21
+ ):
22
+ _, channel, in_h, in_w = input.shape
23
+ input = input.reshape(-1, in_h, in_w, 1)
24
+
25
+ _, in_h, in_w, minor = input.shape
26
+ kernel_h, kernel_w = kernel.shape
27
+
28
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
29
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
30
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
31
+
32
+ out = F.pad(
33
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
34
+ )
35
+ out = out[
36
+ :,
37
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
38
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
39
+ :,
40
+ ]
41
+
42
+ out = out.permute(0, 3, 1, 2)
43
+ out = out.reshape(
44
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
45
+ )
46
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
47
+ out = F.conv2d(out, w)
48
+ out = out.reshape(
49
+ -1,
50
+ minor,
51
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
52
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
53
+ )
54
+ out = out.permute(0, 2, 3, 1)
55
+ out = out[:, ::down_y, ::down_x, :]
56
+
57
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
58
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
59
+
60
+ return out.view(-1, channel, out_h, out_w)
samples/female_11025.jpg ADDED
samples/female_12427.jpg ADDED
samples/margot_robbie.jpg ADDED
samples/output.mp4 ADDED
Binary file (4.22 MB). View file
 
samples/tiktok.mp4 ADDED
Binary file (1.33 MB). View file